|
14 | 14 | node_rewriter, |
15 | 15 | ) |
16 | 16 | from pytensor.graph.rewriting.unify import OpPattern |
17 | | -from pytensor.scalar.basic import Abs, Log, Mul, Sign |
| 17 | +from pytensor.scalar.basic import Abs, Exp, Log, Mul, Sign, Sqr |
18 | 18 | from pytensor.tensor.basic import ( |
19 | 19 | AllocDiag, |
20 | 20 | ExtractDiag, |
@@ -295,27 +295,41 @@ def local_det_chol(fgraph, node): |
295 | 295 | return [prod(diagonal(L, axis1=-2, axis2=-1) ** 2, axis=-1)] |
296 | 296 |
|
297 | 297 |
|
298 | | -@register_canonicalize |
299 | 298 | @register_stabilize |
300 | 299 | @register_specialize |
301 | 300 | @node_rewriter([log]) |
302 | | -def local_log_prod_sqr(fgraph, node): |
303 | | - """ |
304 | | - This utilizes a boolean `positive` tag on matrices. |
305 | | - """ |
306 | | - (x,) = node.inputs |
307 | | - if x.owner and isinstance(x.owner.op, Prod): |
308 | | - # we cannot always make this substitution because |
309 | | - # the prod might include negative terms |
310 | | - p = x.owner.inputs[0] |
| 301 | +def local_log_prod_to_sum_log(fgraph, node): |
| 302 | + """Rewrite log(prod(x)) as sum(log(x)), when x is known to be positive.""" |
| 303 | + [p] = node.inputs |
| 304 | + p_node = p.owner |
| 305 | + |
| 306 | + if p_node is None: |
| 307 | + return None |
| 308 | + |
| 309 | + p_op = p_node.op |
311 | 310 |
|
312 | | - # p is the matrix we're reducing with prod |
313 | | - if getattr(p.tag, "positive", None) is True: |
314 | | - return [log(p).sum(axis=x.owner.op.axis)] |
| 311 | + if isinstance(p_op, Prod): |
| 312 | + x = p_node.inputs[0] |
| 313 | + |
| 314 | + # TODO: The product of diagonals of a Cholesky(A) are also strictly positive |
| 315 | + if ( |
| 316 | + x.owner is not None |
| 317 | + and isinstance(x.owner.op, Elemwise) |
| 318 | + and isinstance(x.owner.op.scalar_op, Abs | Sqr | Exp) |
| 319 | + ) or getattr(x.tag, "positive", False): |
| 320 | + return [log(x).sum(axis=p_node.op.axis)] |
315 | 321 |
|
316 | 322 | # TODO: have a reduction like prod and sum that simply |
317 | 323 | # returns the sign of the prod multiplication. |
318 | 324 |
|
| 325 | + # Special case for log(abs(prod(x))) -> sum(log(abs(x))) that shows up in slogdet |
| 326 | + elif isinstance(p_op, Elemwise) and isinstance(p_op.scalar_op, Abs): |
| 327 | + [p] = p_node.inputs |
| 328 | + p_node = p.owner |
| 329 | + if p_node is not None and isinstance(p_node.op, Prod): |
| 330 | + [x] = p.owner.inputs |
| 331 | + return [log(abs(x)).sum(axis=p_node.op.axis)] |
| 332 | + |
319 | 333 |
|
320 | 334 | @register_specialize |
321 | 335 | @node_rewriter([blockwise_of(MatrixInverse | Cholesky | MatrixPinv)]) |
|
0 commit comments