Skip to content

Commit b1c1a78

Browse files
author
Jesse Grabowski
committed
Add JAX dispatch for expm
1 parent 34f70ac commit b1c1a78

File tree

2 files changed

+18
-0
lines changed

2 files changed

+18
-0
lines changed

pytensor/link/jax/dispatch/slinalg.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
Cholesky,
1111
CholeskySolve,
1212
Eigvalsh,
13+
Expm,
1314
LUFactor,
1415
PivotToPermutations,
1516
Solve,
@@ -179,3 +180,11 @@ def qr(x, mode=mode):
179180
return jax.scipy.linalg.qr(x, mode=mode)
180181

181182
return qr
183+
184+
185+
@jax_funcify.register(Expm)
186+
def jax_funcify_Expm(op, **kwargs):
187+
def expm(x):
188+
return jax.scipy.linalg.expm(x)
189+
190+
return expm

tests/link/jax/test_slinalg.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,3 +361,12 @@ def test_jax_cho_solve(b_shape, lower):
361361
out = pt_slinalg.cho_solve((c, lower), b, b_ndim=len(b_shape))
362362

363363
compare_jax_and_py([A, b], [out], [A_val, b_val])
364+
365+
366+
def test_jax_expm():
367+
rng = np.random.default_rng(utt.fetch_seed())
368+
A = pt.tensor(name="A", shape=(5, 5))
369+
A_val = rng.normal(size=(5, 5)).astype(config.floatX)
370+
out = pt_slinalg.expm(A)
371+
372+
compare_jax_and_py([A], [out], [A_val])

0 commit comments

Comments
 (0)