Skip to content

Commit a199419

Browse files
committed
Fix a few typing issues
1 parent 8cef19c commit a199419

File tree

3 files changed

+12
-6
lines changed

3 files changed

+12
-6
lines changed

sumpy/codegen.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,10 @@
2323
THE SOFTWARE.
2424
"""
2525

26-
26+
from abc import ABC
2727
import logging
2828
import re
29+
from typing import ParamSpec
2930

3031
import numpy as np
3132
from constantdict import constantdict
@@ -43,6 +44,9 @@
4344
logger = logging.getLogger(__name__)
4445

4546

47+
P = ParamSpec("P")
48+
49+
4650
__doc__ = """
4751
4852
Conversion of :mod:`sympy` expressions to :mod:`loopy`
@@ -237,7 +241,8 @@ def register_optimization_preambles(loopy_knl, device):
237241

238242
# {{{ custom mapper base classes
239243

240-
class CSECachingIdentityMapper(IdentityMapper, CSECachingMapperMixin):
244+
class CSECachingIdentityMapper(
245+
IdentityMapper[P], CSECachingMapperMixin[Expression, P], ABC):
241246
pass
242247

243248

sumpy/expansion/m2l.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,13 @@
2525

2626
import logging
2727
from abc import ABC, abstractmethod
28-
from typing import Any, ClassVar
28+
from typing import Any, ClassVar, cast
2929

3030
import numpy as np
3131

3232
import loopy as lp
3333
import pymbolic
34+
import pymbolic.primitives as p
3435

3536
import sumpy.symbolic as sym
3637
from sumpy.tools import add_to_sac, matvec_toeplitz_upper_triangular
@@ -1075,8 +1076,8 @@ def loopy_translation_classes_dependent_data(tgt_expansion, src_expansion,
10751076
for i in range(len(insns)):
10761077
insn = insns[i]
10771078
if isinstance(insn, lp.Assignment) and \
1078-
insn.assignee.name.startswith(vec_name):
1079-
idx = int(insn.assignee.name[len(vec_name):])
1079+
cast(p.Variable, insn.assignee).name.startswith(vec_name):
1080+
idx = int(cast(p.Variable, insn.assignee).name[len(vec_name):])
10801081
insns[i] = lp.Assignment(
10811082
assignee=data[idx],
10821083
expression=insn.expression,

sumpy/tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -972,7 +972,7 @@ def run_opencl_fft(
972972
input_vec: Any,
973973
inverse: bool = False,
974974
wait_for: list[pyopencl.Event] | None = None
975-
) -> tuple[pyopencl.Event, Any]:
975+
) -> tuple[pyopencl.Event | MarkerBasedProfilingEvent, Any]:
976976
"""Runs an FFT on input_vec and returns a :class:`MarkerBasedProfilingEvent`
977977
that indicate the end and start of the operations carried out and the output
978978
vector.

0 commit comments

Comments
 (0)