Source code for torch_geometric.utils.unbatch

from typing import List

import torch
from torch import Tensor

from torch_geometric.utils import degree


[docs]def unbatch(src: Tensor, batch: Tensor, dim: int = 0) -> List[Tensor]: r"""Splits :obj:`src` according to a :obj:`batch` vector along dimension :obj:`dim`. Args: src (Tensor): The source tensor. batch (LongTensor): The batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each entry in :obj:`src` to a specific example. Must be ordered. dim (int, optional): The dimension along which to split the :obj:`src` tensor. (default: :obj:`0`) :rtype: :class:`List[Tensor]` """ sizes = degree(batch, dtype=torch.long).tolist() return src.split(sizes, dim)