Skip to content

Commit 11176cd

Browse files
committed
api: allow expression such as f(x)/h_x for custom weights
1 parent e19765e commit 11176cd

File tree

4 files changed

+31
-13
lines changed

4 files changed

+31
-13
lines changed

devito/finite_differences/derivative.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def _process_weights(cls, **kwargs):
213213
weights = kwargs.get('weights', kwargs.get('w'))
214214
if weights is None:
215215
return None
216-
elif isinstance(weights, sympy.Function):
216+
elif isinstance(weights, Differentiable):
217217
return weights
218218
else:
219219
return as_tuple(weights)

devito/finite_differences/finite_difference.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def cross_derivative(expr, dims, fd_order, deriv_order, x0=None, side=None, **kw
5959
Semantically, this is equivalent to
6060
6161
>>> (f*g).dxdy
62-
Derivative(Derivative(f(x, y)*g(x, y), x), y)
62+
Derivative(f(x, y)*g(x, y), x, y)
6363
6464
The only difference is that in the latter case derivatives remain unevaluated.
6565
The expanded form is obtained via ``evaluate``
@@ -158,15 +158,15 @@ def make_derivative(expr, dim, fd_order, deriv_order, side, matvec, x0, coeffici
158158
# `coefficients` method (`taylor` or `symbolic`)
159159
if weights is None:
160160
weights = fd_weights_registry[coefficients](expr, deriv_order, indices, x0)
161-
elif wdim is not None:
161+
# Did fd_weights_registry return a new Function/Expression instead of a values?
162+
_, wdim, _ = process_weights(weights, expr, dim)
163+
if wdim is not None:
162164
weights = [weights._subs(wdim, i) for i in range(len(indices))]
163165

164166
# Enforce fixed precision FD coefficients to avoid variations in results
165167
if scale:
166168
scale = dim.spacing**(-deriv_order)
167-
else:
168-
scale = 1
169-
weights = [sympify(scale * w).evalf(_PRECISION) for w in weights]
169+
weights = [sympify(scale * w).evalf(_PRECISION) for w in weights]
170170

171171
# Transpose the FD, if necessary
172172
if matvec == transpose:

devito/finite_differences/tools.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from itertools import product
33

44
import numpy as np
5-
from sympy import S, finite_diff_weights, cacheit, sympify, Function, Rational
5+
from sympy import S, finite_diff_weights, cacheit, sympify, Rational
66

77
from devito.logger import warning
88
from devito.tools import Tag, as_tuple
@@ -48,11 +48,11 @@ def adjoint(self, matvec):
4848
def check_input(func):
4949
@wraps(func)
5050
def wrapper(expr, *args, **kwargs):
51-
try:
52-
return S.Zero if expr.is_Number else func(expr, *args, **kwargs)
53-
except AttributeError:
54-
raise ValueError("'%s' must be of type Differentiable, not %s"
55-
% (expr, type(expr)))
51+
# try:
52+
return S.Zero if expr.is_Number else func(expr, *args, **kwargs)
53+
# except AttributeError:
54+
# raise ValueError("'%s' must be of type Differentiable, not %s"
55+
# % (expr, type(expr)))
5656
return wrapper
5757

5858

@@ -326,9 +326,13 @@ def make_shift_x0(shift, ndim):
326326

327327

328328
def process_weights(weights, expr, dim):
329+
from devito.symbolics import retrieve_functions
330+
w_func = retrieve_functions(weights)
329331
if weights is None:
330332
return 0, None, False
331-
elif isinstance(weights, Function):
333+
elif w_func:
334+
assert len(w_func) == 1, "Only one function expected in weights"
335+
weights = w_func[0]
332336
if len(weights.dimensions) == 1:
333337
return weights.shape[0], weights.dimensions[0], False
334338
try:

tests/test_symbolic_coefficients.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,20 @@ def test_function_coefficients_xderiv(self, order):
120120
expr1 = f.dxdy(w=w).evaluate
121121
assert sp.simplify(expr0 - expr1) == 0
122122

123+
def test_coefficients_expr(self):
124+
p = Dimension('p')
125+
126+
grid = Grid(shape=(51, 51, 51))
127+
x, y, z = grid.dimensions
128+
129+
f = Function(name='f', grid=grid, space_order=4)
130+
w = Function(name='w', space_order=0, shape=(*grid.shape, 5),
131+
dimensions=(x, y, z, p))
132+
133+
expr0 = f.dx(w=w/x.spacing).evaluate
134+
expr1 = f.dx(w=w).evaluate / x.spacing
135+
assert sp.simplify(expr0 - expr1) == 0
136+
123137
def test_coefficients_w_xreplace(self):
124138
"""Test custom coefficients with an xreplace before they are applied"""
125139
grid = Grid(shape=(4, 4))

0 commit comments

Comments
 (0)