|
23 | 23 | concatenate, |
24 | 24 | diag, |
25 | 25 | diagonal, |
| 26 | + ones, |
26 | 27 | ) |
27 | 28 | from pytensor.tensor.blockwise import Blockwise |
28 | 29 | from pytensor.tensor.elemwise import DimShuffle, Elemwise |
|
46 | 47 | ) |
47 | 48 | from pytensor.tensor.rewriting.blockwise import blockwise_of |
48 | 49 | from pytensor.tensor.slinalg import ( |
| 50 | + LU, |
| 51 | + QR, |
49 | 52 | BlockDiagonal, |
50 | 53 | Cholesky, |
51 | 54 | CholeskySolve, |
| 55 | + LUFactor, |
52 | 56 | Solve, |
53 | 57 | SolveBase, |
54 | 58 | SolveTriangular, |
|
65 | 69 | MATRIX_INVERSE_OPS = (MatrixInverse, MatrixPinv) |
66 | 70 |
|
67 | 71 |
|
| 72 | +def matrix_diagonal_product(x): |
| 73 | + return pt.prod(diagonal(x, axis1=-2, axis2=-1), axis=-1) |
| 74 | + |
| 75 | + |
68 | 76 | def is_matrix_transpose(x: TensorVariable) -> bool: |
69 | 77 | """Check if a variable corresponds to a transpose of the last two axes""" |
70 | 78 | node = x.owner |
@@ -279,22 +287,6 @@ def cholesky_ldotlt(fgraph, node): |
279 | 287 | return [r] |
280 | 288 |
|
281 | 289 |
|
282 | | -@register_stabilize |
283 | | -@register_specialize |
284 | | -@node_rewriter([det]) |
285 | | -def local_det_chol(fgraph, node): |
286 | | - """ |
287 | | - If we have det(X) and there is already an L=cholesky(X) |
288 | | - floating around, then we can use prod(diag(L)) to get the determinant. |
289 | | -
|
290 | | - """ |
291 | | - (x,) = node.inputs |
292 | | - for cl, xpos in fgraph.clients[x]: |
293 | | - if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, Cholesky): |
294 | | - L = cl.outputs[0] |
295 | | - return [prod(diagonal(L, axis1=-2, axis2=-1) ** 2, axis=-1)] |
296 | | - |
297 | | - |
298 | 290 | @register_stabilize |
299 | 291 | @register_specialize |
300 | 292 | @node_rewriter([log]) |
@@ -456,6 +448,127 @@ def _find_diag_from_eye_mul(potential_mul_input): |
456 | 448 | return eye_input, non_eye_inputs |
457 | 449 |
|
458 | 450 |
|
| 451 | +@register_stabilize |
| 452 | +@register_specialize |
| 453 | +@node_rewriter([det]) |
| 454 | +def det_of_matrix_factorized_elsewhere(fgraph, node): |
| 455 | + """ |
| 456 | + If we have det(X) or abs(det(X)) and there is already a nice decomposition(X) floating around, |
| 457 | + use it to compute it more cheaply |
| 458 | +
|
| 459 | + """ |
| 460 | + [det] = node.outputs |
| 461 | + [x] = node.inputs |
| 462 | + |
| 463 | + only_used_by_abs = all( |
| 464 | + isinstance(client.op, Elemwise) and isinstance(client.op.scalar_op, Abs) |
| 465 | + for client, _ in fgraph.clients[det] |
| 466 | + ) |
| 467 | + |
| 468 | + new_det = None |
| 469 | + for client, _ in fgraph.clients[x]: |
| 470 | + core_op = client.op.core_op if isinstance(client.op, Blockwise) else client.op |
| 471 | + match core_op: |
| 472 | + case Cholesky(): |
| 473 | + L = client.outputs[0] |
| 474 | + new_det = matrix_diagonal_product(L) ** 2 |
| 475 | + case LU(): |
| 476 | + U = client.outputs[-1] |
| 477 | + new_det = matrix_diagonal_product(U) |
| 478 | + case LUFactor(): |
| 479 | + LU_packed = client.outputs[0] |
| 480 | + new_det = matrix_diagonal_product(LU_packed) |
| 481 | + case _: |
| 482 | + if not only_used_by_abs: |
| 483 | + continue |
| 484 | + match core_op: |
| 485 | + case SVD(): |
| 486 | + lmbda = ( |
| 487 | + client.outputs[1] |
| 488 | + if core_op.compute_uv |
| 489 | + else client.outputs[0] |
| 490 | + ) |
| 491 | + new_det = prod(lmbda, axis=-1) |
| 492 | + case QR(): |
| 493 | + R = client.outputs[-1] |
| 494 | + # if mode == "economic", R may not be square and this rewrite could hide a shape error |
| 495 | + # That's why it's tagged as `shape_unsafe` |
| 496 | + new_det = matrix_diagonal_product(R) |
| 497 | + |
| 498 | + if new_det is not None: |
| 499 | + # found a match |
| 500 | + break |
| 501 | + else: # no-break (i.e., no-match) |
| 502 | + return None |
| 503 | + |
| 504 | + [det] = node.outputs |
| 505 | + copy_stack_trace(det, new_det) |
| 506 | + return [new_det] |
| 507 | + |
| 508 | + |
| 509 | +@register_stabilize("shape_unsafe") |
| 510 | +@register_specialize("shape_unsafe") |
| 511 | +@node_rewriter(tracks=[det]) |
| 512 | +def det_of_factorized_matrix(fgraph, node): |
| 513 | + """Introduce special forms for det(decomposition(X)). |
| 514 | +
|
| 515 | + Some cases are only known up to a sign change such as det(QR(X)), |
| 516 | + and are only introduced if the determinant is only ever used inside an abs |
| 517 | + """ |
| 518 | + [det] = node.outputs |
| 519 | + [x] = node.inputs |
| 520 | + |
| 521 | + only_used_by_abs = all( |
| 522 | + isinstance(client.op, Elemwise) and isinstance(client.op.scalar_op, Abs) |
| 523 | + for client, _ in fgraph.clients[det] |
| 524 | + ) |
| 525 | + |
| 526 | + x_node = x.owner |
| 527 | + if x_node is None: |
| 528 | + return None |
| 529 | + |
| 530 | + x_op = x_node.op |
| 531 | + core_op = x_op.core_op if isinstance(x_op, Blockwise) else x_op |
| 532 | + |
| 533 | + new_det = None |
| 534 | + match core_op: |
| 535 | + case Cholesky(): |
| 536 | + new_det = matrix_diagonal_product(x) |
| 537 | + case LU(): |
| 538 | + if x is x_node.outputs[-2]: |
| 539 | + # x is L |
| 540 | + new_det = ones(x.shape[:-2], dtype=det.dtype) |
| 541 | + elif x is x_node.outputs[-1]: |
| 542 | + # x is U |
| 543 | + new_det = matrix_diagonal_product(x) |
| 544 | + case SVD(): |
| 545 | + if not core_op.compute_uv or x is x_node.outputs[1]: |
| 546 | + # x is lambda |
| 547 | + new_det = prod(x, axis=-1) |
| 548 | + elif only_used_by_abs: |
| 549 | + # x is either U or Vt and only ever used inside an abs |
| 550 | + new_det = ones(x.shape[:-2], dtype=det.dtype) |
| 551 | + case QR(): |
| 552 | + # if mode == "economic", Q/R may not be square and this rewrite could hide a shape error |
| 553 | + # That's why it's tagged as `shape_unsafe` |
| 554 | + if x is x_node.outputs[-1]: |
| 555 | + # x is R |
| 556 | + new_det = matrix_diagonal_product(x) |
| 557 | + elif ( |
| 558 | + only_used_by_abs |
| 559 | + and core_op.mode in ("economic", "full") |
| 560 | + and x is x_node.outputs[0] |
| 561 | + ): |
| 562 | + # x is Q and it's only ever used inside an abs |
| 563 | + new_det = ones(x.shape[:-2], dtype=det.dtype) |
| 564 | + |
| 565 | + if new_det is None: |
| 566 | + return None |
| 567 | + |
| 568 | + copy_stack_trace(det, new_det) |
| 569 | + return [new_det] |
| 570 | + |
| 571 | + |
459 | 572 | @register_canonicalize("shape_unsafe") |
460 | 573 | @register_stabilize("shape_unsafe") |
461 | 574 | @node_rewriter([det]) |
|
0 commit comments