Skip to content

Commit 9ec1de8

Browse files
committed
feat: do not take assignments in BesselSubstitutor
1 parent dda86e5 commit 9ec1de8

File tree

1 file changed

+8
-10
lines changed

1 file changed

+8
-10
lines changed

sumpy/codegen.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -365,18 +365,17 @@ def map_common_subexpression_uncached(
365365
class BesselSubstitutor(CSECachingIdentityMapper[P]):
366366
name_gen: Callable[[str], str]
367367
bessel_j_arg_to_top_order: dict[Expression, int]
368-
assignments: list[Assignment | CallInstruction]
369368

369+
assignments: list[Assignment | CallInstruction]
370370
cse_cache: dict[Expression, prim.CommonSubexpression]
371371

372372
def __init__(self,
373373
name_gen: Callable[[str], str],
374-
bessel_j_arg_to_top_order: dict[Expression, int],
375-
assignments: Sequence[Assignment | CallInstruction]) -> None:
374+
bessel_j_arg_to_top_order: dict[Expression, int]) -> None:
376375
self.name_gen = name_gen
377376
self.bessel_j_arg_to_top_order = bessel_j_arg_to_top_order
378377
self.cse_cache = {}
379-
self.assignments = list(assignments)
378+
self.assignments = []
380379

381380
@override
382381
def map_call(self, expr: prim.Call, /,
@@ -809,20 +808,19 @@ def convert_expr(name: str, expr: Expression) -> Expression:
809808
from pytools import UniqueNameGenerator
810809
name_gen = UniqueNameGenerator({name for name, _expr in pymbolic_assignments})
811810

812-
result: list[Assignment | CallInstruction] = []
813-
bessel_sub = BesselSubstitutor(
814-
name_gen, btog.bessel_j_arg_to_top_order,
815-
result)
816-
817-
import loopy as lp
818811
from pytools import MinRecursionLimit
819812

813+
result: list[Assignment | CallInstruction] = []
814+
bessel_sub = BesselSubstitutor(name_gen, btog.bessel_j_arg_to_top_order)
815+
820816
with MinRecursionLimit(3000):
821817
for name, expr in pymbolic_assignments:
822818
result.append(lp.Assignment(id=None,
823819
assignee=name, expression=bessel_sub(expr),
824820
temp_var_type=lp.Optional(None)))
825821

822+
result.extend(bessel_sub.assignments)
823+
826824
logger.info("loopy instruction generation: done")
827825
return result
828826

0 commit comments

Comments
 (0)