Skip to content

Commit 34f70ac

Browse files
author
Jesse Grabowski
committed
Clean up pytensor.linalg.expm and related tests
1 parent e7132ca commit 34f70ac

File tree

2 files changed

+52
-80
lines changed

2 files changed

+52
-80
lines changed

pytensor/tensor/slinalg.py

Lines changed: 32 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1304,82 +1304,63 @@ def eigvalsh(a, b, lower=True):
13041304
class Expm(Op):
13051305
"""
13061306
Compute the matrix exponential of a square array.
1307-
13081307
"""
13091308

13101309
__props__ = ()
1310+
gufunc_signature = "(m,m)->(m,m)"
13111311

13121312
def make_node(self, A):
13131313
A = as_tensor_variable(A)
13141314
assert A.ndim == 2
1315-
expm = matrix(dtype=A.dtype)
1316-
return Apply(
1317-
self,
1318-
[
1319-
A,
1320-
],
1321-
[
1322-
expm,
1323-
],
1324-
)
1315+
1316+
expm = matrix(dtype=A.dtype, shape=A.type.shape)
1317+
1318+
return Apply(self, [A], [expm])
13251319

13261320
def perform(self, node, inputs, outputs):
13271321
(A,) = inputs
13281322
(expm,) = outputs
13291323
expm[0] = scipy_linalg.expm(A)
13301324

1331-
def grad(self, inputs, outputs):
1325+
def L_op(self, inputs, outputs, output_grads):
1326+
# Kalbfleisch and Lawless, J. Am. Stat. Assoc. 80 (1985) Equation 3.4
1327+
# Kind of... You need to do some algebra from there to arrive at
1328+
# this expression.
13321329
(A,) = inputs
1333-
(g_out,) = outputs
1334-
return [ExpmGrad()(A, g_out)]
1335-
1336-
def infer_shape(self, fgraph, node, shapes):
1337-
return [shapes[0]]
1330+
(_,) = outputs # Outputs not used; included for signature consistency only
1331+
(A_bar,) = output_grads
13381332

1333+
w, V = pt.linalg.eig(A, return_components=True)
13391334

1340-
class ExpmGrad(Op):
1341-
"""
1342-
Gradient of the matrix exponential of a square array.
1335+
w = w[0] + 1j * w[1]
1336+
V = V[0] + 1j * V[1]
13431337

1344-
"""
1338+
exp_w = pt.exp(w)
1339+
numer = pt.sub.outer(exp_w, exp_w)
1340+
denom = pt.sub.outer(w, w)
13451341

1346-
__props__ = ()
1342+
# When w_i ≈ w_j, we have a removable singularity in the expression for X, because
1343+
# lim b->a (e^a - e^b) / (a - b) = e^a (derivation left for the motivated reader)
1344+
X = pt.where(pt.abs(denom) < 1e-8, exp_w, numer / denom)
13471345

1348-
def make_node(self, A, gw):
1349-
A = as_tensor_variable(A)
1350-
assert A.ndim == 2
1351-
out = matrix(dtype=A.dtype)
1352-
return Apply(
1353-
self,
1354-
[A, gw],
1355-
[
1356-
out,
1357-
],
1358-
)
1346+
diag_idx = pt.arange(w.shape[0])
1347+
X = X[..., diag_idx, diag_idx].set(exp_w)
13591348

1360-
def infer_shape(self, fgraph, node, shapes):
1361-
return [shapes[0]]
1349+
inner = solve(V, A_bar.T @ V).T
1350+
result = solve(V.T, inner * X) @ V.T
13621351

1363-
def perform(self, node, inputs, outputs):
1364-
# Kalbfleisch and Lawless, J. Am. Stat. Assoc. 80 (1985) Equation 3.4
1365-
# Kind of... You need to do some algebra from there to arrive at
1366-
# this expression.
1367-
(A, gA) = inputs
1368-
(out,) = outputs
1369-
w, V = scipy_linalg.eig(A, right=True)
1370-
U = scipy_linalg.inv(V).T
1352+
# At this point, result is always a complex dtype. If the input was real, the output should be
1353+
# real as well (and all the imaginary parts are numerical noise)
1354+
if A.dtype not in ("complex64", "complex128"):
1355+
return [result.real]
13711356

1372-
exp_w = np.exp(w)
1373-
X = np.subtract.outer(exp_w, exp_w) / np.subtract.outer(w, w)
1374-
np.fill_diagonal(X, exp_w)
1375-
Y = U.dot(V.T.dot(gA).dot(U) * X).dot(V.T)
1357+
return [result]
13761358

1377-
with warnings.catch_warnings():
1378-
warnings.simplefilter("ignore", ComplexWarning)
1379-
out[0] = Y.astype(A.dtype)
1359+
def infer_shape(self, fgraph, node, shapes):
1360+
return [shapes[0]]
13801361

13811362

1382-
expm = Expm()
1363+
expm = Blockwise(Expm())
13831364

13841365

13851366
class SolveContinuousLyapunov(Op):

tests/tensor/test_slinalg.py

Lines changed: 20 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -880,35 +880,26 @@ def test_expm():
880880
np.testing.assert_array_almost_equal(val, ref)
881881

882882

883-
def test_expm_grad_1():
884-
# with symmetric matrix (real eigenvectors)
885-
rng = np.random.default_rng(utt.fetch_seed())
886-
# Always test in float64 for better numerical stability.
887-
A = rng.standard_normal((5, 5))
888-
A = A + A.T
889-
890-
utt.verify_grad(expm, [A], rng=rng)
891-
892-
893-
def test_expm_grad_2():
894-
# with non-symmetric matrix with real eigenspecta
895-
rng = np.random.default_rng(utt.fetch_seed())
896-
# Always test in float64 for better numerical stability.
897-
A = rng.standard_normal((5, 5))
898-
w = rng.standard_normal(5) ** 2
899-
A = (np.diag(w**0.5)).dot(A + A.T).dot(np.diag(w ** (-0.5)))
900-
assert not np.allclose(A, A.T)
901-
902-
utt.verify_grad(expm, [A], rng=rng)
903-
904-
905-
def test_expm_grad_3():
906-
# with non-symmetric matrix (complex eigenvectors)
907-
rng = np.random.default_rng(utt.fetch_seed())
908-
# Always test in float64 for better numerical stability.
909-
A = rng.standard_normal((5, 5))
910-
911-
utt.verify_grad(expm, [A], rng=rng)
883+
@pytest.mark.parametrize(
884+
"mode", ["symmetric", "nonsymmetric_real_eig", "nonsymmetric_complex_eig"][-1:]
885+
)
886+
def test_expm_grad(mode):
887+
rng = np.random.default_rng()
888+
889+
match mode:
890+
case "symmetric":
891+
A = rng.standard_normal((5, 5))
892+
A = A + A.T
893+
case "nonsymmetric_real_eig":
894+
A = rng.standard_normal((5, 5))
895+
w = rng.standard_normal(5) ** 2
896+
A = (np.diag(w**0.5)).dot(A + A.T).dot(np.diag(w ** (-0.5)))
897+
case "nonsymmetric_complex_eig":
898+
A = rng.standard_normal((5, 5))
899+
case _:
900+
raise ValueError(f"Invalid mode: {mode}")
901+
902+
utt.verify_grad(expm, [A], rng=rng, abs_tol=1e-5, rel_tol=1e-5)
912903

913904

914905
def recover_Q(A, X, continuous=True):

0 commit comments

Comments
 (0)