1515from pytensor .tensor .elemwise import DimShuffle
1616from pytensor .tensor .rewriting .basic import register_specialize
1717from pytensor .tensor .rewriting .blockwise import blockwise_of
18- from pytensor .tensor .rewriting .linalg import is_matrix_transpose
1918from pytensor .tensor .slinalg import Solve , cho_solve , cholesky , lu_factor , lu_solve
2019from pytensor .tensor .variable import TensorVariable
2120
@@ -79,28 +78,26 @@ def get_root_A(a: TensorVariable) -> tuple[TensorVariable, bool]:
7978 # the root variable is the pre-DimShuffled input.
8079 # Otherwise, `a` is considered the root variable.
8180 # We also return whether the root `a` is transposed.
81+ root_a = a
8282 transposed = False
83- if a .owner is not None and isinstance ( a . owner . op , DimShuffle ) :
84- if a . owner . op . is_left_expand_dims :
85- [ a ] = a . owner . inputs
86- elif is_matrix_transpose ( a ):
87- [ a ] = a . owner . inputs
88- transposed = True
89- return a , transposed
83+ match a .owner_op_and_inputs :
84+ case ( DimShuffle ( is_left_expand_dims = True ), root_a ): # type: ignore[misc]
85+ transposed = False
86+ case ( DimShuffle ( is_left_expanded_matrix_transpose = True ), root_a ): # type: ignore[misc]
87+ transposed = True # type: ignore[unreachable]
88+
89+ return root_a , transposed
9090
9191 def find_solve_clients (var , assume_a ):
9292 clients = []
9393 for cl , idx in fgraph .clients [var ]:
94- if (
95- idx == 0
96- and isinstance (cl .op , Blockwise )
97- and isinstance (cl .op .core_op , Solve )
98- and (cl .op .core_op .assume_a == assume_a )
99- ):
100- clients .append (cl )
101- elif isinstance (cl .op , DimShuffle ) and cl .op .is_left_expand_dims :
102- # If it's a left expand_dims, recurse on the output
103- clients .extend (find_solve_clients (cl .outputs [0 ], assume_a ))
94+ match (idx , cl .op , * cl .outputs ):
95+ case (0 , Blockwise (Solve (assume_a = assume_a_var )), * _) if (
96+ assume_a_var == assume_a
97+ ):
98+ clients .append (cl )
99+ case (0 , DimShuffle (is_left_expand_dims = True ), cl_out ):
100+ clients .extend (find_solve_clients (cl_out , assume_a ))
104101 return clients
105102
106103 assume_a = node .op .core_op .assume_a
@@ -119,11 +116,11 @@ def find_solve_clients(var, assume_a):
119116
120117 # Find Solves using A.T
121118 for cl , _ in fgraph .clients [A ]:
122- if isinstance (cl .op , DimShuffle ) and is_matrix_transpose ( cl .out ):
123- A_T = cl . out
124- A_solve_clients_and_transpose .extend (
125- (client , True ) for client in find_solve_clients (A_T , assume_a )
126- )
119+ match (cl .op , * cl .outputs ):
120+ case ( DimShuffle ( is_left_expanded_matrix_transpose = True ), A_T ):
121+ A_solve_clients_and_transpose .extend (
122+ (client , True ) for client in find_solve_clients (A_T , assume_a )
123+ )
127124
128125 if not eager and len (A_solve_clients_and_transpose ) == 1 :
129126 # If theres' a single use don't do it... unless it's being broadcast in a Blockwise (or we're eager)
@@ -185,34 +182,34 @@ def _scan_split_non_sequence_decomposition_and_solve(
185182 changed = False
186183 while True :
187184 for inner_node in new_scan_fgraph .toposort ():
188- if (
189- isinstance (inner_node .op , Blockwise )
190- and isinstance (inner_node .op .core_op , Solve )
191- and inner_node .op .core_op .assume_a in allowed_assume_a
192- ):
193- A , _b = inner_node .inputs
194- if all (
195- (isinstance (root_inp , Constant ) or (root_inp in non_sequences ))
196- for root_inp in graph_inputs ([A ])
185+ match (inner_node .op , * inner_node .inputs ):
186+ case (Blockwise (Solve (assume_a = assume_a_var )), A , _b ) if (
187+ assume_a_var in allowed_assume_a
197188 ):
198- if new_scan_fgraph is scan_op .fgraph :
199- # Clone the first time to avoid mutating the original fgraph
200- new_scan_fgraph , equiv = new_scan_fgraph .clone_get_equiv ()
201- non_sequences = {equiv [non_seq ] for non_seq in non_sequences }
202- inner_node = equiv [inner_node ] # type: ignore
203-
204- replace_dict = _split_decomp_and_solve_steps (
205- new_scan_fgraph ,
206- inner_node ,
207- eager = True ,
208- allowed_assume_a = allowed_assume_a ,
209- )
210- assert isinstance (replace_dict , dict ) and len (replace_dict ) > 0 , (
211- "Rewrite failed"
212- )
213- new_scan_fgraph .replace_all (replace_dict .items ())
214- changed = True
215- break # Break to start over with a fresh toposort
189+ if all (
190+ (isinstance (root_inp , Constant ) or (root_inp in non_sequences ))
191+ for root_inp in graph_inputs ([A ])
192+ ):
193+ if new_scan_fgraph is scan_op .fgraph :
194+ # Clone the first time to avoid mutating the original fgraph
195+ new_scan_fgraph , equiv = new_scan_fgraph .clone_get_equiv ()
196+ non_sequences = {
197+ equiv [non_seq ] for non_seq in non_sequences
198+ }
199+ inner_node = equiv [inner_node ] # type: ignore
200+
201+ replace_dict = _split_decomp_and_solve_steps (
202+ new_scan_fgraph ,
203+ inner_node ,
204+ eager = True ,
205+ allowed_assume_a = allowed_assume_a ,
206+ )
207+ assert (
208+ isinstance (replace_dict , dict ) and len (replace_dict ) > 0
209+ ), "Rewrite failed"
210+ new_scan_fgraph .replace_all (replace_dict .items ())
211+ changed = True
212+ break # Break to start over with a fresh toposort
216213 else : # no_break
217214 break # Nothing else changed
218215
0 commit comments