Skip to content

Commit 425f783

Browse files
committed
Generalize log(prod(x)) -> sum(log(x)) rewrite
1 parent 1d13f8c commit 425f783

File tree

1 file changed

+28
-14
lines changed

1 file changed

+28
-14
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
node_rewriter,
1515
)
1616
from pytensor.graph.rewriting.unify import OpPattern
17-
from pytensor.scalar.basic import Abs, Log, Mul, Sign
17+
from pytensor.scalar.basic import Abs, Exp, Log, Mul, Sign, Sqr
1818
from pytensor.tensor.basic import (
1919
AllocDiag,
2020
ExtractDiag,
@@ -295,27 +295,41 @@ def local_det_chol(fgraph, node):
295295
return [prod(diagonal(L, axis1=-2, axis2=-1) ** 2, axis=-1)]
296296

297297

298-
@register_canonicalize
299298
@register_stabilize
300299
@register_specialize
301300
@node_rewriter([log])
302-
def local_log_prod_sqr(fgraph, node):
303-
"""
304-
This utilizes a boolean `positive` tag on matrices.
305-
"""
306-
(x,) = node.inputs
307-
if x.owner and isinstance(x.owner.op, Prod):
308-
# we cannot always make this substitution because
309-
# the prod might include negative terms
310-
p = x.owner.inputs[0]
301+
def local_log_prod_to_sum_log(fgraph, node):
302+
"""Rewrite log(prod(x)) as sum(log(x)), when x is known to be positive."""
303+
[p] = node.inputs
304+
p_node = p.owner
305+
306+
if p_node is None:
307+
return None
308+
309+
p_op = p_node.op
311310

312-
# p is the matrix we're reducing with prod
313-
if getattr(p.tag, "positive", None) is True:
314-
return [log(p).sum(axis=x.owner.op.axis)]
311+
if isinstance(p_op, Prod):
312+
x = p_node.inputs[0]
313+
314+
# TODO: The product of diagonals of a Cholesky(A) are also strictly positive
315+
if (
316+
x.owner is not None
317+
and isinstance(x.owner.op, Elemwise)
318+
and isinstance(x.owner.op.scalar_op, Abs | Sqr | Exp)
319+
) or getattr(x.tag, "positive", False):
320+
return [log(x).sum(axis=p_node.op.axis)]
315321

316322
# TODO: have a reduction like prod and sum that simply
317323
# returns the sign of the prod multiplication.
318324

325+
# Special case for log(abs(prod(x))) -> sum(log(abs(x))) that shows up in slogdet
326+
elif isinstance(p_op, Elemwise) and isinstance(p_op.scalar_op, Abs):
327+
[p] = p_node.inputs
328+
p_node = p.owner
329+
if p_node is not None and isinstance(p_node.op, Prod):
330+
[x] = p.owner.inputs
331+
return [log(abs(x)).sum(axis=p_node.op.axis)]
332+
319333

320334
@register_specialize
321335
@node_rewriter([blockwise_of(MatrixInverse | Cholesky | MatrixPinv)])

0 commit comments

Comments
 (0)