Skip to content

Commit a5a6587

Browse files
committed
Numba does not output numpy scalars
1 parent 92f95df commit a5a6587

File tree

3 files changed

+26
-13
lines changed

3 files changed

+26
-13
lines changed

tests/scalar/test_basic.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from pytensor.compile.mode import Mode
77
from pytensor.graph.fg import FunctionGraph
88
from pytensor.link.c.basic import DualLinker
9+
from pytensor.link.numba import NumbaLinker
910
from pytensor.scalar.basic import (
1011
EQ,
1112
ComplexError,
@@ -368,7 +369,9 @@ def _test_unary(unary_op, x_range):
368369
outi = fi(x_val)
369370
outf = ff(x_val)
370371

371-
assert outi.dtype == outf.dtype, "incorrect dtype"
372+
if not isinstance(ff.maker.linker, NumbaLinker):
373+
# Numba doesn't return numpy scalars
374+
assert outi.dtype == outf.dtype, "incorrect dtype"
372375
assert np.allclose(outi, outf), "insufficient precision"
373376

374377
@staticmethod
@@ -389,7 +392,9 @@ def _test_binary(binary_op, x_range, y_range):
389392
outi = fi(x_val, y_val)
390393
outf = ff(x_val, y_val)
391394

392-
assert outi.dtype == outf.dtype, "incorrect dtype"
395+
if not isinstance(ff.maker.linker, NumbaLinker):
396+
# Numba doesn't return numpy scalars
397+
assert outi.dtype == outf.dtype, "incorrect dtype"
393398
assert np.allclose(outi, outf), "insufficient precision"
394399

395400
def test_true_div(self):
@@ -414,7 +419,9 @@ def test_true_div(self):
414419
outi = fi(x_val, y_val)
415420
outf = ff(x_val, y_val)
416421

417-
assert outi.dtype == outf.dtype, "incorrect dtype"
422+
if not isinstance(ff.maker.linker, NumbaLinker):
423+
# Numba doesn't return numpy scalars
424+
assert outi.dtype == outf.dtype, "incorrect dtype"
418425
assert np.allclose(outi, outf), "insufficient precision"
419426

420427
def test_unary(self):

tests/tensor/test_basic.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from pytensor.graph.basic import Apply, equal_computations
1919
from pytensor.graph.op import Op
2020
from pytensor.graph.replace import clone_replace
21+
from pytensor.link.numba import NumbaLinker
2122
from pytensor.raise_op import Assert
2223
from pytensor.scalar import autocast_float, autocast_float_as
2324
from pytensor.tensor import NoneConst, vectorize
@@ -2191,24 +2192,29 @@ def test_ScalarFromTensor(cast_policy):
21912192
assert ss.owner.op is scalar_from_tensor
21922193
assert ss.type.dtype == tc.type.dtype
21932194

2194-
v = eval_outputs([ss])
2195+
mode = get_default_mode()
2196+
v = eval_outputs([ss], mode=mode)
21952197

21962198
assert v == 56
2197-
assert v.shape == ()
2198-
2199-
if cast_policy == "custom":
2200-
assert isinstance(v, np.int8)
2201-
elif cast_policy == "numpy+floatX":
2202-
assert isinstance(v, np.int64)
2199+
assert isinstance(v, int)
2200+
if not isinstance(mode.linker, NumbaLinker):
2201+
# Numba doesn't return numpy scalars
2202+
assert v.shape == ()
2203+
if cast_policy == "custom":
2204+
assert isinstance(v, np.int8)
2205+
elif cast_policy == "numpy+floatX":
2206+
assert isinstance(v, np.int64)
22032207

22042208
pts = lscalar()
22052209
ss = scalar_from_tensor(pts)
22062210
ss.owner.op.grad([pts], [ss])
22072211
fff = function([pts], ss)
22082212
v = fff(np.asarray(5))
22092213
assert v == 5
2210-
assert isinstance(v, np.int64)
2211-
assert v.shape == ()
2214+
assert isinstance(v, int)
2215+
if not isinstance(mode.linker, NumbaLinker):
2216+
assert isinstance(v, np.int64)
2217+
assert v.shape == ()
22122218

22132219
with pytest.raises(TypeError):
22142220
scalar_from_tensor(vector())

tests/unittest_tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ def _compile_and_check(
259259
numeric_outputs = outputs_function(*numeric_inputs)
260260
numeric_shapes = shapes_function(*numeric_inputs)
261261
for out, shape in zip(numeric_outputs, numeric_shapes, strict=True):
262-
assert np.all(out.shape == shape), (out.shape, shape)
262+
assert np.all(np.asarray(out).shape == shape), (out.shape, shape)
263263

264264

265265
class WrongValue(Exception):

0 commit comments

Comments
 (0)