1414import types
1515from typing import Generic
1616
17+ from array_api_compat import size
1718from pint import Quantity
1819from pint .facets .plain import MagnitudeT , PlainQuantity
1920
@@ -107,6 +108,15 @@ def __repr__(self):
107108 f" '{ self .units } '\n )>"
108109 )
109110
111+ def __mul__ (self , other ):
112+ if hasattr (other , "units" ):
113+ magnitude = self ._call_super_method ("__mul__" , other .magnitude )
114+ units = self .units * other .units
115+ else :
116+ magnitude = self ._call_super_method ("__mul__" , other )
117+ units = self .units
118+ return ArrayUnitQuantity (magnitude , units )
119+
110120 ## Linear Algebra Methods ##
111121 def __matmul__ (self , other ):
112122 return mod .matmul (self , other )
@@ -133,11 +143,11 @@ def mT(self):
133143 def __dlpack_device__ (self ):
134144 return self .magnitude .__dlpack_device__ ()
135145
136- def __dlpack__ (self , ** kwargs ):
146+ def __dlpack__ (self , stream = None , max_version = None , dl_device = None , copy = None ):
137147 # really not sure how to define this
138- return self .magnitude .__dlpack__ (** kwargs )
139-
140- __dlpack__ . __signature__ = inspect . signature ( xp . empty ( 0 ). __dlpack__ )
148+ return self .magnitude .__dlpack__ (
149+ stream = stream , max_version = max_version , dl_device = dl_device , copy = copy
150+ )
141151
142152 def to_device (self , device , / , * , stream = None ):
143153 _magnitude = self ._magnitude .to_device (device , stream = stream )
@@ -185,7 +195,7 @@ def fun(self, name=name):
185195 "__lshift__" ,
186196 "__lt__" ,
187197 "__mod__" ,
188- "__mul__" ,
198+ # "__mul__",
189199 "__ne__" ,
190200 "__or__" ,
191201 "__pow__" ,
@@ -301,7 +311,8 @@ def manip_fun(x, *args, **kwargs):
301311 magnitude = xp .asarray (x .magnitude , copy = True )
302312 units = x .units
303313 elif hasattr (x , "__array_namespace__" ):
304- magnitude = x
314+ x = asarray (x )
315+ magnitude = xp .asarray (x .magnitude , copy = True )
305316 units = None
306317 one_array = True
307318 else :
@@ -390,7 +401,9 @@ def astype(x, dtype, /, *, copy=True, device=None):
390401 if device is None and not copy and dtype == x .dtype :
391402 return x
392403 x = asarray (x )
393- magnitude = xp .astype (x .magnitude , dtype , copy = copy , device = device )
404+ # https://github.com/data-apis/array-api-compat/issues/226
405+ # magnitude = xp.astype(x.magnitude, dtype, copy=copy, device=device)
406+ magnitude = xp .astype (x .magnitude , dtype , copy = copy )
394407 return ArrayUnitQuantity (magnitude , x .units )
395408
396409 mod .astype = astype
@@ -600,7 +613,7 @@ def where(condition, x1, x2, /):
600613 def fun (x , / , * args , func_str = func_str , ** kwargs ):
601614 x = asarray (x )
602615 magnitude = xp .asarray (x .magnitude , copy = True )
603- magnitude = getattr (xp , func_str )(x , * args , ** kwargs )
616+ magnitude = getattr (xp , func_str )(magnitude , * args , ** kwargs )
604617 return ArrayUnitQuantity (magnitude , x .units )
605618
606619 setattr (mod , func_str , fun )
@@ -651,6 +664,20 @@ def fun(x1, x2, /, *args, func_str=func_str, **kwargs):
651664
652665 setattr (mod , func_str , fun )
653666
667+ def multiply (x1 , x2 , / , * args , ** kwargs ):
668+ x1 = asarray (x1 )
669+ x2 = asarray (x2 )
670+
671+ units = x1 .units * x2 .units
672+
673+ x1_magnitude = xp .asarray (x1 .magnitude , copy = True )
674+ x2_magnitude = x2 .m_as (x1 .units )
675+
676+ magnitude = xp .multiply (x1_magnitude , x2_magnitude , * args , ** kwargs )
677+ return ArrayUnitQuantity (magnitude , units )
678+
679+ mod .multiply = multiply
680+
654681 ## Indexing Functions
655682 def take (x , indices , / , ** kwargs ):
656683 magnitude = xp .take (x .magnitude , indices .magnitude , ** kwargs )
@@ -791,7 +818,7 @@ def var(x, /, *args, **kwargs):
791818 def prod (x , / , * args , axis = None , ** kwargs ):
792819 x = asarray (x )
793820 magnitude = xp .asarray (x .magnitude , copy = True )
794- exponent = magnitude .shape [axis ] if axis is not None else magnitude . size
821+ exponent = magnitude .shape [axis ] if axis is not None else size ( magnitude )
795822 units = x .units ** exponent
796823 magnitude = xp .prod (magnitude , * args , axis = axis , ** kwargs )
797824 return ArrayUnitQuantity (magnitude , units )
0 commit comments