Skip to content

Commit e7132ca

Browse files
author
Jesse Grabowski
committed
Add return_components argument to pt.linalg.eig
This argument allows the user to get back both the real and imaginary parts of the eigenvalues, if required.
1 parent 602eb04 commit e7132ca

File tree

2 files changed

+72
-9
lines changed

2 files changed

+72
-9
lines changed

pytensor/tensor/nlinalg.py

Lines changed: 53 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,15 @@
1616
from pytensor.tensor import math as ptm
1717
from pytensor.tensor.basic import as_tensor_variable, diagonal
1818
from pytensor.tensor.blockwise import Blockwise
19-
from pytensor.tensor.type import Variable, dvector, lscalar, matrix, scalar, vector
19+
from pytensor.tensor.type import (
20+
Variable,
21+
dvector,
22+
lscalar,
23+
matrix,
24+
scalar,
25+
tensor3,
26+
vector,
27+
)
2028

2129

2230
class MatrixPinv(Op):
@@ -297,31 +305,68 @@ def slogdet(x: TensorLike) -> tuple[ptb.TensorVariable, ptb.TensorVariable]:
297305
class Eig(Op):
298306
"""
299307
Compute the eigenvalues and right eigenvectors of a square array.
300-
301308
"""
302309

303-
__props__: tuple[str, ...] = ()
304-
gufunc_signature = "(m,m)->(m),(m,m)"
310+
__props__: tuple[str, ...] = ("return_components",)
305311
gufunc_spec = ("numpy.linalg.eig", 1, 2)
306312

313+
def __init__(self, return_components: bool = False):
314+
self.return_components = return_components
315+
if return_components:
316+
signature = "(m,m)->(a,m),(a,m,m)"
317+
else:
318+
signature = "(m,m)->(m),(m,m)"
319+
self.gufunc_signature = signature
320+
307321
def make_node(self, x):
308322
x = as_tensor_variable(x)
309323
assert x.ndim == 2
310-
w = vector(dtype=x.dtype)
311-
v = matrix(dtype=x.dtype)
324+
325+
if self.return_components:
326+
w = matrix(dtype=x.dtype)
327+
v = tensor3(dtype=x.dtype)
328+
329+
else:
330+
w = vector(dtype=x.dtype)
331+
v = matrix(dtype=x.dtype)
332+
312333
return Apply(self, [x], [w, v])
313334

314335
def perform(self, node, inputs, outputs):
315336
(x,) = inputs
316337
(w, v) = outputs
317-
w[0], v[0] = (z.astype(x.dtype) for z in np.linalg.eig(x))
338+
if self.return_components:
339+
w_res, v_res = np.linalg.eig(x)
340+
w[0] = np.stack([w_res.real, w_res.imag], axis=0)
341+
v[0] = np.stack([v_res.real, v_res.imag], axis=0)
342+
else:
343+
w[0], v[0] = (z.astype(x.dtype) for z in np.linalg.eig(x))
318344

319345
def infer_shape(self, fgraph, node, shapes):
320346
n = shapes[0][0]
347+
if self.return_components:
348+
return [(2, n), (2, n, n)]
321349
return [(n,), (n, n)]
322350

323351

324-
eig = Blockwise(Eig())
352+
def eig(x: TensorLike, return_components: bool = False):
353+
"""
354+
Return the eigenvalues and right eigenvectors of a square array.
355+
356+
Parameters
357+
----------
358+
x: TensorLike
359+
Square matrix, or array of such matrices
360+
return_components: bool, optional
361+
By default, only the real component of the eigenvalues and eigenvectors are returned, as Pytensor does not
362+
allow the dtype of a variable to change during graph execution.
363+
364+
To circumvent this, if `return_components` is set to True, the real and imaginary *components* of the
365+
eigenvalues are returned as a concatenated array of shape (2, n), where the first row contains the real parts
366+
and the second row contains the imaginary parts. Similarly, the eigenvectors are returned as an array of shape
367+
(2, n, n).
368+
"""
369+
return Blockwise(Eig(return_components=return_components))(x)
325370

326371

327372
class Eigh(Eig):

tests/tensor/test_nlinalg.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,11 +394,12 @@ def test_trace():
394394

395395
class TestEig(utt.InferShapeTester):
396396
op_class = Eig
397-
op = eig
398397
dtype = "float64"
398+
op = staticmethod(eig)
399399

400400
def setup_method(self):
401401
super().setup_method()
402+
402403
self.rng = np.random.default_rng(utt.fetch_seed())
403404
self.A = matrix(dtype=self.dtype)
404405
self.X = np.asarray(self.rng.random((5, 5)), dtype=self.dtype)
@@ -423,6 +424,23 @@ def test_eval(self):
423424
w, v = (e.eval({A: x}) for e in self.op(A))
424425
assert_array_almost_equal(np.dot(x, v), w * v)
425426

427+
def test_eval_return_components(self):
428+
A = matrix(dtype=self.dtype)
429+
A_val = self.rng.normal(size=(5, 5))
430+
w, v = (e.eval({A: A_val}) for e in self.op(A, return_components=True))
431+
assert w.shape == (2, 5)
432+
assert v.shape == (2, 5, 5)
433+
434+
w = w[0] + 1j * w[1]
435+
v = v[0] + 1j * v[1]
436+
437+
w_np, v_np = np.linalg.eig(A_val)
438+
439+
np.testing.assert_allclose(w, w_np)
440+
np.testing.assert_allclose(v, v_np)
441+
442+
assert_array_almost_equal(A_val @ v, w * v)
443+
426444

427445
class TestEigh(TestEig):
428446
op = staticmethod(eigh)

0 commit comments

Comments
 (0)