@@ -365,18 +365,17 @@ def map_common_subexpression_uncached(
365365class BesselSubstitutor (CSECachingIdentityMapper [P ]):
366366 name_gen : Callable [[str ], str ]
367367 bessel_j_arg_to_top_order : dict [Expression , int ]
368- assignments : list [Assignment | CallInstruction ]
369368
369+ assignments : list [Assignment | CallInstruction ]
370370 cse_cache : dict [Expression , prim .CommonSubexpression ]
371371
372372 def __init__ (self ,
373373 name_gen : Callable [[str ], str ],
374- bessel_j_arg_to_top_order : dict [Expression , int ],
375- assignments : Sequence [Assignment | CallInstruction ]) -> None :
374+ bessel_j_arg_to_top_order : dict [Expression , int ]) -> None :
376375 self .name_gen = name_gen
377376 self .bessel_j_arg_to_top_order = bessel_j_arg_to_top_order
378377 self .cse_cache = {}
379- self .assignments = list ( assignments )
378+ self .assignments = []
380379
381380 @override
382381 def map_call (self , expr : prim .Call , / ,
@@ -809,20 +808,19 @@ def convert_expr(name: str, expr: Expression) -> Expression:
809808 from pytools import UniqueNameGenerator
810809 name_gen = UniqueNameGenerator ({name for name , _expr in pymbolic_assignments })
811810
812- result : list [Assignment | CallInstruction ] = []
813- bessel_sub = BesselSubstitutor (
814- name_gen , btog .bessel_j_arg_to_top_order ,
815- result )
816-
817- import loopy as lp
818811 from pytools import MinRecursionLimit
819812
813+ result : list [Assignment | CallInstruction ] = []
814+ bessel_sub = BesselSubstitutor (name_gen , btog .bessel_j_arg_to_top_order )
815+
820816 with MinRecursionLimit (3000 ):
821817 for name , expr in pymbolic_assignments :
822818 result .append (lp .Assignment (id = None ,
823819 assignee = name , expression = bessel_sub (expr ),
824820 temp_var_type = lp .Optional (None )))
825821
822+ result .extend (bessel_sub .assignments )
823+
826824 logger .info ("loopy instruction generation: done" )
827825 return result
828826
0 commit comments