Expand distributed indexing, match numpy indexing scheme#938
Expand distributed indexing, match numpy indexing scheme#938ClaudiaComito wants to merge 214 commits intomainfrom
Conversation
…y slice-indexing. UNTESTED
…sition in the index_proxy
…ays (#937) * Create ci.yaml * Update ci.yaml * Update ci.yaml * Create CITATION.cff * Update CITATION.cff * Update ci.yaml different python and pytorch versions * Update ci.yaml * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Delete pre-commit.yml * Update ci.yaml * Update CITATION.cff * Update tutorial.ipynb delete example with different split axis * Delete logo_heAT.pdf Removal of old logo * ht.nonzero() returns tuple of 1-D arrays instead of n-D arrays * Updated documentation and Unit-tests * replace x.larray with local_x * Code fixes * Fix return type of nonzero function and gout value * Made sure DNDarray meta-data is available to the tuple members * Transpose before if-branching + adjustments to accomodate it * Fixed global shape assignment * Updated changelog Co-authored-by: mtar <m.tarnawa@fz-juelich.de> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Daniel Coquelin <daniel.coquelin@gmail.com> Co-authored-by: Markus Goetz <markus.goetz@kit.edu> Co-authored-by: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com>
…pe and new split axis
|
Thank you for the PR! |
|
Thank you for the PR! |
|
Thank you for the PR! |
| @@ -879,6 +882,641 @@ def fill_diagonal(self, value: float) -> DNDarray: | |||
|
|
|||
| return self | |||
|
|
|||
| def __process_key( | |||
There was a problem hiding this comment.
__process_key and __process_scalar_key do not use self, so they should not be declared inside the DNDArray.
Possibly move to indexing.py?
| @@ -879,6 +882,641 @@ def fill_diagonal(self, value: float) -> DNDarray: | |||
|
|
|||
| return self | |||
|
|
|||
| def __process_key( | |||
| arr: DNDarray, | |||
| key: Union[Tuple[int, ...], List[int, ...]], | |||
There was a problem hiding this comment.
Would be good to define a key type to make type definitions easier.
Index = int | slice | Ellipsis | None
Indexer = Index | tuple[Index, ...]And the apply it everywhere.
| def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDarray: | ||
| def __process_key( | ||
| arr: "DNDarray", | ||
| key: tuple[int, ...] | list[int], |
There was a problem hiding this comment.
We should define an index, indexer type for shorter type hints
Index = int | slice | Ellipsis | None
Indexer = Index | tuple[Index, ...]| key = kst + slices + kend | ||
| else: | ||
| key = key + [slice(None)] * (self.ndim - len(key)) | ||
| from .types import bool as ht_bool, uint8 as ht_uint8 # avoid circulars |
There was a problem hiding this comment.
We probably need a better solution, not sure what performance impact this could have over the long run.
| for i in range(len(key[: self.split + 1])): | ||
| if self.__key_is_singular(key, i, self_proxy): | ||
| new_split = None if i == self.split else new_split - 1 | ||
| def _normalize_index_component(comp): |
There was a problem hiding this comment.
What is the reason for the defining the function here?
There was a problem hiding this comment.
Probably should be moved to the sanitation module.
| if isinstance(key, DNDarray): | ||
| key = _normalize_index_component(key) | ||
| elif isinstance(key, (list, tuple)): | ||
| key = type(key)(_normalize_index_component(k) for k in key) |
There was a problem hiding this comment.
Double check if key is a tuple, and the normalization function just returns the list/tuple in most cases. Logic could be simplified.
There was a problem hiding this comment.
Key also might always be a tuple? Need to actually check the entries of the first element.
…oltz-analytics/heat into 914_adv-indexing-outshape-outsplit
|
|
Create different functions for scalar:
Process key should return the tuple with the key information, and send to the separate functions. Additionally, unnecessary function definitions in Make table and diagram available in a markdown file in the docs folder for now. |
brownbaerchen
left a comment
There was a problem hiding this comment.
I haven't looked at the actual advanced indexing stuff. But I think we would do well to clean up a bit before getting into the details of this. There have been some changes mixed in that, to me, seem unrelated to advanced indexing in heat. They should be moved to separate PRs if we want to keep them. Other changes are cosmetic or temporary and should be removed entirely.
There was a problem hiding this comment.
Can we remove this change from this PR?
There was a problem hiding this comment.
Let's remove this change from this PR (see #2216 for a similar case)
There was a problem hiding this comment.
Not sure about these changes, but some them can be removed from this PR, right?
There was a problem hiding this comment.
Let's remove this change from this PR.
There was a problem hiding this comment.
Let's remove these changes from this PR.
There was a problem hiding this comment.
These changes seem unrelated to advanced indexing. Maybe put in a separate PR.
There was a problem hiding this comment.
Is this needed here?
There was a problem hiding this comment.
Let's remove this from the PR
There was a problem hiding this comment.
This seems unrelated to advanced indexing. Maybe move to a separate PR?
There was a problem hiding this comment.
This seems unrelated to advanced indexing. Maybe move to separate PR?
Description
This pull request introduces a significant overhaul of distributed indexing within
dndarray.py, specifically targeting the__getitem__and__setitem__methods. The primary objective is to achieve full NumPy indexing compliance in a distributed environment while minimizing MPI overhead and memory footprint.The logic has been refactored to identify zero-communication paths ("early out"), and route heavy unordered advanced indexing through (hopefully?) optimized communication.
The following table shows the distribution semantics of the DNDarray indexing operations. The first column shows the operation, the second column shows the distribution semantics of the
key, and the third column shows the distribution semantics of thevalue. The last column shows the distribution semantics of the result.array[key]array[key]array[key]array[key]array[key] = valuearray[key] = valuearray[key] = valuearray[key] = valuearray[key] = valuearray[key] = valuearray[key] = valuearray[key] = valueRouting logic
The flowchart (DRAFT) maps out the MPI routing decisions based on the evaluated state of the indexing key.
graph TD classDef default fill:#ffffff,stroke:#ced4da,stroke-width:1px,color:#212529,rx:4px,ry:4px; classDef terminal fill:#343a40,stroke:#343a40,stroke-width:2px,color:#ffffff,rx:15px,ry:15px; classDef decision fill:#e3f2fd,stroke:#4dabf7,stroke-width:2px,color:#000000; classDef highlight fill:#e8f5e9,stroke:#69b3a2,stroke-width:2px,color:#212529; Start(["Start: arr[key] or arr[key] = value"]):::terminal Norm["Normalize Key (e.g., Bool Masks -> Int Indices)"] ProcessKey["__process_key() Expands dims, aligns shapes"] StateCalc{"Calculate State: split_key_is_ordered"}:::decision Start --> Norm Norm --> ProcessKey ProcessKey --> StateCalc Branch1{"Single item on split axis? (root != None)"}:::decision Branch0{"Operation?"}:::decision BranchNeg1{"Operation?"}:::decision StateCalc -- "1: Ordered / Ascending" --> Branch1 StateCalc -- "0: Unordered / Random" --> Branch0 StateCalc -- "-1: Descending Slice" --> BranchNeg1 subgraph Ordered ["Fast Path: Ordered Indexing (split_key_is_ordered = 1)"] style Ordered fill:#f8f9fa,stroke:#dee2e6,stroke-width:1px,stroke-dasharray: 5 5,color:#495057 Branch1 -- "Yes: Get" --> RootGet["Root fetches local data"] RootGet --> Bcast["MPI.Bcast to all ranks"] Branch1 -- "Yes: Set" --> RootSet["Root updates local data in-place"] Branch1 -- "No" --> FastLocal["Pure Basic Slicing: Apply locally, NO MPI needed"]:::highlight end subgraph Descending ["Descending Slices (split_key_is_ordered = -1)"] style Descending fill:#f8f9fa,stroke:#dee2e6,stroke-width:1px,stroke-dasharray: 5 5,color:#495057 BranchNeg1 -- "Set" --> FlipVal["Flip 'value' array"] FlipVal --> MatchDist["Align distribution map"] MatchDist --> SetLocal["-1 Local Set"] BranchNeg1 -- "Get" --> UnorderedFallback["Converts to arange -> falls back to unordered"] end subgraph Unordered ["Heavy Path: Unordered Advanced Indexing (split_key_is_ordered = 0)"] style Unordered fill:#f8f9fa,stroke:#dee2e6,stroke-width:1px,stroke-dasharray: 5 5,color:#495057 Branch0 -- "Get" --> G_Allgather["MPI.Allgather: Share recv_counts"] G_Allgather --> G_SendIdx["MPI.Isend/Recv: Send requested indices to owning ranks"] G_SendIdx --> G_Fetch["Owning ranks fetch local data"] G_Fetch --> G_SendData["MPI.Isend/Recv: Send requested data back"] G_SendData --> G_Reconstruct["Reconstruct recv_buf on original rank"] Branch0 -- "Set" --> S_CheckVal{"Is 'value' distributed?"}:::decision S_CheckVal -- "No / Scalar" --> S_LocalMask["_advanced_setitem_unordered_local (Apply locally)"] S_CheckVal -- "Yes" --> S_Align["Redistribute 'value' to match 'key' distribution"] S_Align --> S_AllToAll["MPI.Alltoallv: Exchange data AND indices"] S_AllToAll --> S_ApplyRecv["Apply received data to local elements"] end Bcast --> End(["Return / Complete"]):::terminal RootSet --> End FastLocal --> End SetLocal --> End UnorderedFallback -.-> Branch0 G_Reconstruct --> End S_LocalMask --> End S_ApplyRecv --> EndMain changes
To Be Continued...
Memory footprint
Scaling behaviour
Issue/s resolved: #914 #918
Changes proposed:
Type of change
Memory requirements
Performance
Due Diligence
Does this change modify the behaviour of other functions? If so, which?
yes / no
skip ci