-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy path_sum.py
More file actions
130 lines (105 loc) · 3.69 KB
/
_sum.py
File metadata and controls
130 lines (105 loc) · 3.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# SPDX-License-Identifier: MPL-2.0
from __future__ import annotations
from functools import partial
from typing import TYPE_CHECKING, Any, cast, overload
import numpy as np
from numpy.typing import NDArray
from .. import types
from .._import import lazy_singledispatch
from .._validation import validate_axis
if TYPE_CHECKING:
from typing import Literal
from numpy.typing import ArrayLike, DTypeLike
@overload
def sum(
x: ArrayLike | types.ZarrArray, /, *, axis: None = None, dtype: DTypeLike | None = None
) -> np.number[Any]: ...
@overload
def sum(
x: ArrayLike | types.ZarrArray, /, *, axis: Literal[0, 1], dtype: DTypeLike | None = None
) -> NDArray[Any]: ...
@overload
def sum(
x: types.DaskArray, /, *, axis: Literal[0, 1, None] = None, dtype: DTypeLike | None = None
) -> types.DaskArray: ...
def sum(
x: ArrayLike | types.ZarrArray | types.DaskArray,
/,
*,
axis: Literal[0, 1, None] = None,
dtype: DTypeLike | None = None,
) -> NDArray[Any] | np.number[Any] | types.DaskArray:
"""Sum over both or one axis.
Returns
-------
If ``axis`` is :data:`None`, then the sum over all elements is returned as a scalar.
Otherwise, the sum over the given axis is returned as a 1D array.
See Also
--------
:func:`numpy.sum`
"""
validate_axis(axis)
return _sum(x, axis=axis, dtype=dtype)
@lazy_singledispatch
def _sum(
x: ArrayLike | types.CSBase | types.DaskArray,
/,
*,
axis: Literal[0, 1, None] = None,
dtype: DTypeLike | None = None,
) -> NDArray[Any] | np.number[Any] | types.DaskArray:
return cast(NDArray[Any] | np.number[Any], np.sum(x, axis=axis, dtype=dtype))
@_sum.register("fast_array_utils.types:CSBase", "scipy.sparse")
def _(
x: types.CSBase, /, *, axis: Literal[0, 1, None] = None, dtype: DTypeLike | None = None
) -> NDArray[Any] | np.number[Any]:
import scipy.sparse as sp
from ..types import CSMatrix
if isinstance(x, CSMatrix):
x = sp.csr_array(x) if x.format == "csr" else sp.csc_array(x)
return cast(NDArray[Any] | np.number[Any], np.sum(x, axis=axis, dtype=dtype))
@_sum.register("dask.array:Array")
def _(
x: types.DaskArray, /, *, axis: Literal[0, 1, None] = None, dtype: DTypeLike | None = None
) -> types.DaskArray:
if TYPE_CHECKING:
from dask.array.reductions import reduction
else:
from dask.array import reduction
if isinstance(x._meta, np.matrix): # pragma: no cover # noqa: SLF001
msg = "sum does not support numpy matrices"
raise TypeError(msg)
def sum_drop_keepdims(
a: NDArray[Any] | types.CSBase,
/,
*,
axis: tuple[Literal[0], Literal[1]] | Literal[0, 1, None] = None,
dtype: DTypeLike | None = None,
keepdims: bool = False,
) -> NDArray[Any]:
del keepdims
match axis:
case (0 | 1 as n,):
axis = n
case (0, 1) | (1, 0):
axis = None
case tuple(): # pragma: no cover
msg = f"`sum` can only sum over `axis=0|1|(0,1)` but got {axis} instead"
raise ValueError(msg)
rv = sum(a, axis=axis, dtype=dtype)
rv = np.array(rv, ndmin=1) # make sure rv is at least 1D
return rv.reshape((1, len(rv)))
if dtype is None:
# Explicitly use numpy result dtype (e.g. `NDArray[bool].sum().dtype == int64`)
dtype = np.zeros(1, dtype=x.dtype).sum().dtype
return cast(
types.DaskArray,
reduction( # type: ignore[no-untyped-call]
x,
sum_drop_keepdims,
partial(np.sum, dtype=dtype),
axis=axis,
dtype=dtype,
meta=np.array([], dtype=dtype),
),
)