From 08e03ba7953ee2bf8a85e814ee870cbdd3427149 Mon Sep 17 00:00:00 2001 From: Filippo Maria Bianchi Date: Fri, 6 Mar 2026 16:32:24 +0100 Subject: [PATCH 1/4] Add support for custom collate and separate functions in collate.py and separate.py --- torch_geometric/data/collate.py | 110 ++++++++++++++++++++++--------- torch_geometric/data/separate.py | 95 +++++++++++++++++++------- 2 files changed, 150 insertions(+), 55 deletions(-) diff --git a/torch_geometric/data/collate.py b/torch_geometric/data/collate.py index 9dfdc0901ffb..8049bbc4aa03 100644 --- a/torch_geometric/data/collate.py +++ b/torch_geometric/data/collate.py @@ -2,6 +2,7 @@ from collections.abc import Mapping, Sequence from typing import ( Any, + Callable, Dict, Iterable, List, @@ -13,9 +14,8 @@ ) import torch -from torch import Tensor - import torch_geometric.typing +from torch import Tensor from torch_geometric import EdgeIndex, Index from torch_geometric.data.data import BaseData from torch_geometric.data.storage import BaseStorage, NodeStorage @@ -29,7 +29,7 @@ from torch_geometric.utils import cumsum, is_sparse, is_torch_sparse_tensor from torch_geometric.utils.sparse import cat -T = TypeVar('T') +T = TypeVar("T") SliceDictType = Dict[str, Union[Tensor, Dict[str, Tensor]]] IncDictType = Dict[str, Union[Tensor, Dict[str, Tensor]]] @@ -41,6 +41,7 @@ def collate( add_batch: bool = True, follow_batch: Optional[Iterable[str]] = None, exclude_keys: Optional[Iterable[str]] = None, + collate_fn_map: Optional[Dict[Any, Callable[..., Tuple[Any, Any, Any]]]] = None, ) -> Tuple[T, SliceDictType, IncDictType]: # Collates a list of `data` objects into a single object of type `cls`. # `collate` can handle both homogeneous and heterogeneous data objects by @@ -88,7 +89,6 @@ def collate( key = out_store._key stores = key_to_stores[key] for attr in stores[0].keys(): - if attr in exclude_keys: # Do not include top-level attribute. continue @@ -96,18 +96,19 @@ def collate( # The `num_nodes` attribute needs special treatment, as we need to # sum their values up instead of merging them to a list: - if attr == 'num_nodes': + if attr == "num_nodes": out_store._num_nodes = values out_store.num_nodes = sum(values) continue # Skip batching of `ptr` vectors for now: - if attr == 'ptr': + if attr == "ptr": continue # Collate attributes into a unified representation: - value, slices, incs = _collate(attr, values, data_list, stores, - increment) + value, slices, incs = _collate( + attr, values, data_list, stores, increment, collate_fn_map + ) # If parts of the data are already on GPU, make sure that auxiliary # data like `batch` or `ptr` are also created on GPU: @@ -133,12 +134,15 @@ def collate( # Add an additional batch vector for the given attribute: if attr in follow_batch: batch, ptr = _batch_and_ptr(slices, device) - out_store[f'{attr}_batch'] = batch - out_store[f'{attr}_ptr'] = ptr + out_store[f"{attr}_batch"] = batch + out_store[f"{attr}_ptr"] = ptr # In case of node-level storages, we add a top-level batch vector it: - if (add_batch and isinstance(stores[0], NodeStorage) - and stores[0].can_infer_num_nodes): + if ( + add_batch + and isinstance(stores[0], NodeStorage) + and stores[0].can_infer_num_nodes + ): repeats = [store.num_nodes or 0 for store in stores] out_store.batch = repeat_interleave(repeats, device=device) out_store.ptr = cumsum(torch.tensor(repeats, device=device)) @@ -152,9 +156,32 @@ def _collate( data_list: List[BaseData], stores: List[BaseStorage], increment: bool, + collate_fn_map: Optional[Dict[Any, Callable[..., Tuple[Any, Any, Any]]]] = None, ) -> Tuple[Any, Any, Any]: - elem = values[0] + elem_type = type(elem) + + if collate_fn_map is not None: + if elem_type in collate_fn_map: + return collate_fn_map[elem_type]( + key=key, + values=values, + data_list=data_list, + stores=stores, + increment=increment, + collate_fn_map=collate_fn_map, + ) + + for collate_type in collate_fn_map: + if isinstance(elem, collate_type): + return collate_fn_map[collate_type]( + key=key, + values=values, + data_list=data_list, + stores=stores, + increment=increment, + collate_fn_map=collate_fn_map, + ) if isinstance(elem, Tensor) and not is_sparse(elem): # Concatenate a list of `torch.Tensor` along the `cat_dim`. @@ -169,13 +196,12 @@ def _collate( incs = get_incs(key, values, data_list, stores) if incs.dim() > 1 or int(incs[-1]) != 0: values = [ - value + inc.to(value.device) - for value, inc in zip(values, incs) + value + inc.to(value.device) for value, inc in zip(values, incs) ] else: incs = None - if getattr(elem, 'is_nested', False): + if getattr(elem, "is_nested", False): tensors = [] for nested_tensor in values: tensors.extend(nested_tensor.unbind()) @@ -184,13 +210,15 @@ def _collate( return value, slices, incs out = None - if (torch.utils.data.get_worker_info() is not None - and not isinstance(elem, (Index, EdgeIndex))): + if torch.utils.data.get_worker_info() is not None and not isinstance( + elem, (Index, EdgeIndex) + ): # Write directly into shared memory to avoid an extra copy: numel = sum(value.numel() for value in values) if torch_geometric.typing.WITH_PT20: storage = elem.untyped_storage()._new_shared( - numel * elem.element_size(), device=elem.device) + numel * elem.element_size(), device=elem.device + ) else: storage = elem.storage()._new_shared(numel, device=elem.device) shape = list(elem.size()) @@ -229,7 +257,7 @@ def _collate( # NOTE: `cat_dim` may return a tuple to allow for diagonal stacking. key = str(key) cat_dim = data_list[0].__cat_dim__(key, elem, stores[0]) - cat_dims = (cat_dim, ) if isinstance(cat_dim, int) else cat_dim + cat_dims = (cat_dim,) if isinstance(cat_dim, int) else cat_dim repeats = [[value.size(dim) for dim in cat_dims] for value in values] slices = cumsum(torch.tensor(repeats)) if is_torch_sparse_tensor(elem): @@ -255,16 +283,32 @@ def _collate( value_dict, slice_dict, inc_dict = {}, {}, {} for key in elem.keys(): value_dict[key], slice_dict[key], inc_dict[key] = _collate( - key, [v[key] for v in values], data_list, stores, increment) + key, + [v[key] for v in values], + data_list, + stores, + increment, + collate_fn_map, + ) return value_dict, slice_dict, inc_dict - elif (isinstance(elem, Sequence) and not isinstance(elem, str) - and len(elem) > 0 and isinstance(elem[0], (Tensor, SparseTensor))): + elif ( + isinstance(elem, Sequence) + and not isinstance(elem, str) + and len(elem) > 0 + and isinstance(elem[0], (Tensor, SparseTensor)) + ): # Recursively collate elements of lists. value_list, slice_list, inc_list = [], [], [] for i in range(len(elem)): - value, slices, incs = _collate(key, [v[i] for v in values], - data_list, stores, increment) + value, slices, incs = _collate( + key, + [v[i] for v in values], + data_list, + stores, + increment, + collate_fn_map, + ) value_list.append(value) slice_list.append(slices) inc_list.append(incs) @@ -280,7 +324,7 @@ def _batch_and_ptr( slices: Any, device: Optional[torch.device] = None, ) -> Tuple[Any, Any]: - if (isinstance(slices, Tensor) and slices.dim() == 1): + if isinstance(slices, Tensor) and slices.dim() == 1: # Default case, turn slices tensor into batch. repeats = slices[1:] - slices[:-1] batch = repeat_interleave(repeats.tolist(), device=device) @@ -294,8 +338,11 @@ def _batch_and_ptr( batch[k], ptr[k] = _batch_and_ptr(v, device) return batch, ptr - elif (isinstance(slices, Sequence) and not isinstance(slices, str) - and isinstance(slices[0], Tensor)): + elif ( + isinstance(slices, Sequence) + and not isinstance(slices, str) + and isinstance(slices[0], Tensor) + ): # Recursively batch elements of lists. batch, ptr = [], [] for s in slices: @@ -316,12 +363,13 @@ def repeat_interleave( repeats: List[int], device: Optional[torch.device] = None, ) -> Tensor: - outs = [torch.full((n, ), i, device=device) for i, n in enumerate(repeats)] + outs = [torch.full((n,), i, device=device) for i, n in enumerate(repeats)] return torch.cat(outs, dim=0) -def get_incs(key, values: List[Any], data_list: List[BaseData], - stores: List[BaseStorage]) -> Tensor: +def get_incs( + key, values: List[Any], data_list: List[BaseData], stores: List[BaseStorage] +) -> Tensor: repeats = [ data.__inc__(key, value, store) for value, data, store in zip(values, data_list, stores) diff --git a/torch_geometric/data/separate.py b/torch_geometric/data/separate.py index 2910b6679f60..de2803fa39be 100644 --- a/torch_geometric/data/separate.py +++ b/torch_geometric/data/separate.py @@ -1,15 +1,14 @@ from collections.abc import Mapping, Sequence -from typing import Any, Type, TypeVar +from typing import Any, Callable, Dict, Optional, Type, TypeVar from torch import Tensor - from torch_geometric import EdgeIndex, Index from torch_geometric.data.data import BaseData from torch_geometric.data.storage import BaseStorage from torch_geometric.typing import SparseTensor, TensorFrame from torch_geometric.utils import narrow -T = TypeVar('T') +T = TypeVar("T") def separate( @@ -19,6 +18,7 @@ def separate( slice_dict: Any, inc_dict: Any = None, decrement: bool = True, + separate_fn_map: Optional[Dict[Any, Callable[..., Any]]] = None, ) -> T: # Separates the individual element from a `batch` at index `idx`. # `separate` can handle both homogeneous and heterogeneous data objects by @@ -45,12 +45,21 @@ def separate( slices = slice_dict[attr] incs = inc_dict[attr] if decrement else None - data_store[attr] = _separate(attr, batch_store[attr], idx, slices, - incs, batch, batch_store, decrement) + data_store[attr] = _separate( + attr, + batch_store[attr], + idx, + slices, + incs, + batch, + batch_store, + decrement, + separate_fn_map, + ) # The `num_nodes` attribute needs special treatment, as we cannot infer # the real number of nodes from the total number of nodes alone: - if hasattr(batch_store, '_num_nodes'): + if hasattr(batch_store, "_num_nodes"): data_store.num_nodes = batch_store._num_nodes[idx] return data @@ -65,7 +74,37 @@ def _separate( batch: BaseData, store: BaseStorage, decrement: bool, + separate_fn_map: Optional[Dict[Any, Callable[..., Any]]] = None, ) -> Any: + elem_type = type(values) + + if separate_fn_map is not None: + if elem_type in separate_fn_map: + return separate_fn_map[elem_type]( + key=key, + values=values, + idx=idx, + slices=slices, + incs=incs, + batch=batch, + store=store, + decrement=decrement, + separate_fn_map=separate_fn_map, + ) + + for separate_type in separate_fn_map: + if isinstance(values, separate_type): + return separate_fn_map[separate_type]( + key=key, + values=values, + idx=idx, + slices=slices, + incs=incs, + batch=batch, + store=store, + decrement=decrement, + separate_fn_map=separate_fn_map, + ) if isinstance(values, Tensor): # Narrow a `torch.Tensor` based on `slices`. @@ -87,8 +126,7 @@ def _separate( value._sort_order = values._cat_metadata.sort_order[idx] value._is_undirected = values._cat_metadata.is_undirected[idx] - if (decrement and incs is not None - and (incs.dim() > 1 or int(incs[idx]) != 0)): + if decrement and incs is not None and (incs.dim() > 1 or int(incs[idx]) != 0): value = value - incs[idx].to(value.device) return value @@ -98,7 +136,7 @@ def _separate( # NOTE: `cat_dim` may return a tuple to allow for diagonal stacking. key = str(key) cat_dim = batch.__cat_dim__(key, values, store) - cat_dims = (cat_dim, ) if isinstance(cat_dim, int) else cat_dim + cat_dims = (cat_dim,) if isinstance(cat_dim, int) else cat_dim for i, dim in enumerate(cat_dims): start, end = int(slices[idx][i]), int(slices[idx + 1][i]) values = values.narrow(dim, start, end - start) @@ -113,30 +151,37 @@ def _separate( elif isinstance(values, Mapping): # Recursively separate elements of dictionaries. return { - key: - _separate( - key, + sub_key: _separate( + sub_key, value, idx, - slices=slices[key], - incs=incs[key] if decrement else None, + slices=slices[sub_key], + incs=incs[sub_key] if decrement and incs is not None else None, batch=batch, store=store, decrement=decrement, + separate_fn_map=separate_fn_map, ) - for key, value in values.items() + for sub_key, value in values.items() } - elif (isinstance(values, Sequence) and isinstance(values[0], Sequence) - and not isinstance(values[0], str) and len(values[0]) > 0 - and isinstance(values[0][0], (Tensor, SparseTensor)) - and isinstance(slices, Sequence)): + elif ( + isinstance(values, Sequence) + and isinstance(values[0], Sequence) + and not isinstance(values[0], str) + and len(values[0]) > 0 + and isinstance(values[0][0], (Tensor, SparseTensor)) + and isinstance(slices, Sequence) + ): # Recursively separate elements of lists of lists. return [value[idx] for value in values] - elif (isinstance(values, Sequence) and not isinstance(values, str) - and isinstance(values[0], (Tensor, SparseTensor)) - and isinstance(slices, Sequence)): + elif ( + isinstance(values, Sequence) + and not isinstance(values, str) + and isinstance(values[0], (Tensor, SparseTensor)) + and isinstance(slices, Sequence) + ): # Recursively separate elements of lists of Tensors/SparseTensors. return [ _separate( @@ -144,11 +189,13 @@ def _separate( value, idx, slices=slices[i], - incs=incs[i] if decrement else None, + incs=incs[i] if decrement and incs is not None else None, batch=batch, store=store, decrement=decrement, - ) for i, value in enumerate(values) + separate_fn_map=separate_fn_map, + ) + for i, value in enumerate(values) ] else: From a4ffd5dddc6e5e9e6a27e12926ad20122ec51026 Mon Sep 17 00:00:00 2001 From: Carlo Date: Fri, 6 Mar 2026 16:47:41 +0100 Subject: [PATCH 2/4] Add unit tests for custom collate and separate functions in test_collate_fn_separate_fn.py --- test/data/test_collate_fn_separate_fn.py | 274 +++++++++++++++++++++++ 1 file changed, 274 insertions(+) create mode 100644 test/data/test_collate_fn_separate_fn.py diff --git a/test/data/test_collate_fn_separate_fn.py b/test/data/test_collate_fn_separate_fn.py new file mode 100644 index 000000000000..46d2f95297c1 --- /dev/null +++ b/test/data/test_collate_fn_separate_fn.py @@ -0,0 +1,274 @@ +import torch +from torch import Tensor + +from torch_geometric.data import Data +from torch_geometric.data.collate import collate +from torch_geometric.data.separate import separate + + +class Foo: + def __init__(self, x: Tensor): + self.x = x + + +class FooChild(Foo): + pass + + +def foo_collate( + *, + key: str, + values: list[Foo], + data_list, + stores, + increment: bool, + collate_fn_map, +): + xs = [v.x for v in values] + sizes = torch.tensor([x.size(0) for x in xs], dtype=torch.long) + slices = torch.cat([torch.zeros(1, dtype=torch.long), sizes.cumsum(0)], dim=0) + return Foo(torch.cat(xs, dim=0)), slices, None + + +def foo_separate( + *, + key: str, + values: Foo, + idx: int, + slices: Tensor, + incs, + batch, + store, + decrement: bool, + separate_fn_map, +): + start, end = int(slices[idx]), int(slices[idx + 1]) + return Foo(values.x[start:end]) + + +class FooBatch(Foo): + pass + + +def foobatch_collate( + *, + key: str, + values: list[Foo], + data_list, + stores, + increment: bool, + collate_fn_map, +): + xs = [v.x for v in values] + sizes = torch.tensor([x.size(0) for x in xs], dtype=torch.long) + slices = torch.cat([torch.zeros(1, dtype=torch.long), sizes.cumsum(0)], dim=0) + return FooBatch(torch.cat(xs, dim=0)), slices, None + + +def foobatch_separate( + *, + key: str, + values: Foo, + idx: int, + slices: Tensor, + incs, + batch, + store, + decrement: bool, + separate_fn_map, +): + start, end = int(slices[idx]), int(slices[idx + 1]) + return FooBatch(values.x[start:end]) + + +class MapObj: + def __init__(self, x: Tensor): + self.x = x + + +def mapobj_collate( + *, + key: str, + values: list[MapObj], + data_list, + stores, + increment: bool, + collate_fn_map, +): + xs = [v.x for v in values] + sizes = torch.tensor([x.size(0) for x in xs], dtype=torch.long) + slices = torch.cat([torch.zeros(1, dtype=torch.long), sizes.cumsum(0)], dim=0) + value = {"x": torch.cat(xs, dim=0)} + slice_dict = {"x": slices} + return value, slice_dict, None + + +class SeqObj: + def __init__(self, a: Tensor, b: Tensor): + self.a = a + self.b = b + + +def seqobj_collate( + *, + key: str, + values: list[SeqObj], + data_list, + stores, + increment: bool, + collate_fn_map, +): + as_ = [v.a for v in values] + bs_ = [v.b for v in values] + a_sizes = torch.tensor([x.size(0) for x in as_], dtype=torch.long) + b_sizes = torch.tensor([x.size(0) for x in bs_], dtype=torch.long) + a_slices = torch.cat([torch.zeros(1, dtype=torch.long), a_sizes.cumsum(0)], dim=0) + b_slices = torch.cat([torch.zeros(1, dtype=torch.long), b_sizes.cumsum(0)], dim=0) + return [torch.cat(as_, dim=0), torch.cat(bs_, dim=0)], [a_slices, b_slices], None + + +def test_collate_separate_fn_map_roundtrip(): + data_list = [ + Data(foo=Foo(torch.tensor([1, 2]))), + Data(foo=Foo(torch.tensor([3]))), + Data(foo=Foo(torch.tensor([4, 5, 6]))), + ] + + batch, slice_dict, inc_dict = collate( + Data, + data_list=data_list, + increment=True, + add_batch=False, + collate_fn_map={Foo: foo_collate}, + ) + + for i, ref in enumerate(data_list): + out = separate( + Data, + batch=batch, + idx=i, + slice_dict=slice_dict, + inc_dict=inc_dict, + decrement=True, + separate_fn_map={Foo: foo_separate}, + ) + assert isinstance(out.foo, Foo) + assert torch.equal(out.foo.x, ref.foo.x) + + +def test_collate_separate_fn_map_isinstance_dispatch_and_recursion(): + data_list = [ + Data( + foo=FooChild(torch.tensor([1])), + nested={"foo": Foo(torch.tensor([2, 3])), "t": torch.tensor([4])}, + ), + Data( + foo=FooChild(torch.tensor([5, 6])), + nested={"foo": Foo(torch.tensor([7])), "t": torch.tensor([8, 9])}, + ), + ] + + batch, slice_dict, inc_dict = collate( + Data, + data_list=data_list, + increment=True, + add_batch=False, + collate_fn_map={Foo: foo_collate}, + ) + + out0 = separate( + Data, + batch=batch, + idx=0, + slice_dict=slice_dict, + inc_dict=inc_dict, + decrement=True, + separate_fn_map={Foo: foo_separate}, + ) + + assert torch.equal(out0.foo.x, data_list[0].foo.x) + assert isinstance(out0.nested["foo"], Foo) + assert torch.equal(out0.nested["foo"].x, data_list[0].nested["foo"].x) + assert torch.equal(out0.nested["t"], data_list[0].nested["t"]) + + +def test_separate_mapping_handles_incs_none_from_custom_collate(): + data_list = [ + Data(m=MapObj(torch.tensor([1, 2]))), + Data(m=MapObj(torch.tensor([3]))), + ] + + batch, slice_dict, inc_dict = collate( + Data, + data_list=data_list, + increment=True, + add_batch=False, + collate_fn_map={MapObj: mapobj_collate}, + ) + + out0 = separate( + Data, + batch=batch, + idx=0, + slice_dict=slice_dict, + inc_dict=inc_dict, + decrement=True, + ) + + assert torch.equal(out0.m["x"], torch.tensor([1, 2])) + + +def test_separate_sequence_handles_incs_none_from_custom_collate(): + data_list = [ + Data(s=SeqObj(torch.tensor([1]), torch.tensor([2, 3]))), + Data(s=SeqObj(torch.tensor([4, 5]), torch.tensor([6]))), + ] + + batch, slice_dict, inc_dict = collate( + Data, + data_list=data_list, + increment=True, + add_batch=False, + collate_fn_map={SeqObj: seqobj_collate}, + ) + + out1 = separate( + Data, + batch=batch, + idx=1, + slice_dict=slice_dict, + inc_dict=inc_dict, + decrement=True, + ) + + assert torch.equal(out1.s[0], torch.tensor([4, 5])) + assert torch.equal(out1.s[1], torch.tensor([6])) + + +def test_separate_fn_map_isinstance_dispatch_for_subclass_batch_value(): + data_list = [ + Data(foo=Foo(torch.tensor([1, 2]))), + Data(foo=Foo(torch.tensor([3]))), + ] + + batch, slice_dict, inc_dict = collate( + Data, + data_list=data_list, + increment=True, + add_batch=False, + collate_fn_map={Foo: foobatch_collate}, + ) + + out0 = separate( + Data, + batch=batch, + idx=0, + slice_dict=slice_dict, + inc_dict=inc_dict, + decrement=True, + separate_fn_map={Foo: foobatch_separate}, + ) + + assert isinstance(batch.foo, FooBatch) + assert isinstance(out0.foo, FooBatch) + assert torch.equal(out0.foo.x, data_list[0].foo.x) From ac7efe71bbe8c4efd85216c9f0bf3c52fa6a818d Mon Sep 17 00:00:00 2001 From: Filippo Maria Bianchi Date: Fri, 6 Mar 2026 18:12:09 +0100 Subject: [PATCH 3/4] Update changelog and documentation for custom collate and separate function support; refactor type hints in collate and separate modules --- CHANGELOG.md | 4 ++ docs/source/advanced/batching.rst | 52 ++++++++++++++++++++++++ test/data/test_collate_fn_separate_fn.py | 51 +++++++++++++++++++++++ torch_geometric/data/collate.py | 6 ++- torch_geometric/data/separate.py | 6 ++- 5 files changed, 115 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 529b429f7483..53b83caa0d0a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/). ### Added +- Added support for custom `collate_fn_map` and `separate_fn_map` hooks in + `torch_geometric.data.collate.collate` and + `torch_geometric.data.separate.separate` + ### Changed - Dropped support for TorchScript in `GATConv` and `GATv2Conv` for correctness ([#10596](https://github.com/pyg-team/pytorch_geometric/pull/10596)) diff --git a/docs/source/advanced/batching.rst b/docs/source/advanced/batching.rst index 881755eda112..ed2b12333511 100644 --- a/docs/source/advanced/batching.rst +++ b/docs/source/advanced/batching.rst @@ -236,3 +236,55 @@ Specifically, a list of attributes of shape :obj:`[num_features]` should be retu >>> MyDataBatch(num_nodes=6, edge_index=[2, 8], foo=[2, 16]) As desired, :obj:`batch.foo` is now described by two dimensions: The batch dimension and the feature dimension. + +Custom Collate and Separate Functions +------------------------------------- + +For advanced use-cases, :func:`torch_geometric.data.collate.collate` and +:func:`torch_geometric.data.separate.separate` accept function maps that let +you override batching/separation for user-defined types. +Dispatch first checks exact type matches, and then falls back to +:func:`isinstance` checks. + +.. code-block:: python + + import torch + from torch import Tensor + from torch_geometric.data import Data + from torch_geometric.data.collate import collate + from torch_geometric.data.separate import separate + + + class Foo: + def __init__(self, x: Tensor): + self.x = x + + + def foo_collate(*, values, **kwargs): + xs = [v.x for v in values] + sizes = torch.tensor([x.size(0) for x in xs], dtype=torch.long) + slices = torch.cat([torch.zeros(1, dtype=torch.long), sizes.cumsum(0)]) + return Foo(torch.cat(xs, dim=0)), slices, None + + + def foo_separate(*, values, idx, slices, **kwargs): + start, end = int(slices[idx]), int(slices[idx + 1]) + return Foo(values.x[start:end]) + + + data_list = [Data(foo=Foo(torch.tensor([1, 2]))), Data(foo=Foo(torch.tensor([3])))] + + batch, slice_dict, inc_dict = collate( + Data, + data_list=data_list, + collate_fn_map={Foo: foo_collate}, + add_batch=False, + ) + data = separate( + Data, + batch=batch, + idx=0, + slice_dict=slice_dict, + inc_dict=inc_dict, + separate_fn_map={Foo: foo_separate}, + ) diff --git a/test/data/test_collate_fn_separate_fn.py b/test/data/test_collate_fn_separate_fn.py index 46d2f95297c1..9ca27eeafc02 100644 --- a/test/data/test_collate_fn_separate_fn.py +++ b/test/data/test_collate_fn_separate_fn.py @@ -272,3 +272,54 @@ def test_separate_fn_map_isinstance_dispatch_for_subclass_batch_value(): assert isinstance(batch.foo, FooBatch) assert isinstance(out0.foo, FooBatch) assert torch.equal(out0.foo.x, data_list[0].foo.x) + + +def test_collate_fn_map_prefers_exact_type_match(): + data_list = [ + Data(foo=FooChild(torch.tensor([1]))), + Data(foo=FooChild(torch.tensor([2, 3]))), + ] + + batch, _, _ = collate( + Data, + data_list=data_list, + increment=True, + add_batch=False, + collate_fn_map={ + Foo: foo_collate, + FooChild: foobatch_collate, + }, + ) + + assert isinstance(batch.foo, FooBatch) + + +def test_separate_fn_map_prefers_exact_type_match(): + data_list = [ + Data(foo=Foo(torch.tensor([1, 2]))), + Data(foo=Foo(torch.tensor([3]))), + ] + + batch, slice_dict, inc_dict = collate( + Data, + data_list=data_list, + increment=True, + add_batch=False, + collate_fn_map={Foo: foobatch_collate}, + ) + + out0 = separate( + Data, + batch=batch, + idx=0, + slice_dict=slice_dict, + inc_dict=inc_dict, + decrement=True, + separate_fn_map={ + Foo: foo_separate, + FooBatch: foobatch_separate, + }, + ) + + assert isinstance(out0.foo, FooBatch) + assert torch.equal(out0.foo.x, data_list[0].foo.x) diff --git a/torch_geometric/data/collate.py b/torch_geometric/data/collate.py index 8049bbc4aa03..894ef3e2b1e4 100644 --- a/torch_geometric/data/collate.py +++ b/torch_geometric/data/collate.py @@ -30,6 +30,8 @@ from torch_geometric.utils.sparse import cat T = TypeVar("T") +CollateFn = Callable[..., Tuple[Any, Any, Any]] +CollateFnMap = Dict[Type[Any], CollateFn] SliceDictType = Dict[str, Union[Tensor, Dict[str, Tensor]]] IncDictType = Dict[str, Union[Tensor, Dict[str, Tensor]]] @@ -41,7 +43,7 @@ def collate( add_batch: bool = True, follow_batch: Optional[Iterable[str]] = None, exclude_keys: Optional[Iterable[str]] = None, - collate_fn_map: Optional[Dict[Any, Callable[..., Tuple[Any, Any, Any]]]] = None, + collate_fn_map: Optional[CollateFnMap] = None, ) -> Tuple[T, SliceDictType, IncDictType]: # Collates a list of `data` objects into a single object of type `cls`. # `collate` can handle both homogeneous and heterogeneous data objects by @@ -156,7 +158,7 @@ def _collate( data_list: List[BaseData], stores: List[BaseStorage], increment: bool, - collate_fn_map: Optional[Dict[Any, Callable[..., Tuple[Any, Any, Any]]]] = None, + collate_fn_map: Optional[CollateFnMap] = None, ) -> Tuple[Any, Any, Any]: elem = values[0] elem_type = type(elem) diff --git a/torch_geometric/data/separate.py b/torch_geometric/data/separate.py index de2803fa39be..e5c5f571d374 100644 --- a/torch_geometric/data/separate.py +++ b/torch_geometric/data/separate.py @@ -9,6 +9,8 @@ from torch_geometric.utils import narrow T = TypeVar("T") +SeparateFn = Callable[..., Any] +SeparateFnMap = Dict[Type[Any], SeparateFn] def separate( @@ -18,7 +20,7 @@ def separate( slice_dict: Any, inc_dict: Any = None, decrement: bool = True, - separate_fn_map: Optional[Dict[Any, Callable[..., Any]]] = None, + separate_fn_map: Optional[SeparateFnMap] = None, ) -> T: # Separates the individual element from a `batch` at index `idx`. # `separate` can handle both homogeneous and heterogeneous data objects by @@ -74,7 +76,7 @@ def _separate( batch: BaseData, store: BaseStorage, decrement: bool, - separate_fn_map: Optional[Dict[Any, Callable[..., Any]]] = None, + separate_fn_map: Optional[SeparateFnMap] = None, ) -> Any: elem_type = type(values) From a9d2a5375193c79d6231ced21fcf5f3f5efc887c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 12 Mar 2026 10:36:33 +0000 Subject: [PATCH 4/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- test/data/test_collate_fn_separate_fn.py | 28 ++++++++++---- torch_geometric/data/collate.py | 48 +++++++++--------------- torch_geometric/data/separate.py | 33 +++++++--------- 3 files changed, 52 insertions(+), 57 deletions(-) diff --git a/test/data/test_collate_fn_separate_fn.py b/test/data/test_collate_fn_separate_fn.py index 9ca27eeafc02..ce7802a76271 100644 --- a/test/data/test_collate_fn_separate_fn.py +++ b/test/data/test_collate_fn_separate_fn.py @@ -26,7 +26,8 @@ def foo_collate( ): xs = [v.x for v in values] sizes = torch.tensor([x.size(0) for x in xs], dtype=torch.long) - slices = torch.cat([torch.zeros(1, dtype=torch.long), sizes.cumsum(0)], dim=0) + slices = torch.cat([torch.zeros(1, dtype=torch.long), + sizes.cumsum(0)], dim=0) return Foo(torch.cat(xs, dim=0)), slices, None @@ -61,7 +62,8 @@ def foobatch_collate( ): xs = [v.x for v in values] sizes = torch.tensor([x.size(0) for x in xs], dtype=torch.long) - slices = torch.cat([torch.zeros(1, dtype=torch.long), sizes.cumsum(0)], dim=0) + slices = torch.cat([torch.zeros(1, dtype=torch.long), + sizes.cumsum(0)], dim=0) return FooBatch(torch.cat(xs, dim=0)), slices, None @@ -97,7 +99,8 @@ def mapobj_collate( ): xs = [v.x for v in values] sizes = torch.tensor([x.size(0) for x in xs], dtype=torch.long) - slices = torch.cat([torch.zeros(1, dtype=torch.long), sizes.cumsum(0)], dim=0) + slices = torch.cat([torch.zeros(1, dtype=torch.long), + sizes.cumsum(0)], dim=0) value = {"x": torch.cat(xs, dim=0)} slice_dict = {"x": slices} return value, slice_dict, None @@ -122,9 +125,12 @@ def seqobj_collate( bs_ = [v.b for v in values] a_sizes = torch.tensor([x.size(0) for x in as_], dtype=torch.long) b_sizes = torch.tensor([x.size(0) for x in bs_], dtype=torch.long) - a_slices = torch.cat([torch.zeros(1, dtype=torch.long), a_sizes.cumsum(0)], dim=0) - b_slices = torch.cat([torch.zeros(1, dtype=torch.long), b_sizes.cumsum(0)], dim=0) - return [torch.cat(as_, dim=0), torch.cat(bs_, dim=0)], [a_slices, b_slices], None + a_slices = torch.cat([torch.zeros(1, dtype=torch.long), + a_sizes.cumsum(0)], dim=0) + b_slices = torch.cat([torch.zeros(1, dtype=torch.long), + b_sizes.cumsum(0)], dim=0) + return [torch.cat(as_, dim=0), + torch.cat(bs_, dim=0)], [a_slices, b_slices], None def test_collate_separate_fn_map_roundtrip(): @@ -160,11 +166,17 @@ def test_collate_separate_fn_map_isinstance_dispatch_and_recursion(): data_list = [ Data( foo=FooChild(torch.tensor([1])), - nested={"foo": Foo(torch.tensor([2, 3])), "t": torch.tensor([4])}, + nested={ + "foo": Foo(torch.tensor([2, 3])), + "t": torch.tensor([4]) + }, ), Data( foo=FooChild(torch.tensor([5, 6])), - nested={"foo": Foo(torch.tensor([7])), "t": torch.tensor([8, 9])}, + nested={ + "foo": Foo(torch.tensor([7])), + "t": torch.tensor([8, 9]) + }, ), ] diff --git a/torch_geometric/data/collate.py b/torch_geometric/data/collate.py index 894ef3e2b1e4..a788b7a0d57e 100644 --- a/torch_geometric/data/collate.py +++ b/torch_geometric/data/collate.py @@ -14,8 +14,9 @@ ) import torch -import torch_geometric.typing from torch import Tensor + +import torch_geometric.typing from torch_geometric import EdgeIndex, Index from torch_geometric.data.data import BaseData from torch_geometric.data.storage import BaseStorage, NodeStorage @@ -108,9 +109,8 @@ def collate( continue # Collate attributes into a unified representation: - value, slices, incs = _collate( - attr, values, data_list, stores, increment, collate_fn_map - ) + value, slices, incs = _collate(attr, values, data_list, stores, + increment, collate_fn_map) # If parts of the data are already on GPU, make sure that auxiliary # data like `batch` or `ptr` are also created on GPU: @@ -140,11 +140,8 @@ def collate( out_store[f"{attr}_ptr"] = ptr # In case of node-level storages, we add a top-level batch vector it: - if ( - add_batch - and isinstance(stores[0], NodeStorage) - and stores[0].can_infer_num_nodes - ): + if (add_batch and isinstance(stores[0], NodeStorage) + and stores[0].can_infer_num_nodes): repeats = [store.num_nodes or 0 for store in stores] out_store.batch = repeat_interleave(repeats, device=device) out_store.ptr = cumsum(torch.tensor(repeats, device=device)) @@ -198,7 +195,8 @@ def _collate( incs = get_incs(key, values, data_list, stores) if incs.dim() > 1 or int(incs[-1]) != 0: values = [ - value + inc.to(value.device) for value, inc in zip(values, incs) + value + inc.to(value.device) + for value, inc in zip(values, incs) ] else: incs = None @@ -213,14 +211,12 @@ def _collate( out = None if torch.utils.data.get_worker_info() is not None and not isinstance( - elem, (Index, EdgeIndex) - ): + elem, (Index, EdgeIndex)): # Write directly into shared memory to avoid an extra copy: numel = sum(value.numel() for value in values) if torch_geometric.typing.WITH_PT20: storage = elem.untyped_storage()._new_shared( - numel * elem.element_size(), device=elem.device - ) + numel * elem.element_size(), device=elem.device) else: storage = elem.storage()._new_shared(numel, device=elem.device) shape = list(elem.size()) @@ -259,7 +255,7 @@ def _collate( # NOTE: `cat_dim` may return a tuple to allow for diagonal stacking. key = str(key) cat_dim = data_list[0].__cat_dim__(key, elem, stores[0]) - cat_dims = (cat_dim,) if isinstance(cat_dim, int) else cat_dim + cat_dims = (cat_dim, ) if isinstance(cat_dim, int) else cat_dim repeats = [[value.size(dim) for dim in cat_dims] for value in values] slices = cumsum(torch.tensor(repeats)) if is_torch_sparse_tensor(elem): @@ -294,12 +290,8 @@ def _collate( ) return value_dict, slice_dict, inc_dict - elif ( - isinstance(elem, Sequence) - and not isinstance(elem, str) - and len(elem) > 0 - and isinstance(elem[0], (Tensor, SparseTensor)) - ): + elif (isinstance(elem, Sequence) and not isinstance(elem, str) + and len(elem) > 0 and isinstance(elem[0], (Tensor, SparseTensor))): # Recursively collate elements of lists. value_list, slice_list, inc_list = [], [], [] for i in range(len(elem)): @@ -340,11 +332,8 @@ def _batch_and_ptr( batch[k], ptr[k] = _batch_and_ptr(v, device) return batch, ptr - elif ( - isinstance(slices, Sequence) - and not isinstance(slices, str) - and isinstance(slices[0], Tensor) - ): + elif (isinstance(slices, Sequence) and not isinstance(slices, str) + and isinstance(slices[0], Tensor)): # Recursively batch elements of lists. batch, ptr = [], [] for s in slices: @@ -365,13 +354,12 @@ def repeat_interleave( repeats: List[int], device: Optional[torch.device] = None, ) -> Tensor: - outs = [torch.full((n,), i, device=device) for i, n in enumerate(repeats)] + outs = [torch.full((n, ), i, device=device) for i, n in enumerate(repeats)] return torch.cat(outs, dim=0) -def get_incs( - key, values: List[Any], data_list: List[BaseData], stores: List[BaseStorage] -) -> Tensor: +def get_incs(key, values: List[Any], data_list: List[BaseData], + stores: List[BaseStorage]) -> Tensor: repeats = [ data.__inc__(key, value, store) for value, data, store in zip(values, data_list, stores) diff --git a/torch_geometric/data/separate.py b/torch_geometric/data/separate.py index e5c5f571d374..21cb6b0c8ef0 100644 --- a/torch_geometric/data/separate.py +++ b/torch_geometric/data/separate.py @@ -2,6 +2,7 @@ from typing import Any, Callable, Dict, Optional, Type, TypeVar from torch import Tensor + from torch_geometric import EdgeIndex, Index from torch_geometric.data.data import BaseData from torch_geometric.data.storage import BaseStorage @@ -128,7 +129,8 @@ def _separate( value._sort_order = values._cat_metadata.sort_order[idx] value._is_undirected = values._cat_metadata.is_undirected[idx] - if decrement and incs is not None and (incs.dim() > 1 or int(incs[idx]) != 0): + if decrement and incs is not None and (incs.dim() > 1 + or int(incs[idx]) != 0): value = value - incs[idx].to(value.device) return value @@ -138,7 +140,7 @@ def _separate( # NOTE: `cat_dim` may return a tuple to allow for diagonal stacking. key = str(key) cat_dim = batch.__cat_dim__(key, values, store) - cat_dims = (cat_dim,) if isinstance(cat_dim, int) else cat_dim + cat_dims = (cat_dim, ) if isinstance(cat_dim, int) else cat_dim for i, dim in enumerate(cat_dims): start, end = int(slices[idx][i]), int(slices[idx + 1][i]) values = values.narrow(dim, start, end - start) @@ -153,7 +155,8 @@ def _separate( elif isinstance(values, Mapping): # Recursively separate elements of dictionaries. return { - sub_key: _separate( + sub_key: + _separate( sub_key, value, idx, @@ -167,23 +170,16 @@ def _separate( for sub_key, value in values.items() } - elif ( - isinstance(values, Sequence) - and isinstance(values[0], Sequence) - and not isinstance(values[0], str) - and len(values[0]) > 0 - and isinstance(values[0][0], (Tensor, SparseTensor)) - and isinstance(slices, Sequence) - ): + elif (isinstance(values, Sequence) and isinstance(values[0], Sequence) + and not isinstance(values[0], str) and len(values[0]) > 0 + and isinstance(values[0][0], (Tensor, SparseTensor)) + and isinstance(slices, Sequence)): # Recursively separate elements of lists of lists. return [value[idx] for value in values] - elif ( - isinstance(values, Sequence) - and not isinstance(values, str) - and isinstance(values[0], (Tensor, SparseTensor)) - and isinstance(slices, Sequence) - ): + elif (isinstance(values, Sequence) and not isinstance(values, str) + and isinstance(values[0], (Tensor, SparseTensor)) + and isinstance(slices, Sequence)): # Recursively separate elements of lists of Tensors/SparseTensors. return [ _separate( @@ -196,8 +192,7 @@ def _separate( store=store, decrement=decrement, separate_fn_map=separate_fn_map, - ) - for i, value in enumerate(values) + ) for i, value in enumerate(values) ] else: