from typing import List
import torch
from torch import Tensor
from torch_geometric.utils import degree
[docs]def unbatch(src: Tensor, batch: Tensor, dim: int = 0) -> List[Tensor]:
r"""Splits :obj:`src` according to a :obj:`batch` vector along dimension
:obj:`dim`.
Args:
src (Tensor): The source tensor.
batch (LongTensor): The batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
entry in :obj:`src` to a specific example. Must be ordered.
dim (int, optional): The dimension along which to split the :obj:`src`
tensor. (default: :obj:`0`)
:rtype: :class:`List[Tensor]`
"""
sizes = degree(batch, dtype=torch.long).tolist()
return src.split(sizes, dim)