Skip to content

Commit 01ac0c7

Browse files
committed
Numba RavelMultiIndex: Fix scalars with clip mode
1 parent 30d6a74 commit 01ac0c7

File tree

2 files changed

+17
-1
lines changed

2 files changed

+17
-1
lines changed

pytensor/link/numba/dispatch/extra_ops.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,11 @@ def ravelmultiindex(*inp):
154154
stacked_indices[..., i] %= dim_limit
155155
elif mode == "clip":
156156
dim_indices = stacked_indices[..., i]
157-
stacked_indices[..., i] = np.clip(dim_indices, 0, dim_limit - 1)
157+
# Cannot call np.clip on scalars
158+
if vec_indices:
159+
stacked_indices[..., i] = np.clip(dim_indices, 0, dim_limit - 1)
160+
else:
161+
stacked_indices[..., i] = max(0, min(dim_indices, dim_limit - 1))
158162
else: # raise
159163
dim_indices = stacked_indices[..., i]
160164
invalid_indices = (dim_indices < 0) | (dim_indices >= shape[i])

tests/link/numba/test_extra_ops.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,12 @@ def test_FillDiagonalOffset(a, val, offset):
171171
"raise",
172172
ValueError,
173173
),
174+
(
175+
tuple((pt.lscalar(), v) for v in np.array([0, 0, 3])),
176+
(pt.lvector(), np.array([2, 3, 4])),
177+
"wrap",
178+
None,
179+
),
174180
(
175181
tuple(
176182
(pt.lvector(), v) for v in np.array([[0, 1, 2], [2, 0, 3], [1, 3, 5]])
@@ -188,6 +194,12 @@ def test_FillDiagonalOffset(a, val, offset):
188194
"wrap",
189195
None,
190196
),
197+
(
198+
tuple((pt.lscalar(), v) for v in np.array([0, 0, 3])),
199+
(pt.lvector(), np.array([2, 3, 4])),
200+
"clip",
201+
None,
202+
),
191203
(
192204
tuple(
193205
(pt.lvector(), v) for v in np.array([[0, 1, 2], [2, 0, 3], [1, 3, 5]])

0 commit comments

Comments
 (0)