@@ -249,7 +249,17 @@ def postprocess_at_target(self, expr, bvec):
249249 :arg expr: may be a :class:`sympy.core.expr.Expr` or a
250250 :class:`sumpy.derivative_taker.DifferentiatedExprDerivativeTaker`.
251251 """
252- return expr
252+ from sumpy .derivative_taker import (ExprDerivativeTaker ,
253+ DifferentiatedExprDerivativeTaker )
254+ expr_dict = {(0 ,)* self .dim : 1 }
255+ expr_dict = self .get_derivative_coeff_dict_at_target (expr_dict )
256+ if isinstance (expr , ExprDerivativeTaker ):
257+ return DifferentiatedExprDerivativeTaker (expr , expr_dict )
258+
259+ result = 0
260+ for mi , coeff in expr_dict .items ():
261+ result += coeff * self ._diff (expr , bvec , mi )
262+ return result
253263
254264 def get_derivative_coeff_dict_at_source (self , expr_dict ):
255265 r"""Get the derivative transformation of the expression at source
@@ -263,6 +273,18 @@ def get_derivative_coeff_dict_at_source(self, expr_dict):
263273 """
264274 return expr_dict
265275
276+ def get_derivative_coeff_dict_at_target (self , expr_dict ):
277+ r"""Get the derivative transformation of the expression at target
278+ represented by the dictionary expr_dict which is mapping from multi-index
279+ `mi` to coefficient `coeff`.
280+ Expression represented by the dictionary `expr_dict` is
281+ :math:`\sum_{mi} \frac{\partial^mi}{x^mi}G * coeff`. Returns an
282+ expression of the same type.
283+
284+ This function is meant to be overridden by child classes where necessary.
285+ """
286+ return expr_dict
287+
266288 def get_global_scaling_const (self ):
267289 r"""Return a global scaling constant of the kernel.
268290 Typically, this ensures that the kernel is scaled so that
@@ -959,8 +981,8 @@ def get_expression(self, scaled_dist_vec):
959981 def get_derivative_coeff_dict_at_source (self , expr_dict ):
960982 return self .inner_kernel .get_derivative_coeff_dict_at_source (expr_dict )
961983
962- def postprocess_at_target (self , expr , bvec ):
963- return self .inner_kernel .postprocess_at_target ( expr , bvec )
984+ def get_derivative_coeff_dict_at_target (self , expr_dict ):
985+ return self .inner_kernel .get_derivative_coeff_dict_at_target ( expr_dict )
964986
965987 def get_global_scaling_const (self ):
966988 return self .inner_kernel .get_global_scaling_const ()
@@ -1043,23 +1065,19 @@ def __str__(self):
10431065 def __repr__ (self ):
10441066 return f"AxisTargetDerivative({ self .axis } , { self .inner_kernel !r} )"
10451067
1046- def postprocess_at_target (self , expr , bvec ):
1047- from sumpy .derivative_taker import (DifferentiatedExprDerivativeTaker ,
1048- diff_derivative_coeff_dict )
1068+ def get_derivative_coeff_dict_at_target (self , expr_dict ):
10491069 from sumpy .symbolic import make_sym_vector as make_sympy_vector
1050-
10511070 target_vec = make_sympy_vector (self .target_array_name , self .dim )
10521071
1053- # bvec = tgt - ctr
1054- expr = self .inner_kernel .postprocess_at_target (expr , bvec )
1055- if isinstance (expr , DifferentiatedExprDerivativeTaker ):
1056- transformation = diff_derivative_coeff_dict (expr .derivative_coeff_dict ,
1057- self .axis , target_vec )
1058- return DifferentiatedExprDerivativeTaker (expr .taker , transformation )
1059- else :
1060- # Since `bvec` and `tgt` are two different symbolic variables
1061- # need to differentiate by both to get the correct answer
1062- return expr .diff (bvec [self .axis ]) + expr .diff (target_vec [self .axis ])
1072+ expr_dict = self .inner_kernel .get_derivative_coeff_dict_at_target (
1073+ expr_dict )
1074+ result = defaultdict (lambda : 0 )
1075+ for mi , coeff in expr_dict .items ():
1076+ new_mi = list (mi )
1077+ new_mi [self .axis ] += 1
1078+ result [tuple (new_mi )] += coeff
1079+ result [mi ] += sym .diff (coeff , target_vec [self .axis ])
1080+ return dict (result )
10631081
10641082 def replace_base_kernel (self , new_base_kernel ):
10651083 return type (self )(self .axis ,
@@ -1141,35 +1159,23 @@ def transform(expr):
11411159
11421160 return transform
11431161
1144- def postprocess_at_target (self , expr , bvec ):
1145- from sumpy .derivative_taker import (DifferentiatedExprDerivativeTaker ,
1146- diff_derivative_coeff_dict )
1147-
1162+ def get_derivative_coeff_dict_at_target (self , expr_dict ):
11481163 from sumpy .symbolic import make_sym_vector as make_sympy_vector
11491164 dir_vec = make_sympy_vector (self .dir_vec_name , self .dim )
11501165 target_vec = make_sympy_vector (self .target_array_name , self .dim )
11511166
1152- expr = self .inner_kernel .postprocess_at_target (expr , bvec )
1167+ expr_dict = self .inner_kernel .get_derivative_coeff_dict_at_target (
1168+ expr_dict )
11531169
1154- # bvec = tgt - center
1155- if not isinstance ( expr , DifferentiatedExprDerivativeTaker ):
1156- result = 0
1170+ # avec = tgt - center
1171+ result = defaultdict ( lambda : 0 )
1172+ for mi , coeff in expr_dict . items ():
11571173 for axis in range (self .dim ):
1158- # Since `bvec` and `tgt` are two different symbolic variables
1159- # need to differentiate by both to get the correct answer
1160- result += (expr .diff (bvec [axis ]) + expr .diff (target_vec [axis ])) \
1161- * dir_vec [axis ]
1162- return result
1163-
1164- new_transformation = defaultdict (lambda : 0 )
1165- for axis in range (self .dim ):
1166- axis_transformation = diff_derivative_coeff_dict (
1167- expr .derivative_coeff_dict , axis , target_vec )
1168- for mi , coeff in axis_transformation .items ():
1169- new_transformation [mi ] += coeff * dir_vec [axis ]
1170-
1171- return DifferentiatedExprDerivativeTaker (expr .taker ,
1172- dict (new_transformation ))
1174+ new_mi = list (mi )
1175+ new_mi [axis ] += 1
1176+ result [tuple (new_mi )] += coeff * dir_vec [axis ]
1177+ result [mi ] += sym .diff (coeff , target_vec [axis ])
1178+ return result
11731179
11741180 def get_args (self ):
11751181 return [
@@ -1268,25 +1274,14 @@ def replace_base_kernel(self, new_base_kernel):
12681274 def replace_inner_kernel (self , new_inner_kernel ):
12691275 return type (self )(self .axis , new_inner_kernel )
12701276
1271- def postprocess_at_target (self , expr , avec ):
1277+ def get_derivative_coeff_dict_at_target (self , expr_dict ):
12721278 from sumpy .symbolic import make_sym_vector as make_sympy_vector
1273- from sumpy .derivative_taker import (ExprDerivativeTaker ,
1274- DifferentiatedExprDerivativeTaker )
1275-
1276- expr = self .inner_kernel .postprocess_at_target (expr , avec )
12771279 target_vec = make_sympy_vector (self .target_array_name , self .dim )
12781280
1279- zeros = tuple ([ 0 ] * self .dim )
1280- mult = target_vec [ self . axis ]
1281+ expr_dict = self .inner_kernel . get_derivative_coeff_dict_at_target (
1282+ expr_dict )
12811283
1282- if isinstance (expr , DifferentiatedExprDerivativeTaker ):
1283- transform = {mi : coeff * mult for mi , coeff in
1284- expr .derivative_coeff_dict .items ()}
1285- return DifferentiatedExprDerivativeTaker (expr .taker , transform )
1286- elif isinstance (expr , ExprDerivativeTaker ):
1287- return DifferentiatedExprDerivativeTaker ({zeros : mult })
1288- else :
1289- return mult * expr
1284+ return {mi : coeff * target_vec [self .axis ] for mi , coeff in expr_dict .items ()}
12901285
12911286 def get_code_transformer (self ):
12921287 from sumpy .codegen import VectorComponentRewriter
0 commit comments