Skip to content

Commit 3e7d24d

Browse files
authored
perf:improve performance of morton order iter (#3705)
* improve performance of morton order iter * changelog
1 parent 1ed266d commit 3e7d24d

File tree

3 files changed

+65
-21
lines changed

3 files changed

+65
-21
lines changed

changes/3705.bugfix.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix a performance bug in morton curve generation.

src/zarr/core/indexing.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from collections.abc import Iterator, Sequence
88
from dataclasses import dataclass
99
from enum import Enum
10-
from functools import reduce
10+
from functools import lru_cache, reduce
1111
from types import EllipsisType
1212
from typing import (
1313
TYPE_CHECKING,
@@ -1467,16 +1467,21 @@ def decode_morton(z: int, chunk_shape: tuple[int, ...]) -> tuple[int, ...]:
14671467
return tuple(out)
14681468

14691469

1470-
def morton_order_iter(chunk_shape: tuple[int, ...]) -> Iterator[tuple[int, ...]]:
1471-
i = 0
1470+
@lru_cache
1471+
def _morton_order(chunk_shape: tuple[int, ...]) -> tuple[tuple[int, ...], ...]:
1472+
n_total = product(chunk_shape)
14721473
order: list[tuple[int, ...]] = []
1473-
while len(order) < product(chunk_shape):
1474+
i = 0
1475+
while len(order) < n_total:
14741476
m = decode_morton(i, chunk_shape)
1475-
if m not in order and all(x < y for x, y in zip(m, chunk_shape, strict=False)):
1477+
if all(x < y for x, y in zip(m, chunk_shape, strict=False)):
14761478
order.append(m)
14771479
i += 1
1478-
for j in range(product(chunk_shape)):
1479-
yield order[j]
1480+
return tuple(order)
1481+
1482+
1483+
def morton_order_iter(chunk_shape: tuple[int, ...]) -> Iterator[tuple[int, ...]]:
1484+
return iter(_morton_order(tuple(chunk_shape)))
14801485

14811486

14821487
def c_order_iter(chunks_per_shard: tuple[int, ...]) -> Iterator[tuple[int, ...]]:

tests/test_codecs/test_codecs.py

Lines changed: 52 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
TransposeCodec,
1919
)
2020
from zarr.core.buffer import default_buffer_prototype
21-
from zarr.core.indexing import BasicSelection, morton_order_iter
21+
from zarr.core.indexing import BasicSelection, decode_morton, morton_order_iter
2222
from zarr.core.metadata.v3 import ArrayV3Metadata
2323
from zarr.dtype import UInt8
2424
from zarr.errors import ZarrUserWarning
@@ -171,7 +171,8 @@ def test_open(store: Store) -> None:
171171
assert a.metadata == b.metadata
172172

173173

174-
def test_morton() -> None:
174+
def test_morton_exact_order() -> None:
175+
"""Test exact morton ordering for power-of-2 shapes."""
175176
assert list(morton_order_iter((2, 2))) == [(0, 0), (1, 0), (0, 1), (1, 1)]
176177
assert list(morton_order_iter((2, 2, 2))) == [
177178
(0, 0, 0),
@@ -206,21 +207,58 @@ def test_morton() -> None:
206207
@pytest.mark.parametrize(
207208
"shape",
208209
[
209-
[2, 2, 2],
210-
[5, 2],
211-
[2, 5],
212-
[2, 9, 2],
213-
[3, 2, 12],
214-
[2, 5, 1],
215-
[4, 3, 6, 2, 7],
216-
[3, 2, 1, 6, 4, 5, 2],
210+
(2, 2, 2),
211+
(5, 2),
212+
(2, 5),
213+
(2, 9, 2),
214+
(3, 2, 12),
215+
(2, 5, 1),
216+
(4, 3, 6, 2, 7),
217+
(3, 2, 1, 6, 4, 5, 2),
218+
(1,),
219+
(1, 1),
220+
(5, 1, 3),
221+
(1, 4, 1, 2),
217222
],
218223
)
219-
def test_morton2(shape: tuple[int, ...]) -> None:
224+
def test_morton_is_permutation(shape: tuple[int, ...]) -> None:
225+
"""Test that morton_order_iter produces every valid coordinate exactly once."""
226+
import itertools
227+
228+
from zarr.core.common import product
229+
230+
order = list(morton_order_iter(shape))
231+
expected_len = product(shape)
232+
# completeness: every valid coordinate is present
233+
assert len(order) == expected_len
234+
# no duplicates
235+
assert len(set(order)) == expected_len
236+
# all coordinates are within bounds
237+
assert all(all(c < s for c, s in zip(coord, shape, strict=True)) for coord in order)
238+
# the set of coordinates equals the full cartesian product
239+
assert set(order) == set(itertools.product(*(range(s) for s in shape)))
240+
241+
242+
@pytest.mark.parametrize(
243+
"shape",
244+
[
245+
(2, 2),
246+
(4, 4),
247+
(2, 2, 2),
248+
(4, 4, 4),
249+
(2, 2, 2, 2),
250+
],
251+
)
252+
def test_morton_ordering(shape: tuple[int, ...]) -> None:
253+
"""Test that the iteration order matches consecutive decode_morton outputs.
254+
255+
For power-of-2 shapes, every decode_morton output is in-bounds,
256+
so the ordering should be exactly decode_morton(0), decode_morton(1), ...
257+
"""
258+
220259
order = list(morton_order_iter(shape))
221-
for i, x in enumerate(order):
222-
assert x not in order[:i] # no duplicates
223-
assert all(x[j] < shape[j] for j in range(len(shape))) # all indices are within bounds
260+
for i, coord in enumerate(order):
261+
assert coord == decode_morton(i, shape)
224262

225263

226264
@pytest.mark.parametrize("store", ["local", "memory"], indirect=["store"])

0 commit comments

Comments
 (0)