Skip to content

Commit 829520e

Browse files
committed
Type more of fmm and toys
1 parent e25b26c commit 829520e

File tree

3 files changed

+107
-44
lines changed

3 files changed

+107
-44
lines changed

doc/conf.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
"np.floating": "class:numpy.floating",
3636
"np.complexfloating": "class:numpy.complexfloating",
3737
"np.inexact": "class:numpy.inexact",
38+
"np.dtype": "class:numpy.dtype",
3839
# pytools
3940
"obj_array.ObjectArray1D": "obj:pytools.obj_array.ObjectArray1D",
4041
# sympy
@@ -49,6 +50,8 @@
4950
"Expression": "obj:pymbolic.typing.Expression",
5051
# arraycontext
5152
"Array": "obj:arraycontext.Array",
53+
# boxtree
54+
"FMMTraversalInfo": "class:boxtree.traversal.FMMTraversalInfo",
5255
# sumpy
5356
"ArithmeticExpr": "obj:sumpy.kernel.ArithmeticExpr",
5457
}

sumpy/fmm.py

Lines changed: 59 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,24 @@
2626
__doc__ = """Integrates :mod:`boxtree` with :mod:`sumpy`.
2727
2828
.. autoclass:: SumpyTreeIndependentDataForWrangler
29+
.. autodata:: FMMLevelToOrder
30+
:no-index:
31+
.. class:: FMMLevelToOrder
32+
33+
See above.
34+
2935
.. autoclass:: SumpyExpansionWrangler
3036
3137
.. autoclass:: MultipoleExpansionFromOrderFactory
3238
.. autoclass:: LocalExpansionFromOrderFactory
3339
"""
3440

35-
from typing import TYPE_CHECKING, Protocol, cast
41+
from collections.abc import Callable, Mapping
42+
from typing import TYPE_CHECKING, Any, Protocol, TypeAlias, cast
3643

44+
import numpy as np
3745
from boxtree.fmm import ExpansionWranglerInterface, TreeIndependentDataForWrangler
46+
from boxtree.tree import Tree
3847

3948
import pyopencl as cl
4049
import pyopencl.array as cl_array
@@ -55,6 +64,7 @@
5564
P2EFromSingleBox,
5665
P2PFromCSR,
5766
)
67+
from sumpy.kernel import Kernel
5868
from sumpy.tools import (
5969
AggregateProfilingEvent,
6070
get_native_event,
@@ -67,9 +77,11 @@
6777
if TYPE_CHECKING:
6878
from collections.abc import Sequence
6979

80+
from boxtree.traversal import FMMTraversalInfo
81+
from numpy.typing import DTypeLike
82+
7083
from sumpy.expansion.local import LocalExpansionBase
7184
from sumpy.expansion.multipole import MultipoleExpansionBase
72-
from sumpy.kernel import Kernel
7385

7486

7587
class MultipoleExpansionFromOrderFactory(Protocol):
@@ -306,6 +318,11 @@ def done(self):
306318

307319
# {{{ expansion wrangler
308320

321+
FMMLevelToOrder: TypeAlias = Callable[
322+
[Kernel, frozenset[tuple[str, object]], Tree, int],
323+
int]
324+
325+
309326
class SumpyExpansionWrangler(ExpansionWranglerInterface):
310327
"""Implements the :class:`boxtree.fmm.ExpansionWranglerInterface`
311328
by using :mod:`sumpy` expansions/translations.
@@ -331,23 +348,44 @@ class SumpyExpansionWrangler(ExpansionWranglerInterface):
331348
Type for the preprocessed multipole expansion if used for M2L.
332349
"""
333350

334-
def __init__(self, tree_indep, traversal, dtype, fmm_level_to_order,
335-
source_extra_kwargs=None,
336-
kernel_extra_kwargs=None,
337-
self_extra_kwargs=None,
338-
translation_classes_data=None,
339-
preprocessed_mpole_dtype=None,
340-
*, _disable_translation_classes=False):
351+
tree_indep: SumpyTreeIndependentDataForWrangler
352+
traversal: FMMTraversalInfo
353+
354+
source_extra_kwargs: Mapping[str, object]
355+
kernel_extra_kwargs: Mapping[str, object]
356+
self_extra_kwargs: Mapping[str, object]
357+
extra_kwargs: Mapping[str, object]
358+
359+
dtype: np.dtype[Any]
360+
preprocessed_mpole_dtype: np.dtype[Any]
361+
362+
level_order: Sequence[int]
363+
364+
issued_timing_data_warning: bool
365+
366+
def __init__(self,
367+
tree_indep: SumpyTreeIndependentDataForWrangler,
368+
traversal: FMMTraversalInfo,
369+
dtype: DTypeLike,
370+
fmm_level_to_order: FMMLevelToOrder,
371+
source_extra_kwargs: Mapping[str, object] | None = None,
372+
kernel_extra_kwargs: Mapping[str, object] | None = None,
373+
self_extra_kwargs: Mapping[str, object] | None = None,
374+
translation_classes_data=None,
375+
preprocessed_mpole_dtype: DTypeLike | None = None,
376+
*,
377+
_disable_translation_classes=False
378+
):
341379
super().__init__(tree_indep, traversal)
342380
self.issued_timing_data_warning = False
343381

344-
self.dtype = dtype
382+
self.dtype = np.dtype(dtype)
345383

346384
if not self.tree_indep.m2l_translation.use_fft:
347385
# If not FFT, we don't need complex dtypes
348-
self.preprocessed_mpole_dtype = dtype
386+
self.preprocessed_mpole_dtype = self.dtype
349387
elif preprocessed_mpole_dtype is not None:
350-
self.preprocessed_mpole_dtype = preprocessed_mpole_dtype
388+
self.preprocessed_mpole_dtype = np.dtype(preprocessed_mpole_dtype)
351389
else:
352390
# FIXME: It is weird that the wrangler has to compute this.
353391
self.preprocessed_mpole_dtype = to_complex_dtype(dtype)
@@ -372,7 +410,7 @@ def __init__(self, tree_indep, traversal, dtype, fmm_level_to_order,
372410
self.kernel_extra_kwargs = kernel_extra_kwargs
373411
self.self_extra_kwargs = self_extra_kwargs
374412

375-
self.extra_kwargs = source_extra_kwargs.copy()
413+
self.extra_kwargs = dict(source_extra_kwargs)
376414
self.extra_kwargs.update(self.kernel_extra_kwargs)
377415

378416
if _disable_translation_classes or not base_kernel.is_translation_invariant:
@@ -390,7 +428,7 @@ def __init__(self, tree_indep, traversal, dtype, fmm_level_to_order,
390428

391429
self.translation_classes_data = translation_classes_data
392430

393-
def level_to_rscale(self, level):
431+
def level_to_rscale(self, level: int) -> float:
394432
tree = self.tree
395433
order = self.level_orders[level]
396434
r = tree.root_extent * (2**-level)
@@ -411,7 +449,7 @@ def level_to_rscale(self, level):
411449

412450
# {{{ data vector utilities
413451

414-
def _expansions_level_starts(self, order_to_size):
452+
def _expansions_level_starts(self, order_to_size: Callable[[int], int]):
415453
return build_csr_level_starts(self.level_orders, order_to_size,
416454
self.tree.level_start_box_nrs)
417455

@@ -433,7 +471,7 @@ def m2l_translation_class_level_start_box_nrs(self):
433471

434472
@memoize_method
435473
def m2l_translation_classes_dependent_data_level_starts(self):
436-
def order_to_size(order):
474+
def order_to_size(order: int):
437475
mpole_expn = self.tree_indep.multipole_expansion(order)
438476
local_expn = self.tree_indep.local_expansion(order)
439477
m2l_translation = local_expn.m2l_translation
@@ -623,7 +661,7 @@ def form_multipoles(self,
623661
src_weight_vecs):
624662
mpoles = self.multipole_expansion_zeros(src_weight_vecs[0])
625663

626-
kwargs = self.extra_kwargs.copy()
664+
kwargs = dict(self.extra_kwargs)
627665
kwargs.update(self.box_source_list_kwargs())
628666

629667
events = []
@@ -716,7 +754,7 @@ def eval_direct(self, target_boxes, source_box_starts,
716754
source_box_lists, src_weight_vecs):
717755
pot = self.output_zeros(src_weight_vecs[0])
718756

719-
kwargs = self.extra_kwargs.copy()
757+
kwargs = dict(self.extra_kwargs)
720758
kwargs.update(self.self_extra_kwargs)
721759
kwargs.update(self.box_source_list_kwargs())
722760
kwargs.update(self.box_target_list_kwargs())
@@ -963,7 +1001,7 @@ def eval_multipoles(self,
9631001
target_boxes_by_source_level, source_boxes_by_level, mpole_exps):
9641002
pot = self.output_zeros(mpole_exps)
9651003

966-
kwargs = self.kernel_extra_kwargs.copy()
1004+
kwargs = dict(self.kernel_extra_kwargs)
9671005
kwargs.update(self.box_target_list_kwargs())
9681006

9691007
events = []
@@ -1015,7 +1053,7 @@ def form_locals(self,
10151053
target_or_target_parent_boxes, starts, lists, src_weight_vecs):
10161054
local_exps = self.local_expansion_zeros(src_weight_vecs[0])
10171055

1018-
kwargs = self.extra_kwargs.copy()
1056+
kwargs = dict(self.extra_kwargs)
10191057
kwargs.update(self.box_source_list_kwargs())
10201058

10211059
events = []
@@ -1102,7 +1140,7 @@ def refine_locals(self,
11021140
def eval_locals(self, level_start_target_box_nrs, target_boxes, local_exps):
11031141
pot = self.output_zeros(local_exps)
11041142

1105-
kwargs = self.kernel_extra_kwargs.copy()
1143+
kwargs = dict(self.kernel_extra_kwargs)
11061144
kwargs.update(self.box_target_list_kwargs())
11071145

11081146
events = []

0 commit comments

Comments
 (0)