Skip to content

Commit a9d058c

Browse files
committed
Generalize determinant from factorization rewrites
1 parent 425f783 commit a9d058c

File tree

2 files changed

+136
-21
lines changed

2 files changed

+136
-21
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 129 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
concatenate,
2424
diag,
2525
diagonal,
26+
ones,
2627
)
2728
from pytensor.tensor.blockwise import Blockwise
2829
from pytensor.tensor.elemwise import DimShuffle, Elemwise
@@ -46,9 +47,12 @@
4647
)
4748
from pytensor.tensor.rewriting.blockwise import blockwise_of
4849
from pytensor.tensor.slinalg import (
50+
LU,
51+
QR,
4952
BlockDiagonal,
5053
Cholesky,
5154
CholeskySolve,
55+
LUFactor,
5256
Solve,
5357
SolveBase,
5458
SolveTriangular,
@@ -65,6 +69,10 @@
6569
MATRIX_INVERSE_OPS = (MatrixInverse, MatrixPinv)
6670

6771

72+
def matrix_diagonal_product(x):
73+
return pt.prod(diagonal(x, axis1=-2, axis2=-1), axis=-1)
74+
75+
6876
def is_matrix_transpose(x: TensorVariable) -> bool:
6977
"""Check if a variable corresponds to a transpose of the last two axes"""
7078
node = x.owner
@@ -279,22 +287,6 @@ def cholesky_ldotlt(fgraph, node):
279287
return [r]
280288

281289

282-
@register_stabilize
283-
@register_specialize
284-
@node_rewriter([det])
285-
def local_det_chol(fgraph, node):
286-
"""
287-
If we have det(X) and there is already an L=cholesky(X)
288-
floating around, then we can use prod(diag(L)) to get the determinant.
289-
290-
"""
291-
(x,) = node.inputs
292-
for cl, xpos in fgraph.clients[x]:
293-
if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, Cholesky):
294-
L = cl.outputs[0]
295-
return [prod(diagonal(L, axis1=-2, axis2=-1) ** 2, axis=-1)]
296-
297-
298290
@register_stabilize
299291
@register_specialize
300292
@node_rewriter([log])
@@ -456,6 +448,127 @@ def _find_diag_from_eye_mul(potential_mul_input):
456448
return eye_input, non_eye_inputs
457449

458450

451+
@register_stabilize
452+
@register_specialize
453+
@node_rewriter([det])
454+
def det_of_matrix_factorized_elsewhere(fgraph, node):
455+
"""
456+
If we have det(X) or abs(det(X)) and there is already a nice decomposition(X) floating around,
457+
use it to compute it more cheaply
458+
459+
"""
460+
[det] = node.outputs
461+
[x] = node.inputs
462+
463+
only_used_by_abs = all(
464+
isinstance(client.op, Elemwise) and isinstance(client.op.scalar_op, Abs)
465+
for client, _ in fgraph.clients[det]
466+
)
467+
468+
new_det = None
469+
for client, _ in fgraph.clients[x]:
470+
core_op = client.op.core_op if isinstance(client.op, Blockwise) else client.op
471+
match core_op:
472+
case Cholesky():
473+
L = client.outputs[0]
474+
new_det = matrix_diagonal_product(L) ** 2
475+
case LU():
476+
U = client.outputs[-1]
477+
new_det = matrix_diagonal_product(U)
478+
case LUFactor():
479+
LU_packed = client.outputs[0]
480+
new_det = matrix_diagonal_product(LU_packed)
481+
case _:
482+
if not only_used_by_abs:
483+
continue
484+
match core_op:
485+
case SVD():
486+
lmbda = (
487+
client.outputs[1]
488+
if core_op.compute_uv
489+
else client.outputs[0]
490+
)
491+
new_det = prod(lmbda, axis=-1)
492+
case QR():
493+
R = client.outputs[-1]
494+
# if mode == "economic", R may not be square and this rewrite could hide a shape error
495+
# That's why it's tagged as `shape_unsafe`
496+
new_det = matrix_diagonal_product(R)
497+
498+
if new_det is not None:
499+
# found a match
500+
break
501+
else: # no-break (i.e., no-match)
502+
return None
503+
504+
[det] = node.outputs
505+
copy_stack_trace(det, new_det)
506+
return [new_det]
507+
508+
509+
@register_stabilize("shape_unsafe")
510+
@register_specialize("shape_unsafe")
511+
@node_rewriter(tracks=[det])
512+
def det_of_factorized_matrix(fgraph, node):
513+
"""Introduce special forms for det(decomposition(X)).
514+
515+
Some cases are only known up to a sign change such as det(QR(X)),
516+
and are only introduced if the determinant is only ever used inside an abs
517+
"""
518+
[det] = node.outputs
519+
[x] = node.inputs
520+
521+
only_used_by_abs = all(
522+
isinstance(client.op, Elemwise) and isinstance(client.op.scalar_op, Abs)
523+
for client, _ in fgraph.clients[det]
524+
)
525+
526+
x_node = x.owner
527+
if x_node is None:
528+
return None
529+
530+
x_op = x_node.op
531+
core_op = x_op.core_op if isinstance(x_op, Blockwise) else x_op
532+
533+
new_det = None
534+
match core_op:
535+
case Cholesky():
536+
new_det = matrix_diagonal_product(x)
537+
case LU():
538+
if x is x_node.outputs[-2]:
539+
# x is L
540+
new_det = ones(x.shape[:-2], dtype=det.dtype)
541+
elif x is x_node.outputs[-1]:
542+
# x is U
543+
new_det = matrix_diagonal_product(x)
544+
case SVD():
545+
if not core_op.compute_uv or x is x_node.outputs[1]:
546+
# x is lambda
547+
new_det = prod(x, axis=-1)
548+
elif only_used_by_abs:
549+
# x is either U or Vt and only ever used inside an abs
550+
new_det = ones(x.shape[:-2], dtype=det.dtype)
551+
case QR():
552+
# if mode == "economic", Q/R may not be square and this rewrite could hide a shape error
553+
# That's why it's tagged as `shape_unsafe`
554+
if x is x_node.outputs[-1]:
555+
# x is R
556+
new_det = matrix_diagonal_product(x)
557+
elif (
558+
only_used_by_abs
559+
and core_op.mode in ("economic", "full")
560+
and x is x_node.outputs[0]
561+
):
562+
# x is Q and it's only ever used inside an abs
563+
new_det = ones(x.shape[:-2], dtype=det.dtype)
564+
565+
if new_det is None:
566+
return None
567+
568+
copy_stack_trace(det, new_det)
569+
return [new_det]
570+
571+
459572
@register_canonicalize("shape_unsafe")
460573
@register_stabilize("shape_unsafe")
461574
@node_rewriter([det])

tests/tensor/rewriting/test_linalg.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -243,14 +243,16 @@ def test_local_det_chol():
243243
det_X = pt.linalg.det(X)
244244

245245
f = function([X], [L, det_X])
246-
247-
nodes = f.maker.fgraph.toposort()
248-
assert not any(isinstance(node, Det) for node in nodes)
246+
assert not any(isinstance(node, Det) for node in f.maker.fgraph.apply_nodes)
249247

250248
# This previously raised an error (issue #392)
251249
f = function([X], [L, det_X, X])
252-
nodes = f.maker.fgraph.toposort()
253-
assert not any(isinstance(node, Det) for node in nodes)
250+
assert not any(isinstance(node, Det) for node in f.maker.fgraph.apply_nodes)
251+
252+
# Test graph that only has det_X
253+
f = function([X], [det_X])
254+
f.dprint()
255+
assert not any(isinstance(node, Det) for node in f.maker.fgraph.apply_nodes)
254256

255257

256258
def test_psd_solve_with_chol():

0 commit comments

Comments
 (0)