Skip to content

Commit 0da540c

Browse files
authored
Merge pull request #380 from ev-br/tuples_not_lists
`broadcast_arrays` et al: change the return type into a tuple
2 parents e240e75 + d2e2f0e commit 0da540c

File tree

10 files changed

+51
-14
lines changed

10 files changed

+51
-14
lines changed

array_api_compat/cupy/_aliases.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
from builtins import bool as py_bool
4+
from typing import Literal
45

56
import cupy as cp
67

@@ -139,6 +140,15 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array:
139140
return cp.take_along_axis(x, indices, axis=axis)
140141

141142

143+
# https://github.com/cupy/cupy/pull/9582
144+
def broadcast_arrays(*arrays: Array) -> tuple[Array, ...]:
145+
return tuple(cp.broadcast_arrays(*arrays))
146+
147+
148+
def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> tuple[Array, ...]:
149+
return tuple(cp.meshgrid(*arrays, indexing=indexing))
150+
151+
142152
# These functions are completely new here. If the library already has them
143153
# (i.e., numpy 2.0), use the library version instead of our wrapper.
144154
if hasattr(cp, 'vecdot'):
@@ -161,7 +171,8 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array:
161171
'atan2', 'atanh', 'bitwise_left_shift',
162172
'bitwise_invert', 'bitwise_right_shift',
163173
'bool', 'concat', 'count_nonzero', 'pow', 'sign',
164-
'ceil', 'floor', 'trunc', 'take_along_axis']
174+
'ceil', 'floor', 'trunc', 'take_along_axis',
175+
'broadcast_arrays', 'meshgrid']
165176

166177

167178
def __dir__() -> list[str]:

array_api_compat/cupy/_info.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,4 +333,4 @@ def devices(self):
333333
__array_namespace_info__.dtypes
334334
335335
"""
336-
return [cuda.Device(i) for i in range(cuda.runtime.getDeviceCount())]
336+
return tuple(cuda.Device(i) for i in range(cuda.runtime.getDeviceCount()))

array_api_compat/dask/array/_info.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ def dtypes(
379379
return res
380380
raise ValueError(f"unsupported kind: {kind!r}")
381381

382-
def devices(self) -> list[Device]:
382+
def devices(self) -> tuple[Device]:
383383
"""
384384
The devices supported by Dask.
385385
@@ -404,4 +404,4 @@ def devices(self) -> list[Device]:
404404
['cpu', DASK_DEVICE]
405405
406406
"""
407-
return ["cpu", _DASK_DEVICE]
407+
return ("cpu", _DASK_DEVICE)

array_api_compat/numpy/_info.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ def dtypes(
332332
return res
333333
raise ValueError(f"unsupported kind: {kind!r}")
334334

335-
def devices(self) -> list[Device]:
335+
def devices(self) -> tuple[Device]:
336336
"""
337337
The devices supported by NumPy.
338338
@@ -357,7 +357,7 @@ def devices(self) -> list[Device]:
357357
['cpu']
358358
359359
"""
360-
return ["cpu"]
360+
return ("cpu",)
361361

362362

363363
__all__ = ["__array_namespace_info__"]

array_api_compat/torch/_aliases.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -711,9 +711,9 @@ def astype(
711711
return x.to(dtype=dtype, copy=copy)
712712

713713

714-
def broadcast_arrays(*arrays: Array) -> list[Array]:
714+
def broadcast_arrays(*arrays: Array) -> tuple[Array, ...]:
715715
shape = torch.broadcast_shapes(*[a.shape for a in arrays])
716-
return [torch.broadcast_to(a, shape) for a in arrays]
716+
return tuple(torch.broadcast_to(a, shape) for a in arrays)
717717

718718
# Note that these named tuples aren't actually part of the standard namespace,
719719
# but I don't see any issue with exporting the names here regardless.
@@ -897,10 +897,11 @@ def sign(x: Array, /) -> Array:
897897
return out
898898

899899

900-
def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> list[Array]:
901-
# enforce the default of 'xy'
902-
# TODO: is the return type a list or a tuple
903-
return list(torch.meshgrid(*arrays, indexing=indexing))
900+
def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> tuple[Array, ...]:
901+
# torch <= 2.9 emits a UserWarning: "torch.meshgrid: in an upcoming release, it
902+
# will be required to pass the indexing argument."
903+
# Thus always pass it explicitly.
904+
return torch.meshgrid(*arrays, indexing=indexing)
904905

905906

906907
__all__ = ['asarray', 'result_type', 'can_cast',

array_api_compat/torch/_info.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -366,4 +366,4 @@ def devices(self):
366366
break
367367
i += 1
368368

369-
return devices
369+
return tuple(devices)

dask-xfails.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,13 @@ array_api_tests/test_linalg.py::test_matrix_norm
129129
array_api_tests/test_linalg.py::test_qr
130130
array_api_tests/test_manipulation_functions.py::test_roll
131131

132+
# 2025.12 support
133+
array_api_tests/test_has_names.py::test_has_names[manipulation-broadcast_shapes]
134+
array_api_tests/test_signatures.py::test_func_signature[broadcast_shapes]
135+
array_api_tests/test_data_type_functions.py::TestBroadcastShapes::test_broadcast_shapes
136+
array_api_tests/test_data_type_functions.py::TestBroadcastShapes::test_empty
137+
array_api_tests/test_data_type_functions.py::TestBroadcastShapes::test_error
138+
132139
# Stubs have a comment: (**note**: libraries may return ``NaN`` to match Python behavior.)
133140
array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity]
134141
array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity]

numpy-1-22-xfails.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,13 @@ array_api_tests/test_signatures.py::test_func_signature[bitwise_right_shift]
152152
array_api_tests/test_signatures.py::test_func_signature[bitwise_xor]
153153
array_api_tests/test_data_type_functions.py::TestResultType::test_with_scalars
154154

155+
# 2025.12 support
156+
157+
# older numpies return lists not tuples
158+
array_api_tests/test_creation_functions.py::test_meshgrid
159+
array_api_tests/test_data_type_functions.py::test_broadcast_arrays
160+
161+
155162
# Stubs have a comment: (**note**: libraries may return ``NaN`` to match Python behavior.); Apparently,NumPy does just that
156163
array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity]
157164
array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity]

numpy-1-26-xfails.txt

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,15 @@ array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__]
4242

4343
# 2024.12 support
4444
array_api_tests/test_data_type_functions.py::TestResultType::test_with_scalars
45-
4645
array_api_tests/test_operators_and_elementwise_functions.py::test_where_with_scalars
4746

47+
# 2025.12 support
48+
49+
# older numpies return lists not tuples
50+
array_api_tests/test_creation_functions.py::test_meshgrid
51+
array_api_tests/test_data_type_functions.py::test_broadcast_arrays
52+
53+
4854
# Stubs have a comment: (**note**: libraries may return ``NaN`` to match Python behavior.); Apparently, NumPy does just that
4955
array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity]
5056
array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity]

torch-xfails.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,11 @@ array_api_tests/test_signatures.py::test_func_signature[from_dlpack]
144144
# Argument 'max_version' missing from signature
145145
array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__]
146146

147+
# 2025.12 support
148+
149+
# broadcast_shapes emits a RuntimeError where the spec says ValueError
150+
array_api_tests/test_data_type_functions.py::TestBroadcastShapes::test_error
151+
147152

148153
# 2024.12 support: binary functions reject python scalar arguments
149154
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[atan2]

0 commit comments

Comments
 (0)