Skip to content

Commit 64a2196

Browse files
committed
introduce get_derivative_coeff_dict_at_source counterpart for target
1 parent 82d6d43 commit 64a2196

File tree

2 files changed

+50
-62
lines changed

2 files changed

+50
-62
lines changed

sumpy/expansion/multipole.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -90,17 +90,10 @@ def coefficients_from_source(self, kernel, avec, bvec, rscale, sac=None):
9090
rscale, (1,), sac=sac)
9191

9292
def evaluate(self, kernel, coeffs, bvec, rscale, sac=None):
93-
from sumpy.derivative_taker import DifferentiatedExprDerivativeTaker
9493
if not self.use_rscale:
9594
rscale = 1
9695

9796
base_taker = kernel.get_derivative_taker(bvec, rscale, sac)
98-
# Following is a no-op, but AxisTargetDerivative.postprocess_at_target and
99-
# DirectionalTargetDerivative.postprocess_at_target only handles
100-
# DifferentiatedExprDerivativeTaker and sympy expressions, so we need to
101-
# make the taker a DifferentitatedExprDerivativeTaker instance.
102-
base_taker = DifferentiatedExprDerivativeTaker(base_taker,
103-
{tuple([0]*self.dim): 1})
10497
taker = kernel.postprocess_at_target(base_taker, bvec)
10598

10699
result = []

sumpy/kernel.py

Lines changed: 50 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,17 @@ def postprocess_at_target(self, expr, bvec):
249249
:arg expr: may be a :class:`sympy.core.expr.Expr` or a
250250
:class:`sumpy.derivative_taker.DifferentiatedExprDerivativeTaker`.
251251
"""
252-
return expr
252+
from sumpy.derivative_taker import (ExprDerivativeTaker,
253+
DifferentiatedExprDerivativeTaker)
254+
expr_dict = {(0,)*self.dim: 1}
255+
expr_dict = self.get_derivative_coeff_dict_at_target(expr_dict)
256+
if isinstance(expr, ExprDerivativeTaker):
257+
return DifferentiatedExprDerivativeTaker(expr, expr_dict)
258+
259+
result = 0
260+
for mi, coeff in expr_dict.items():
261+
result += coeff * self._diff(expr, bvec, mi)
262+
return result
253263

254264
def get_derivative_coeff_dict_at_source(self, expr_dict):
255265
r"""Get the derivative transformation of the expression at source
@@ -263,6 +273,18 @@ def get_derivative_coeff_dict_at_source(self, expr_dict):
263273
"""
264274
return expr_dict
265275

276+
def get_derivative_coeff_dict_at_target(self, expr_dict):
277+
r"""Get the derivative transformation of the expression at target
278+
represented by the dictionary expr_dict which is mapping from multi-index
279+
`mi` to coefficient `coeff`.
280+
Expression represented by the dictionary `expr_dict` is
281+
:math:`\sum_{mi} \frac{\partial^mi}{x^mi}G * coeff`. Returns an
282+
expression of the same type.
283+
284+
This function is meant to be overridden by child classes where necessary.
285+
"""
286+
return expr_dict
287+
266288
def get_global_scaling_const(self):
267289
r"""Return a global scaling constant of the kernel.
268290
Typically, this ensures that the kernel is scaled so that
@@ -959,8 +981,8 @@ def get_expression(self, scaled_dist_vec):
959981
def get_derivative_coeff_dict_at_source(self, expr_dict):
960982
return self.inner_kernel.get_derivative_coeff_dict_at_source(expr_dict)
961983

962-
def postprocess_at_target(self, expr, bvec):
963-
return self.inner_kernel.postprocess_at_target(expr, bvec)
984+
def get_derivative_coeff_dict_at_target(self, expr_dict):
985+
return self.inner_kernel.get_derivative_coeff_dict_at_target(expr_dict)
964986

965987
def get_global_scaling_const(self):
966988
return self.inner_kernel.get_global_scaling_const()
@@ -1043,23 +1065,19 @@ def __str__(self):
10431065
def __repr__(self):
10441066
return f"AxisTargetDerivative({self.axis}, {self.inner_kernel!r})"
10451067

1046-
def postprocess_at_target(self, expr, bvec):
1047-
from sumpy.derivative_taker import (DifferentiatedExprDerivativeTaker,
1048-
diff_derivative_coeff_dict)
1068+
def get_derivative_coeff_dict_at_target(self, expr_dict):
10491069
from sumpy.symbolic import make_sym_vector as make_sympy_vector
1050-
10511070
target_vec = make_sympy_vector(self.target_array_name, self.dim)
10521071

1053-
# bvec = tgt - ctr
1054-
expr = self.inner_kernel.postprocess_at_target(expr, bvec)
1055-
if isinstance(expr, DifferentiatedExprDerivativeTaker):
1056-
transformation = diff_derivative_coeff_dict(expr.derivative_coeff_dict,
1057-
self.axis, target_vec)
1058-
return DifferentiatedExprDerivativeTaker(expr.taker, transformation)
1059-
else:
1060-
# Since `bvec` and `tgt` are two different symbolic variables
1061-
# need to differentiate by both to get the correct answer
1062-
return expr.diff(bvec[self.axis]) + expr.diff(target_vec[self.axis])
1072+
expr_dict = self.inner_kernel.get_derivative_coeff_dict_at_target(
1073+
expr_dict)
1074+
result = defaultdict(lambda: 0)
1075+
for mi, coeff in expr_dict.items():
1076+
new_mi = list(mi)
1077+
new_mi[self.axis] += 1
1078+
result[tuple(new_mi)] += coeff
1079+
result[mi] += sym.diff(coeff, target_vec[self.axis])
1080+
return dict(result)
10631081

10641082
def replace_base_kernel(self, new_base_kernel):
10651083
return type(self)(self.axis,
@@ -1141,35 +1159,23 @@ def transform(expr):
11411159

11421160
return transform
11431161

1144-
def postprocess_at_target(self, expr, bvec):
1145-
from sumpy.derivative_taker import (DifferentiatedExprDerivativeTaker,
1146-
diff_derivative_coeff_dict)
1147-
1162+
def get_derivative_coeff_dict_at_target(self, expr_dict):
11481163
from sumpy.symbolic import make_sym_vector as make_sympy_vector
11491164
dir_vec = make_sympy_vector(self.dir_vec_name, self.dim)
11501165
target_vec = make_sympy_vector(self.target_array_name, self.dim)
11511166

1152-
expr = self.inner_kernel.postprocess_at_target(expr, bvec)
1167+
expr_dict = self.inner_kernel.get_derivative_coeff_dict_at_target(
1168+
expr_dict)
11531169

1154-
# bvec = tgt - center
1155-
if not isinstance(expr, DifferentiatedExprDerivativeTaker):
1156-
result = 0
1170+
# avec = tgt - center
1171+
result = defaultdict(lambda: 0)
1172+
for mi, coeff in expr_dict.items():
11571173
for axis in range(self.dim):
1158-
# Since `bvec` and `tgt` are two different symbolic variables
1159-
# need to differentiate by both to get the correct answer
1160-
result += (expr.diff(bvec[axis]) + expr.diff(target_vec[axis])) \
1161-
* dir_vec[axis]
1162-
return result
1163-
1164-
new_transformation = defaultdict(lambda: 0)
1165-
for axis in range(self.dim):
1166-
axis_transformation = diff_derivative_coeff_dict(
1167-
expr.derivative_coeff_dict, axis, target_vec)
1168-
for mi, coeff in axis_transformation.items():
1169-
new_transformation[mi] += coeff * dir_vec[axis]
1170-
1171-
return DifferentiatedExprDerivativeTaker(expr.taker,
1172-
dict(new_transformation))
1174+
new_mi = list(mi)
1175+
new_mi[axis] += 1
1176+
result[tuple(new_mi)] += coeff * dir_vec[axis]
1177+
result[mi] += sym.diff(coeff, target_vec[axis])
1178+
return result
11731179

11741180
def get_args(self):
11751181
return [
@@ -1268,25 +1274,14 @@ def replace_base_kernel(self, new_base_kernel):
12681274
def replace_inner_kernel(self, new_inner_kernel):
12691275
return type(self)(self.axis, new_inner_kernel)
12701276

1271-
def postprocess_at_target(self, expr, avec):
1277+
def get_derivative_coeff_dict_at_target(self, expr_dict):
12721278
from sumpy.symbolic import make_sym_vector as make_sympy_vector
1273-
from sumpy.derivative_taker import (ExprDerivativeTaker,
1274-
DifferentiatedExprDerivativeTaker)
1275-
1276-
expr = self.inner_kernel.postprocess_at_target(expr, avec)
12771279
target_vec = make_sympy_vector(self.target_array_name, self.dim)
12781280

1279-
zeros = tuple([0]*self.dim)
1280-
mult = target_vec[self.axis]
1281+
expr_dict = self.inner_kernel.get_derivative_coeff_dict_at_target(
1282+
expr_dict)
12811283

1282-
if isinstance(expr, DifferentiatedExprDerivativeTaker):
1283-
transform = {mi: coeff * mult for mi, coeff in
1284-
expr.derivative_coeff_dict.items()}
1285-
return DifferentiatedExprDerivativeTaker(expr.taker, transform)
1286-
elif isinstance(expr, ExprDerivativeTaker):
1287-
return DifferentiatedExprDerivativeTaker({zeros: mult})
1288-
else:
1289-
return mult * expr
1284+
return {mi: coeff*target_vec[self.axis] for mi, coeff in expr_dict.items()}
12901285

12911286
def get_code_transformer(self):
12921287
from sumpy.codegen import VectorComponentRewriter

0 commit comments

Comments
 (0)