Skip to content

Commit c3b70b3

Browse files
committed
Numba does not output numpy scalars
1 parent 7d4ac51 commit c3b70b3

File tree

3 files changed

+28
-13
lines changed

3 files changed

+28
-13
lines changed

tests/scalar/test_basic.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from pytensor.compile.mode import Mode
77
from pytensor.graph.fg import FunctionGraph
88
from pytensor.link.c.basic import DualLinker
9+
from pytensor.link.numba import NumbaLinker
910
from pytensor.scalar.basic import (
1011
EQ,
1112
ComplexError,
@@ -368,7 +369,9 @@ def _test_unary(unary_op, x_range):
368369
outi = fi(x_val)
369370
outf = ff(x_val)
370371

371-
assert outi.dtype == outf.dtype, "incorrect dtype"
372+
if not isinstance(ff.maker.linker, NumbaLinker):
373+
# Numba doesn't return numpy scalars
374+
assert outi.dtype == outf.dtype, "incorrect dtype"
372375
assert np.allclose(outi, outf), "insufficient precision"
373376

374377
@staticmethod
@@ -389,7 +392,9 @@ def _test_binary(binary_op, x_range, y_range):
389392
outi = fi(x_val, y_val)
390393
outf = ff(x_val, y_val)
391394

392-
assert outi.dtype == outf.dtype, "incorrect dtype"
395+
if not isinstance(ff.maker.linker, NumbaLinker):
396+
# Numba doesn't return numpy scalars
397+
assert outi.dtype == outf.dtype, "incorrect dtype"
393398
assert np.allclose(outi, outf), "insufficient precision"
394399

395400
def test_true_div(self):
@@ -414,7 +419,9 @@ def test_true_div(self):
414419
outi = fi(x_val, y_val)
415420
outf = ff(x_val, y_val)
416421

417-
assert outi.dtype == outf.dtype, "incorrect dtype"
422+
if not isinstance(ff.maker.linker, NumbaLinker):
423+
# Numba doesn't return numpy scalars
424+
assert outi.dtype == outf.dtype, "incorrect dtype"
418425
assert np.allclose(outi, outf), "insufficient precision"
419426

420427
def test_unary(self):

tests/tensor/test_basic.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from pytensor.graph.basic import Apply, equal_computations
1919
from pytensor.graph.op import Op
2020
from pytensor.graph.replace import clone_replace
21+
from pytensor.link.numba import NumbaLinker
2122
from pytensor.raise_op import Assert
2223
from pytensor.scalar import autocast_float, autocast_float_as
2324
from pytensor.tensor import NoneConst, vectorize
@@ -2193,24 +2194,31 @@ def test_ScalarFromTensor(cast_policy):
21932194
assert ss.owner.op is scalar_from_tensor
21942195
assert ss.type.dtype == tc.type.dtype
21952196

2196-
v = eval_outputs([ss])
2197+
mode = get_default_mode()
2198+
v = eval_outputs([ss], mode=mode)
21972199

21982200
assert v == 56
2199-
assert v.shape == ()
2200-
2201-
if cast_policy == "custom":
2202-
assert isinstance(v, np.int8)
2203-
elif cast_policy == "numpy+floatX":
2204-
assert isinstance(v, np.int64)
2201+
if isinstance(mode.linker, NumbaLinker):
2202+
# Numba doesn't return numpy scalars
2203+
assert isinstance(v, int)
2204+
else:
2205+
assert v.shape == ()
2206+
if cast_policy == "custom":
2207+
assert isinstance(v, np.int8)
2208+
elif cast_policy == "numpy+floatX":
2209+
assert isinstance(v, np.int64)
22052210

22062211
pts = lscalar()
22072212
ss = scalar_from_tensor(pts)
22082213
ss.owner.op.grad([pts], [ss])
22092214
fff = function([pts], ss)
22102215
v = fff(np.asarray(5))
22112216
assert v == 5
2212-
assert isinstance(v, np.int64)
2213-
assert v.shape == ()
2217+
if isinstance(mode.linker, NumbaLinker):
2218+
assert isinstance(v, int)
2219+
else:
2220+
assert isinstance(v, np.int64)
2221+
assert v.shape == ()
22142222

22152223
with pytest.raises(TypeError):
22162224
scalar_from_tensor(vector())

tests/unittest_tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ def _compile_and_check(
259259
numeric_outputs = outputs_function(*numeric_inputs)
260260
numeric_shapes = shapes_function(*numeric_inputs)
261261
for out, shape in zip(numeric_outputs, numeric_shapes, strict=True):
262-
assert np.all(out.shape == shape), (out.shape, shape)
262+
assert np.all(np.asarray(out).shape == shape), (out.shape, shape)
263263

264264

265265
class WrongValue(Exception):

0 commit comments

Comments
 (0)