Skip to content

Added support for custom functions in collate and separate#10636

Open
FilippoMB wants to merge 5 commits intopyg-team:masterfrom
FilippoMB:master
Open

Added support for custom functions in collate and separate#10636
FilippoMB wants to merge 5 commits intopyg-team:masterfrom
FilippoMB:master

Conversation

@FilippoMB
Copy link
Copy Markdown

For advanced use-cases, torch_geometric.data.collate.collate and torch_geometric.data.separate.separate accept function maps that let you override batching/separation for user-defined types.
Dispatch first checks exact type matches, and then falls back to isinstance checks.

import torch
from torch import Tensor
from torch_geometric.data import Data
from torch_geometric.data.collate import collate
from torch_geometric.data.separate import separate


class Foo:
    def __init__(self, x: Tensor):
        self.x = x


def foo_collate(*, values, **kwargs):
    xs = [v.x for v in values]
    sizes = torch.tensor([x.size(0) for x in xs], dtype=torch.long)
    slices = torch.cat([torch.zeros(1, dtype=torch.long), sizes.cumsum(0)])
    return Foo(torch.cat(xs, dim=0)), slices, None


def foo_separate(*, values, idx, slices, **kwargs):
    start, end = int(slices[idx]), int(slices[idx + 1])
    return Foo(values.x[start:end])


data_list = [Data(foo=Foo(torch.tensor([1, 2]))), Data(foo=Foo(torch.tensor([3])))]

batch, slice_dict, inc_dict = collate(
    Data,
    data_list=data_list,
    collate_fn_map={Foo: foo_collate},
    add_batch=False,
)
data = separate(
    Data,
    batch=batch,
    idx=0,
    slice_dict=slice_dict,
    inc_dict=inc_dict,
    separate_fn_map={Foo: foo_separate},
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants