Skip to content

Commit cccce32

Browse files
committed
feat(typing): add types to expansion.loopy
1 parent da7829f commit cccce32

File tree

1 file changed

+72
-60
lines changed

1 file changed

+72
-60
lines changed

sumpy/expansion/loopy.py

Lines changed: 72 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
import numpy as np
3030

3131
import loopy as lp
32-
import pymbolic
32+
import pymbolic.primitives as prim
3333

3434
import sumpy.symbolic as sym
3535
from sumpy.assignment_collection import SymbolicAssignmentCollection
@@ -39,6 +39,8 @@
3939
if TYPE_CHECKING:
4040
from collections.abc import Sequence
4141

42+
from pymbolic.typing import ArithmeticExpression
43+
4244
from sumpy.expansion import ExpansionBase
4345
from sumpy.kernel import Kernel
4446

@@ -49,27 +51,30 @@
4951
def make_e2p_loopy_kernel(
5052
expansion: ExpansionBase, kernels: Sequence[Kernel]) -> lp.TranslationUnit:
5153
"""
52-
This is a helper function to create a loopy kernel for multipole/local
53-
evaluation. This function uses symbolic expressions given by the expansion class,
54-
converts them to pymbolic expressions and generates a loopy
55-
kernel. Note that the loopy kernel returned has lots of expressions in it and
56-
takes a long time. Therefore, this function should be used only as a fallback
57-
when there is no "loop-y" kernel to evaluate the expansion.
54+
A helper function that creates a :mod:`loopy` kernel for multipole/local evaluation.
55+
56+
This function uses symbolic expressions given by the expansion class,
57+
converts them to :mod:`pymbolic` expressions and generates a :mod:`loopy`
58+
kernel. Note that the :mod:`loopy` kernel returned has lots of expressions
59+
in it and (likely) takes a long time. Therefore, this function should be
60+
used only as a fallback when there is no "loop-y" kernel to evaluate the
61+
expansion.
5862
"""
59-
dim = expansion.dim
6063

61-
bvec = sym.make_sym_vector("b", dim)
64+
sac = SymbolicAssignmentCollection()
65+
66+
dim = expansion.dim
6267
ncoeffs = len(expansion.get_coefficient_identifiers())
6368

69+
bvec = sym.make_sym_vector("b", dim)
6470
rscale = sym.Symbol("rscale")
6571

66-
sac = SymbolicAssignmentCollection()
67-
6872
domains = [
6973
"{[idim]: 0<=idim<dim}",
7074
"{[iknl]: 0<=iknl<nresults}",
7175
]
72-
insns = []
76+
77+
insns: list[lp.Assignment | lp.CallInstruction] = []
7378
insns.append(
7479
lp.Assignment(
7580
assignee="b[idim]",
@@ -79,15 +84,17 @@ def make_e2p_loopy_kernel(
7984
target_args = gather_loopy_arguments((expansion, *tuple(kernels)))
8085

8186
coeff_exprs = sym.make_sym_vector("coeffs", ncoeffs)
82-
coeff_names = [
87+
coeff_names = {
8388
sac.add_assignment(f"result{i}",
8489
expansion.evaluate(knl, coeff_exprs, bvec, rscale, sac=sac))
85-
for i, knl in enumerate(kernels)]
90+
for i, knl in enumerate(kernels)
91+
}
8692

8793
sac.run_global_cse()
8894

89-
code_transformers = [expansion.get_code_transformer()] \
90-
+ [kernel.get_code_transformer() for kernel in kernels]
95+
code_transformers = (
96+
[expansion.get_code_transformer()]
97+
+ [kernel.get_code_transformer() for kernel in kernels])
9198

9299
from sumpy.codegen import to_loopy_insns
93100
insns += to_loopy_insns(
@@ -98,37 +105,37 @@ def make_e2p_loopy_kernel(
98105
complex_dtype=np.complex128 # FIXME
99106
)
100107

101-
result = pymbolic.var("result")
108+
result = prim.Variable("result")
102109

103110
# change result{i} = expr into result[i] += expr
104111
for i in range(len(insns)):
105112
insn = insns[i]
106-
if isinstance(insn, lp.Assignment) and \
107-
isinstance(insn.assignee, pymbolic.var) and \
108-
insn.assignee.name.startswith(result.name):
113+
if (isinstance(insn, lp.Assignment)
114+
and isinstance(insn.assignee, prim.Variable)
115+
and insn.assignee.name.startswith(result.name)):
109116
idx = int(insn.assignee.name[len(result.name):])
110117
insns[i] = lp.Assignment(
111118
assignee=result[idx],
112119
expression=(
113120
result[idx]
114-
+ cast("pymbolic.ArithmeticExpression", insn.expression)),
121+
+ cast("ArithmeticExpression", insn.expression)),
115122
id=f"result_{idx}",
116123
happens_after=insn.happens_after,
117124
)
118125

119126
loopy_knl = lp.make_function(domains, insns,
120127
kernel_data=[
121-
lp.GlobalArg("result", shape=(len(kernels),), is_input=True,
122-
is_output=True),
123-
lp.GlobalArg("coeffs",
124-
shape=(ncoeffs,), is_input=True, is_output=False),
125-
lp.GlobalArg("center, target",
126-
shape=(dim,), is_input=True, is_output=False),
128+
lp.GlobalArg("result", shape=(len(kernels),),
129+
is_input=True, is_output=True),
130+
lp.GlobalArg("coeffs", shape=(ncoeffs,),
131+
is_input=True, is_output=False),
132+
lp.GlobalArg("center, target", shape=(dim,),
133+
is_input=True, is_output=False),
127134
lp.ValueArg("rscale", is_input=True),
128135
lp.ValueArg("itgt", is_input=True),
129136
lp.ValueArg("ntargets", is_input=True),
130-
lp.GlobalArg("targets",
131-
shape=(dim, "ntargets"), is_input=True, is_output=False),
137+
lp.GlobalArg("targets", shape=(dim, "ntargets"),
138+
is_input=True, is_output=False),
132139
*target_args,
133140
...],
134141
name="e2p",
@@ -144,29 +151,33 @@ def make_e2p_loopy_kernel(
144151

145152

146153
def make_p2e_loopy_kernel(
147-
expansion: ExpansionBase, kernels: Sequence[Kernel],
148-
strength_usage: Sequence[int], nstrengths: int) -> lp.TranslationUnit:
154+
expansion: ExpansionBase,
155+
kernels: Sequence[Kernel],
156+
strength_usage: Sequence[int],
157+
nstrengths: int) -> lp.TranslationUnit:
149158
"""
150-
This is a helper function to create a loopy kernel for multipole/local
151-
expression. This function uses symbolic expressions given by the expansion
152-
class, converts them to pymbolic expressions and generates a loopy
153-
kernel. Note that the loopy kernel returned has lots of expressions in it and
154-
takes a long time. Therefore, this function should be used only as a fallback
155-
when there is no "loop-y" kernel to evaluate the expansion.
159+
A helper function that creates a :mod:`loopy` kernel for multipole/local evaluation.
160+
161+
This function uses symbolic expressions given by the expansion class,
162+
converts them to :mod:`pymbolic` expressions and generates a :mod:`loopy`
163+
kernel. Note that the :mod:`loopy` kernel returned has lots of expressions
164+
in it and (likely) takes a long time. Therefore, this function should be
165+
used only as a fallback when there is no "loop-y" kernel to evaluate the
166+
expansion.
156167
"""
157-
dim = expansion.dim
168+
sac = SymbolicAssignmentCollection()
158169

159-
avec = sym.make_sym_vector("a", dim)
170+
dim = expansion.dim
160171
ncoeffs = len(expansion.get_coefficient_identifiers())
161172

173+
avec = sym.make_sym_vector("a", dim)
162174
rscale = sym.Symbol("rscale")
163175

164-
sac = SymbolicAssignmentCollection()
165-
166176
domains = [
167177
"{[idim]: 0<=idim<dim}",
168178
]
169-
insns = []
179+
180+
insns: list[lp.Assignment | lp.CallInstruction] = []
170181
insns.append(
171182
lp.Assignment(
172183
assignee="a[idim]",
@@ -179,15 +190,15 @@ def make_p2e_loopy_kernel(
179190
strengths = [all_strengths[i] for i in strength_usage]
180191
coeffs = expansion.coefficients_from_source_vec(kernels,
181192
avec, None, rscale, strengths, sac=sac)
182-
183-
coeff_names = [
193+
coeff_names = {
184194
sac.add_assignment(f"coeffs{i}", coeff) for i, coeff in enumerate(coeffs)
185-
]
195+
}
186196

187197
sac.run_global_cse()
188198

189-
code_transformers = [expansion.get_code_transformer()] \
190-
+ [kernel.get_code_transformer() for kernel in kernels]
199+
code_transformers = (
200+
[expansion.get_code_transformer()]
201+
+ [kernel.get_code_transformer() for kernel in kernels])
191202

192203
from sumpy.codegen import to_loopy_insns
193204
insns += to_loopy_insns(
@@ -198,37 +209,38 @@ def make_p2e_loopy_kernel(
198209
complex_dtype=np.complex128 # FIXME
199210
)
200211

201-
coeffs = pymbolic.var("coeffs")
212+
coeffs = prim.Variable("coeffs")
202213

203214
# change coeff{i} = expr into coeff[i] += expr
204215
for i in range(len(insns)):
205216
insn = insns[i]
206-
if isinstance(insn, lp.Assignment) and \
207-
isinstance(insn.assignee, pymbolic.var) and \
208-
insn.assignee.name.startswith(coeffs.name):
217+
if (isinstance(insn, lp.Assignment)
218+
and isinstance(insn.assignee, prim.Variable)
219+
and insn.assignee.name.startswith(coeffs.name)):
209220
idx = int(insn.assignee.name[len(coeffs.name):])
221+
210222
insns[i] = lp.Assignment(
211223
assignee=coeffs[idx],
212224
expression=(
213225
coeffs[idx]
214-
+ cast("pymbolic.ArithmeticExpression", insn.expression)),
226+
+ cast("ArithmeticExpression", insn.expression)),
215227
id=f"coeff_{idx}",
216228
happens_after=insn.happens_after,
217229
)
218230

219231
loopy_knl = lp.make_function(domains, insns,
220232
kernel_data=[
221-
lp.GlobalArg("coeffs",
222-
shape=(ncoeffs,), is_input=True, is_output=True),
223-
lp.GlobalArg("center, source",
224-
shape=(dim,), is_input=True, is_output=False),
225-
lp.GlobalArg("strength",
226-
shape=(nstrengths,), is_input=True, is_output=False),
233+
lp.GlobalArg("coeffs", shape=(ncoeffs,),
234+
is_input=True, is_output=True),
235+
lp.GlobalArg("center, source", shape=(dim,),
236+
is_input=True, is_output=False),
237+
lp.GlobalArg("strength", shape=(nstrengths,),
238+
is_input=True, is_output=False),
227239
lp.ValueArg("rscale", is_input=True),
228240
lp.ValueArg("isrc", is_input=True),
229241
lp.ValueArg("nsources", is_input=True),
230-
lp.GlobalArg("sources",
231-
shape=(dim, "nsources"), is_input=True, is_output=False),
242+
lp.GlobalArg("sources", shape=(dim, "nsources"),
243+
is_input=True, is_output=False),
232244
*source_args,
233245
...],
234246
name="p2e",

0 commit comments

Comments
 (0)