@@ -246,7 +246,23 @@ def astype(x, dtype, /, *, copy=True, device=None):
246246
247247 mod .astype = astype
248248
249- # Handle functions that ignore units on input and output
249+ # Functions with output units equal to input units
250+ for func_str in (
251+ "max" ,
252+ "min" ,
253+ "mean" ,
254+ ):
255+
256+ def func (x , / , * args , func_str = func_str , ** kwargs ):
257+ x = asarray (x )
258+ magnitude = xp .asarray (x .magnitude , copy = True )
259+ xp_func = getattr (xp , func_str )
260+ magnitude = xp_func (magnitude , * args , ** kwargs )
261+ return ArrayUnitQuantity (magnitude , x .units )
262+
263+ setattr (mod , func_str , func )
264+
265+ # Functions which ignore units on input and output
250266 for func_str in (
251267 "ones_like" ,
252268 "zeros_like" ,
@@ -261,7 +277,7 @@ def func(x, /, *args, func_str=func_str, **kwargs):
261277 x = asarray (x )
262278 magnitude = xp .asarray (x .magnitude , copy = True )
263279 xp_func = getattr (xp , func_str )
264- magnitude = xp_func (x , * args , ** kwargs )
280+ magnitude = xp_func (magnitude , * args , ** kwargs )
265281 return ArrayUnitQuantity (magnitude , None )
266282
267283 setattr (mod , func_str , func )
@@ -281,7 +297,7 @@ def func(x, /, *args, func_str=func_str, **kwargs):
281297 magnitude = xp .asarray (x .magnitude , copy = True )
282298 units = x .units
283299 xp_func = getattr (xp , func_str )
284- magnitude = xp_func (x , * args , ** kwargs )
300+ magnitude = xp_func (magnitude , * args , ** kwargs )
285301 units = (1 * units + 1 * units ).units
286302 return ArrayUnitQuantity (magnitude , units )
287303
@@ -290,16 +306,28 @@ def func(x, /, *args, func_str=func_str, **kwargs):
290306 # output_unit="variance":
291307 # square of `x.units`,
292308 # unless non-multiplicative, which raises `OffsetUnitCalculusError`
293- def var (x , / , * , axis = None , correction = 0.0 , keepdims = False ):
309+ def var (x , / , * args , ** kwargs ):
294310 x = asarray (x )
295311 magnitude = xp .asarray (x .magnitude , copy = True )
296312 units = x .units
297- magnitude = xp .var (x , axis = axis , correction = correction , keepdims = keepdims )
313+ magnitude = xp .var (magnitude , * args , ** kwargs )
298314 units = ((1 * units + 1 * units ) ** 2 ).units
299315 return ArrayUnitQuantity (magnitude , units )
300316
301317 mod .var = var
302318
319+ # Output unit is the product of the input unit with itself along axis,
320+ # or the input unit to the power of the size of the array for axis=None
321+ def prod (x , / , * args , axis = None , ** kwargs ):
322+ x = asarray (x )
323+ magnitude = xp .asarray (x .magnitude , copy = True )
324+ exponent = magnitude .shape [axis ] if axis is not None else magnitude .size
325+ units = x .units ** exponent
326+ magnitude = xp .prod (magnitude , * args , axis = axis , ** kwargs )
327+ return ArrayUnitQuantity (magnitude , units )
328+
329+ mod .prod = prod
330+
303331 # "mul": product of all units in `all_args`
304332 # - "delta": `first_input_units`, unless non-multiplicative,
305333 # which uses delta version
0 commit comments