fix: return 0-D array for full reductions per Array API standard#932
fix: return 0-D array for full reductions per Array API standard#932Abineshabee wants to merge 5 commits intopydata:mainfrom
Conversation
Fixes pydata#921 - Remove if out.ndim == 0: return out[()] in SparseArray.reduce() which was converting 0-D COO arrays into NumPy scalars - Fix mean() to handle 0-D result uniformly without scalar branch - Fix _einsum_single() two scalar escape hatches to return 0-D COO - Update est_einsum.py scalar branches to use .todense() - Add TestArrayAPIReductions with 16 regression tests
for more information, see https://pre-commit.ci
|
One thing that can be improved here: |
|
Fixed out[()] scalar unwrap — reduce() and var() now return 0-D sparse arrays directly per Array API standard. All 316 tests passing. @hameerabbasi please re-review! |
|
Please let me know if any further changes are needed! |
hameerabbasi
left a comment
There was a problem hiding this comment.
Just one small concern.
| mean along all axes. | ||
|
|
||
| >>> s.mean() | ||
| np.float64(0.5) |
There was a problem hiding this comment.
Huh? s.mean() should produce SOME output.
Merging this PR will degrade performance by 54.75%
|
| Benchmark | BASE |
HEAD |
Efficiency | |
|---|---|---|---|---|
| ❌ | test_index_slice[side=100-rank=2-format='gcxs'] |
1.7 ms | 3.8 ms | -54.75% |
Comparing Abineshabee:fix/array-api-reduction-returns-0d-array (9304ca7) with main (3589a7c)
Fixes #921
Problem
sparse.sum()and all other full reductions (max,min,prod,mean,any,all,einsum) were returning a NumPy scalar instead of a 0-D array, violating the Array API standard.Root Cause
SparseArray.reduce()hadif out.ndim == 0: return out[()]which unwrapped the 0-D COO into a scalar._einsum_single()also had two explicit scalar return paths.Changes
_sparse_array.py— remove scalar unwrap inreduce(), unifymean()path_common.py— fix both scalar paths in_einsum_single()test_einsum.py— update scalar-output branches to use.todense()test_array_function.py— addTestArrayAPIReductionswith 16 regression testsTesting
6079 passed, 0 failed