Skip to content

Commit 19a2397

Browse files
committed
XFAIL/SKIP float16 tests
1 parent 0bb06df commit 19a2397

File tree

3 files changed

+72
-5
lines changed

3 files changed

+72
-5
lines changed

tests/tensor/rewriting/test_basic.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from pytensor.graph.rewriting.basic import check_stack_trace, out2in
1919
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
2020
from pytensor.graph.rewriting.utils import rewrite_graph
21+
from pytensor.link.numba import NumbaLinker
2122
from pytensor.printing import debugprint, pprint
2223
from pytensor.raise_op import Assert, CheckAndRaise
2324
from pytensor.scalar import Composite, float64
@@ -1206,6 +1207,10 @@ def test_sum_bool_upcast(self):
12061207
f(5)
12071208

12081209

1210+
@pytest.mark.xfail(
1211+
condition=isinstance(get_default_mode().linker, NumbaLinker),
1212+
reason="Numba does not support float16",
1213+
)
12091214
class TestLocalOptAllocF16(TestLocalOptAlloc):
12101215
dtype = "float16"
12111216

tests/tensor/test_math.py

Lines changed: 60 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from pytensor.graph.replace import vectorize_node
2525
from pytensor.graph.traversal import ancestors, applys_between
2626
from pytensor.link.c.basic import DualLinker
27+
from pytensor.link.numba import NumbaLinker
2728
from pytensor.printing import pprint
2829
from pytensor.raise_op import Assert
2930
from pytensor.tensor import blas, blas_c
@@ -858,6 +859,10 @@ def test_basic_2(self, axis, np_axis):
858859
([1, 0], None),
859860
],
860861
)
862+
@pytest.mark.xfail(
863+
condition=isinstance(get_default_mode().linker, NumbaLinker),
864+
reason="Numba does not support float16",
865+
)
861866
def test_basic_2_float16(self, axis, np_axis):
862867
# Test negative values and bigger range to make sure numpy don't do the argmax as on uint16
863868
data = (random(20, 30).astype("float16") - 0.5) * 20
@@ -1114,6 +1119,10 @@ def test2(self):
11141119
v_shape = eval_outputs(fct(n, axis).shape)
11151120
assert tuple(v_shape) == nfct(data, np_axis).shape
11161121

1122+
@pytest.mark.xfail(
1123+
condition=isinstance(get_default_mode().linker, NumbaLinker),
1124+
reason="Numba does not support float16",
1125+
)
11171126
def test2_float16(self):
11181127
# Test negative values and bigger range to make sure numpy don't do the argmax as on uint16
11191128
data = (random(20, 30).astype("float16") - 0.5) * 20
@@ -1981,6 +1990,10 @@ def test_mean_single_element(self):
19811990
res = mean(np.zeros(1))
19821991
assert res.eval() == 0.0
19831992

1993+
@pytest.mark.xfail(
1994+
condition=isinstance(get_default_mode().linker, NumbaLinker),
1995+
reason="Numba does not support float16",
1996+
)
19841997
def test_mean_f16(self):
19851998
x = vector(dtype="float16")
19861999
y = x.mean()
@@ -3153,7 +3166,9 @@ class TestSumProdReduceDtype:
31533166
op = CAReduce
31543167
axes = [None, 0, 1, [], [0], [1], [0, 1]]
31553168
methods = ["sum", "prod"]
3156-
dtypes = list(map(str, ps.all_types))
3169+
dtypes = tuple(map(str, ps.all_types))
3170+
if isinstance(mode.linker, NumbaLinker):
3171+
dtypes = tuple(d for d in dtypes if d != "float16")
31573172

31583173
# Test the default dtype of a method().
31593174
def test_reduce_default_dtype(self):
@@ -3313,10 +3328,13 @@ def test_reduce_precision(self):
33133328
class TestMeanDtype:
33143329
def test_mean_default_dtype(self):
33153330
# Test the default dtype of a mean().
3331+
is_numba = isinstance(get_default_mode().linker, NumbaLinker)
33163332

33173333
# We try multiple axis combinations even though axis should not matter.
33183334
axes = [None, 0, 1, [], [0], [1], [0, 1]]
33193335
for idx, dtype in enumerate(map(str, ps.all_types)):
3336+
if is_numba and dtype == "float16":
3337+
continue
33203338
axis = axes[idx % len(axes)]
33213339
x = matrix(dtype=dtype)
33223340
m = x.mean(axis=axis)
@@ -3337,7 +3355,13 @@ def test_mean_default_dtype(self):
33373355
"uint16",
33383356
"int8",
33393357
"int64",
3340-
"float16",
3358+
pytest.param(
3359+
"float16",
3360+
marks=pytest.mark.xfail(
3361+
condition=isinstance(get_default_mode().linker, NumbaLinker),
3362+
reason="Numba does not support float16",
3363+
),
3364+
),
33413365
"float32",
33423366
"float64",
33433367
"complex64",
@@ -3351,7 +3375,13 @@ def test_mean_default_dtype(self):
33513375
"uint16",
33523376
"int8",
33533377
"int64",
3354-
"float16",
3378+
pytest.param(
3379+
"float16",
3380+
marks=pytest.mark.xfail(
3381+
condition=isinstance(get_default_mode().linker, NumbaLinker),
3382+
reason="Numba does not support float16",
3383+
),
3384+
),
33553385
"float32",
33563386
"float64",
33573387
"complex64",
@@ -3411,10 +3441,13 @@ def test_prod_without_zeros_default_dtype(self):
34113441

34123442
def test_prod_without_zeros_default_acc_dtype(self):
34133443
# Test the default dtype of a ProdWithoutZeros().
3444+
is_numba = isinstance(get_default_mode().linker, NumbaLinker)
34143445

34153446
# We try multiple axis combinations even though axis should not matter.
34163447
axes = [None, 0, 1, [], [0], [1], [0, 1]]
34173448
for idx, dtype in enumerate(map(str, ps.all_types)):
3449+
if is_numba and dtype == "float16":
3450+
continue
34183451
axis = axes[idx % len(axes)]
34193452
x = matrix(dtype=dtype)
34203453
p = ProdWithoutZeros(axis=axis)(x)
@@ -3442,13 +3475,17 @@ def test_prod_without_zeros_default_acc_dtype(self):
34423475
@pytest.mark.slow
34433476
def test_prod_without_zeros_custom_dtype(self):
34443477
# Test ability to provide your own output dtype for a ProdWithoutZeros().
3445-
3478+
is_numba = isinstance(get_default_mode().linker, NumbaLinker)
34463479
# We try multiple axis combinations even though axis should not matter.
34473480
axes = [None, 0, 1, [], [0], [1], [0, 1]]
34483481
idx = 0
34493482
for input_dtype in map(str, ps.all_types):
3483+
if is_numba and input_dtype == "float16":
3484+
continue
34503485
x = matrix(dtype=input_dtype)
34513486
for output_dtype in map(str, ps.all_types):
3487+
if is_numba and output_dtype == "float16":
3488+
continue
34523489
axis = axes[idx % len(axes)]
34533490
prod_woz_var = ProdWithoutZeros(axis=axis, dtype=output_dtype)(x)
34543491
assert prod_woz_var.dtype == output_dtype
@@ -3464,13 +3501,18 @@ def test_prod_without_zeros_custom_dtype(self):
34643501
@pytest.mark.slow
34653502
def test_prod_without_zeros_custom_acc_dtype(self):
34663503
# Test ability to provide your own acc_dtype for a ProdWithoutZeros().
3504+
is_numba = isinstance(get_default_mode().linker, NumbaLinker)
34673505

34683506
# We try multiple axis combinations even though axis should not matter.
34693507
axes = [None, 0, 1, [], [0], [1], [0, 1]]
34703508
idx = 0
34713509
for input_dtype in map(str, ps.all_types):
3510+
if is_numba and input_dtype == "float16":
3511+
continue
34723512
x = matrix(dtype=input_dtype)
34733513
for acc_dtype in map(str, ps.all_types):
3514+
if is_numba and acc_dtype == "float16":
3515+
continue
34743516
axis = axes[idx % len(axes)]
34753517
# If acc_dtype would force a downcast, we expect a TypeError
34763518
# We always allow int/uint inputs with float/complex outputs.
@@ -3746,7 +3788,20 @@ def test_scalar_error(self):
37463788
with pytest.raises(ValueError, match="cannot be scalar"):
37473789
self.op(4, [4, 1])
37483790

3749-
@pytest.mark.parametrize("dtype", (np.float16, np.float32, np.float64))
3791+
@pytest.mark.parametrize(
3792+
"dtype",
3793+
(
3794+
pytest.param(
3795+
np.float16,
3796+
marks=pytest.mark.xfail(
3797+
condition=isinstance(get_default_mode().linker, NumbaLinker),
3798+
reason="Numba does not support float16",
3799+
),
3800+
),
3801+
np.float32,
3802+
np.float64,
3803+
),
3804+
)
37503805
def test_dtype_param(self, dtype):
37513806
sol = self.op([1, 2, 3], [3, 2, 1], dtype=dtype)
37523807
assert sol.eval().dtype == dtype

tests/tensor/test_slinalg.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010

1111
from pytensor import function, grad
1212
from pytensor import tensor as pt
13+
from pytensor.compile import get_default_mode
1314
from pytensor.configdefaults import config
1415
from pytensor.graph.basic import equal_computations
16+
from pytensor.link.numba import NumbaLinker
1517
from pytensor.tensor import TensorVariable
1618
from pytensor.tensor.slinalg import (
1719
Cholesky,
@@ -606,6 +608,8 @@ def test_solve_correctness(self):
606608
)
607609

608610
def test_solve_dtype(self):
611+
is_numba = isinstance(get_default_mode().linker, NumbaLinker)
612+
609613
dtypes = [
610614
"uint8",
611615
"uint16",
@@ -626,6 +630,9 @@ def test_solve_dtype(self):
626630

627631
# try all dtype combinations
628632
for A_dtype, b_dtype in itertools.product(dtypes, dtypes):
633+
if is_numba and (A_dtype == "float16" or b_dtype == "float16"):
634+
# Numba does not support float16
635+
continue
629636
A = matrix(dtype=A_dtype)
630637
b = matrix(dtype=b_dtype)
631638
x = op(A, b)

0 commit comments

Comments
 (0)