2626__doc__ = """Integrates :mod:`boxtree` with :mod:`sumpy`.
2727
2828.. autoclass:: SumpyTreeIndependentDataForWrangler
29+ .. autodata:: FMMLevelToOrder
30+ :no-index:
31+ .. class:: FMMLevelToOrder
32+
33+ See above.
34+
2935.. autoclass:: SumpyExpansionWrangler
3036
3137.. autoclass:: MultipoleExpansionFromOrderFactory
3238.. autoclass:: LocalExpansionFromOrderFactory
3339"""
3440
35- from typing import TYPE_CHECKING , Protocol , cast
41+ from collections .abc import Callable , Mapping
42+ from typing import TYPE_CHECKING , Any , Protocol , TypeAlias , cast
3643
44+ import numpy as np
3745from boxtree .fmm import ExpansionWranglerInterface , TreeIndependentDataForWrangler
46+ from boxtree .tree import Tree
3847
3948import pyopencl as cl
4049import pyopencl .array as cl_array
5564 P2EFromSingleBox ,
5665 P2PFromCSR ,
5766)
67+ from sumpy .kernel import Kernel
5868from sumpy .tools import (
5969 AggregateProfilingEvent ,
6070 get_native_event ,
6777if TYPE_CHECKING :
6878 from collections .abc import Sequence
6979
80+ from boxtree .traversal import FMMTraversalInfo
81+ from numpy .typing import DTypeLike
82+
7083 from sumpy .expansion .local import LocalExpansionBase
7184 from sumpy .expansion .multipole import MultipoleExpansionBase
72- from sumpy .kernel import Kernel
7385
7486
7587class MultipoleExpansionFromOrderFactory (Protocol ):
@@ -306,6 +318,11 @@ def done(self):
306318
307319# {{{ expansion wrangler
308320
321+ FMMLevelToOrder : TypeAlias = Callable [
322+ [Kernel , frozenset [tuple [str , object ]], Tree , int ],
323+ int ]
324+
325+
309326class SumpyExpansionWrangler (ExpansionWranglerInterface ):
310327 """Implements the :class:`boxtree.fmm.ExpansionWranglerInterface`
311328 by using :mod:`sumpy` expansions/translations.
@@ -331,23 +348,44 @@ class SumpyExpansionWrangler(ExpansionWranglerInterface):
331348 Type for the preprocessed multipole expansion if used for M2L.
332349 """
333350
334- def __init__ (self , tree_indep , traversal , dtype , fmm_level_to_order ,
335- source_extra_kwargs = None ,
336- kernel_extra_kwargs = None ,
337- self_extra_kwargs = None ,
338- translation_classes_data = None ,
339- preprocessed_mpole_dtype = None ,
340- * , _disable_translation_classes = False ):
351+ tree_indep : SumpyTreeIndependentDataForWrangler
352+ traversal : FMMTraversalInfo
353+
354+ source_extra_kwargs : Mapping [str , object ]
355+ kernel_extra_kwargs : Mapping [str , object ]
356+ self_extra_kwargs : Mapping [str , object ]
357+ extra_kwargs : Mapping [str , object ]
358+
359+ dtype : np .dtype [Any ]
360+ preprocessed_mpole_dtype : np .dtype [Any ]
361+
362+ level_order : Sequence [int ]
363+
364+ issued_timing_data_warning : bool
365+
366+ def __init__ (self ,
367+ tree_indep : SumpyTreeIndependentDataForWrangler ,
368+ traversal : FMMTraversalInfo ,
369+ dtype : DTypeLike ,
370+ fmm_level_to_order : FMMLevelToOrder ,
371+ source_extra_kwargs : Mapping [str , object ] | None = None ,
372+ kernel_extra_kwargs : Mapping [str , object ] | None = None ,
373+ self_extra_kwargs : Mapping [str , object ] | None = None ,
374+ translation_classes_data = None ,
375+ preprocessed_mpole_dtype : DTypeLike | None = None ,
376+ * ,
377+ _disable_translation_classes = False
378+ ):
341379 super ().__init__ (tree_indep , traversal )
342380 self .issued_timing_data_warning = False
343381
344- self .dtype = dtype
382+ self .dtype = np . dtype ( dtype )
345383
346384 if not self .tree_indep .m2l_translation .use_fft :
347385 # If not FFT, we don't need complex dtypes
348- self .preprocessed_mpole_dtype = dtype
386+ self .preprocessed_mpole_dtype = self . dtype
349387 elif preprocessed_mpole_dtype is not None :
350- self .preprocessed_mpole_dtype = preprocessed_mpole_dtype
388+ self .preprocessed_mpole_dtype = np . dtype ( preprocessed_mpole_dtype )
351389 else :
352390 # FIXME: It is weird that the wrangler has to compute this.
353391 self .preprocessed_mpole_dtype = to_complex_dtype (dtype )
@@ -372,7 +410,7 @@ def __init__(self, tree_indep, traversal, dtype, fmm_level_to_order,
372410 self .kernel_extra_kwargs = kernel_extra_kwargs
373411 self .self_extra_kwargs = self_extra_kwargs
374412
375- self .extra_kwargs = source_extra_kwargs . copy ( )
413+ self .extra_kwargs = dict ( source_extra_kwargs )
376414 self .extra_kwargs .update (self .kernel_extra_kwargs )
377415
378416 if _disable_translation_classes or not base_kernel .is_translation_invariant :
@@ -390,7 +428,7 @@ def __init__(self, tree_indep, traversal, dtype, fmm_level_to_order,
390428
391429 self .translation_classes_data = translation_classes_data
392430
393- def level_to_rscale (self , level ) :
431+ def level_to_rscale (self , level : int ) -> float :
394432 tree = self .tree
395433 order = self .level_orders [level ]
396434 r = tree .root_extent * (2 ** - level )
@@ -411,7 +449,7 @@ def level_to_rscale(self, level):
411449
412450 # {{{ data vector utilities
413451
414- def _expansions_level_starts (self , order_to_size ):
452+ def _expansions_level_starts (self , order_to_size : Callable [[ int ], int ] ):
415453 return build_csr_level_starts (self .level_orders , order_to_size ,
416454 self .tree .level_start_box_nrs )
417455
@@ -433,7 +471,7 @@ def m2l_translation_class_level_start_box_nrs(self):
433471
434472 @memoize_method
435473 def m2l_translation_classes_dependent_data_level_starts (self ):
436- def order_to_size (order ):
474+ def order_to_size (order : int ):
437475 mpole_expn = self .tree_indep .multipole_expansion (order )
438476 local_expn = self .tree_indep .local_expansion (order )
439477 m2l_translation = local_expn .m2l_translation
@@ -623,7 +661,7 @@ def form_multipoles(self,
623661 src_weight_vecs ):
624662 mpoles = self .multipole_expansion_zeros (src_weight_vecs [0 ])
625663
626- kwargs = self .extra_kwargs . copy ( )
664+ kwargs = dict ( self .extra_kwargs )
627665 kwargs .update (self .box_source_list_kwargs ())
628666
629667 events = []
@@ -716,7 +754,7 @@ def eval_direct(self, target_boxes, source_box_starts,
716754 source_box_lists , src_weight_vecs ):
717755 pot = self .output_zeros (src_weight_vecs [0 ])
718756
719- kwargs = self .extra_kwargs . copy ( )
757+ kwargs = dict ( self .extra_kwargs )
720758 kwargs .update (self .self_extra_kwargs )
721759 kwargs .update (self .box_source_list_kwargs ())
722760 kwargs .update (self .box_target_list_kwargs ())
@@ -963,7 +1001,7 @@ def eval_multipoles(self,
9631001 target_boxes_by_source_level , source_boxes_by_level , mpole_exps ):
9641002 pot = self .output_zeros (mpole_exps )
9651003
966- kwargs = self .kernel_extra_kwargs . copy ( )
1004+ kwargs = dict ( self .kernel_extra_kwargs )
9671005 kwargs .update (self .box_target_list_kwargs ())
9681006
9691007 events = []
@@ -1015,7 +1053,7 @@ def form_locals(self,
10151053 target_or_target_parent_boxes , starts , lists , src_weight_vecs ):
10161054 local_exps = self .local_expansion_zeros (src_weight_vecs [0 ])
10171055
1018- kwargs = self .extra_kwargs . copy ( )
1056+ kwargs = dict ( self .extra_kwargs )
10191057 kwargs .update (self .box_source_list_kwargs ())
10201058
10211059 events = []
@@ -1102,7 +1140,7 @@ def refine_locals(self,
11021140 def eval_locals (self , level_start_target_box_nrs , target_boxes , local_exps ):
11031141 pot = self .output_zeros (local_exps )
11041142
1105- kwargs = self .kernel_extra_kwargs . copy ( )
1143+ kwargs = dict ( self .kernel_extra_kwargs )
11061144 kwargs .update (self .box_target_list_kwargs ())
11071145
11081146 events = []
0 commit comments