2424from pytensor .graph .replace import vectorize_node
2525from pytensor .graph .traversal import ancestors , applys_between
2626from pytensor .link .c .basic import DualLinker
27+ from pytensor .link .numba import NumbaLinker
2728from pytensor .printing import pprint
2829from pytensor .raise_op import Assert
2930from pytensor .tensor import blas , blas_c
@@ -858,6 +859,10 @@ def test_basic_2(self, axis, np_axis):
858859 ([1 , 0 ], None ),
859860 ],
860861 )
862+ @pytest .mark .xfail (
863+ condition = isinstance (get_default_mode ().linker , NumbaLinker ),
864+ reason = "Numba does not support float16" ,
865+ )
861866 def test_basic_2_float16 (self , axis , np_axis ):
862867 # Test negative values and bigger range to make sure numpy don't do the argmax as on uint16
863868 data = (random (20 , 30 ).astype ("float16" ) - 0.5 ) * 20
@@ -1114,6 +1119,10 @@ def test2(self):
11141119 v_shape = eval_outputs (fct (n , axis ).shape )
11151120 assert tuple (v_shape ) == nfct (data , np_axis ).shape
11161121
1122+ @pytest .mark .xfail (
1123+ condition = isinstance (get_default_mode ().linker , NumbaLinker ),
1124+ reason = "Numba does not support float16" ,
1125+ )
11171126 def test2_float16 (self ):
11181127 # Test negative values and bigger range to make sure numpy don't do the argmax as on uint16
11191128 data = (random (20 , 30 ).astype ("float16" ) - 0.5 ) * 20
@@ -1981,6 +1990,10 @@ def test_mean_single_element(self):
19811990 res = mean (np .zeros (1 ))
19821991 assert res .eval () == 0.0
19831992
1993+ @pytest .mark .xfail (
1994+ condition = isinstance (get_default_mode ().linker , NumbaLinker ),
1995+ reason = "Numba does not support float16" ,
1996+ )
19841997 def test_mean_f16 (self ):
19851998 x = vector (dtype = "float16" )
19861999 y = x .mean ()
@@ -3153,7 +3166,9 @@ class TestSumProdReduceDtype:
31533166 op = CAReduce
31543167 axes = [None , 0 , 1 , [], [0 ], [1 ], [0 , 1 ]]
31553168 methods = ["sum" , "prod" ]
3156- dtypes = list (map (str , ps .all_types ))
3169+ dtypes = tuple (map (str , ps .all_types ))
3170+ if isinstance (mode .linker , NumbaLinker ):
3171+ dtypes = tuple (d for d in dtypes if d != "float16" )
31573172
31583173 # Test the default dtype of a method().
31593174 def test_reduce_default_dtype (self ):
@@ -3313,10 +3328,13 @@ def test_reduce_precision(self):
33133328class TestMeanDtype :
33143329 def test_mean_default_dtype (self ):
33153330 # Test the default dtype of a mean().
3331+ is_numba = isinstance (get_default_mode ().linker , NumbaLinker )
33163332
33173333 # We try multiple axis combinations even though axis should not matter.
33183334 axes = [None , 0 , 1 , [], [0 ], [1 ], [0 , 1 ]]
33193335 for idx , dtype in enumerate (map (str , ps .all_types )):
3336+ if is_numba and dtype == "float16" :
3337+ continue
33203338 axis = axes [idx % len (axes )]
33213339 x = matrix (dtype = dtype )
33223340 m = x .mean (axis = axis )
@@ -3337,7 +3355,13 @@ def test_mean_default_dtype(self):
33373355 "uint16" ,
33383356 "int8" ,
33393357 "int64" ,
3340- "float16" ,
3358+ pytest .param (
3359+ "float16" ,
3360+ marks = pytest .mark .xfail (
3361+ condition = isinstance (get_default_mode ().linker , NumbaLinker ),
3362+ reason = "Numba does not support float16" ,
3363+ ),
3364+ ),
33413365 "float32" ,
33423366 "float64" ,
33433367 "complex64" ,
@@ -3351,7 +3375,13 @@ def test_mean_default_dtype(self):
33513375 "uint16" ,
33523376 "int8" ,
33533377 "int64" ,
3354- "float16" ,
3378+ pytest .param (
3379+ "float16" ,
3380+ marks = pytest .mark .xfail (
3381+ condition = isinstance (get_default_mode ().linker , NumbaLinker ),
3382+ reason = "Numba does not support float16" ,
3383+ ),
3384+ ),
33553385 "float32" ,
33563386 "float64" ,
33573387 "complex64" ,
@@ -3411,10 +3441,13 @@ def test_prod_without_zeros_default_dtype(self):
34113441
34123442 def test_prod_without_zeros_default_acc_dtype (self ):
34133443 # Test the default dtype of a ProdWithoutZeros().
3444+ is_numba = isinstance (get_default_mode ().linker , NumbaLinker )
34143445
34153446 # We try multiple axis combinations even though axis should not matter.
34163447 axes = [None , 0 , 1 , [], [0 ], [1 ], [0 , 1 ]]
34173448 for idx , dtype in enumerate (map (str , ps .all_types )):
3449+ if is_numba and dtype == "float16" :
3450+ continue
34183451 axis = axes [idx % len (axes )]
34193452 x = matrix (dtype = dtype )
34203453 p = ProdWithoutZeros (axis = axis )(x )
@@ -3442,13 +3475,17 @@ def test_prod_without_zeros_default_acc_dtype(self):
34423475 @pytest .mark .slow
34433476 def test_prod_without_zeros_custom_dtype (self ):
34443477 # Test ability to provide your own output dtype for a ProdWithoutZeros().
3445-
3478+ is_numba = isinstance ( get_default_mode (). linker , NumbaLinker )
34463479 # We try multiple axis combinations even though axis should not matter.
34473480 axes = [None , 0 , 1 , [], [0 ], [1 ], [0 , 1 ]]
34483481 idx = 0
34493482 for input_dtype in map (str , ps .all_types ):
3483+ if is_numba and input_dtype == "float16" :
3484+ continue
34503485 x = matrix (dtype = input_dtype )
34513486 for output_dtype in map (str , ps .all_types ):
3487+ if is_numba and output_dtype == "float16" :
3488+ continue
34523489 axis = axes [idx % len (axes )]
34533490 prod_woz_var = ProdWithoutZeros (axis = axis , dtype = output_dtype )(x )
34543491 assert prod_woz_var .dtype == output_dtype
@@ -3464,13 +3501,18 @@ def test_prod_without_zeros_custom_dtype(self):
34643501 @pytest .mark .slow
34653502 def test_prod_without_zeros_custom_acc_dtype (self ):
34663503 # Test ability to provide your own acc_dtype for a ProdWithoutZeros().
3504+ is_numba = isinstance (get_default_mode ().linker , NumbaLinker )
34673505
34683506 # We try multiple axis combinations even though axis should not matter.
34693507 axes = [None , 0 , 1 , [], [0 ], [1 ], [0 , 1 ]]
34703508 idx = 0
34713509 for input_dtype in map (str , ps .all_types ):
3510+ if is_numba and input_dtype == "float16" :
3511+ continue
34723512 x = matrix (dtype = input_dtype )
34733513 for acc_dtype in map (str , ps .all_types ):
3514+ if is_numba and acc_dtype == "float16" :
3515+ continue
34743516 axis = axes [idx % len (axes )]
34753517 # If acc_dtype would force a downcast, we expect a TypeError
34763518 # We always allow int/uint inputs with float/complex outputs.
@@ -3746,7 +3788,20 @@ def test_scalar_error(self):
37463788 with pytest .raises (ValueError , match = "cannot be scalar" ):
37473789 self .op (4 , [4 , 1 ])
37483790
3749- @pytest .mark .parametrize ("dtype" , (np .float16 , np .float32 , np .float64 ))
3791+ @pytest .mark .parametrize (
3792+ "dtype" ,
3793+ (
3794+ pytest .param (
3795+ np .float16 ,
3796+ marks = pytest .mark .xfail (
3797+ condition = isinstance (get_default_mode ().linker , NumbaLinker ),
3798+ reason = "Numba does not support float16" ,
3799+ ),
3800+ ),
3801+ np .float32 ,
3802+ np .float64 ,
3803+ ),
3804+ )
37503805 def test_dtype_param (self , dtype ):
37513806 sol = self .op ([1 , 2 , 3 ], [3 , 2 , 1 ], dtype = dtype )
37523807 assert sol .eval ().dtype == dtype
0 commit comments