Skip to content

Commit b537581

Browse files
JDBetteridgemloubout
authored andcommitted
dsl: Allow symbolic derivatives of Dimension types, do not evaluate
1 parent 47abb19 commit b537581

File tree

2 files changed

+30
-35
lines changed

2 files changed

+30
-35
lines changed

devito/finite_differences/derivative.py

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,16 @@
55

66
import sympy
77

8-
from .finite_difference import generic_derivative, cross_derivative
9-
from .differentiable import Differentiable, diffify, interp_for_fd, Add, Mul
10-
from .tools import direct, transpose
11-
from .rsfd import d45
12-
from devito.tools import (as_mapper, as_tuple, frozendict, is_integer,
13-
Pickable)
8+
from devito.tools import Pickable, as_mapper, as_tuple, frozendict, is_integer
149
from devito.types.dimension import Dimension
1510
from devito.types.utils import DimensionTuple
1611
from devito.warnings import warn
1712

13+
from .differentiable import Add, Differentiable, Mul, diffify, interp_for_fd
14+
from .finite_difference import cross_derivative, generic_derivative
15+
from .rsfd import d45
16+
from .tools import direct, transpose
17+
1818
__all__ = ['Derivative']
1919

2020

@@ -112,27 +112,19 @@ def __new__(cls, expr, *dims, **kwargs):
112112

113113
# It is also possible that the expression itself is just a
114114
# `devito.Dimension` type which is:
115-
# - derivative 1 if the dimension coincides and the number of derivatives
115+
# - derivative 1 if the Dimension coincides and the number of derivatives
116116
# is 1 ie: `Derivative(x, (x, 1)) == 1`.
117-
# - derivative 0 if the dimension coincides and the total number of
117+
# - derivative 0 if the Dimension coincides and the total number of
118118
# derivatives is greater than 1 ie: `Derivative(x, (x, 2)) == 0` and
119119
# `Derivative(x, x, y) == 0`.
120-
# - An error otherwise.
121-
if isinstance(expr, Dimension):
122-
if expr in dcounter.keys():
123-
if dcounter[expr] == 0:
124-
raise ValueError(
125-
f'Cannot interpolate a dimension `{expr}` onto itself'
126-
)
127-
elif dcounter.pop(expr) == 1 and not dcounter:
128-
return 1
129-
else:
130-
return 0
120+
# - An unevaluated expression otherwise.
121+
if isinstance(expr, Dimension) and expr in dcounter:
122+
if dcounter[expr] == 0:
123+
pass
124+
elif dcounter.pop(expr) == 1 and not dcounter:
125+
return 1
131126
else:
132-
raise ValueError(
133-
f'Cannot differentiate one dimension `{expr}` with respect to'
134-
f' another {tuple(dcounter.keys())}'
135-
)
127+
return 0
136128

137129
# Validate the finite difference order `fd_order`
138130
fd_order = cls._validate_fd_order(kwargs.get('fd_order'), expr, dims, dcounter)
@@ -182,12 +174,12 @@ def _validate_expr(expr):
182174
convertible to "differentiable" type.
183175
"""
184176
if type(expr) is sympy.Derivative:
185-
raise ValueError('Cannot nest sympy.Derivative with devito.Derivative')
177+
raise ValueError("Cannot nest sympy.Derivative with devito.Derivative")
186178
if not isinstance(expr, Differentiable):
187179
try:
188180
expr = diffify(expr)
189181
except Exception as e:
190-
raise ValueError('`expr` must be a `Differentiable` type object') from e
182+
raise ValueError("`expr` must be a `Differentiable` type object") from e
191183
return expr
192184

193185
@staticmethod
@@ -257,7 +249,11 @@ def _validate_fd_order(fd_order, expr, dims, dcounter):
257249
Required: `expr`, `dims`, and the derivative counter to validate.
258250
If not provided, the maximum supported order will be used.
259251
"""
260-
if fd_order is not None:
252+
if isinstance(expr, Dimension):
253+
# If the expression is just a dimension `expr.time_order` and
254+
# `expr.space_order` are not defined
255+
fd_order = (99,)*len(dcounter)
256+
elif fd_order is not None:
261257
# If `fd_order` is specified, then validate
262258
fcounter = defaultdict(int)
263259
# First create a dictionary mapping variable wrt which to differentiate

tests/test_derivatives.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1276,12 +1276,11 @@ def test_null(self):
12761276
assert Derivative(self.x, self.y, self.t, self.x) == 0
12771277
assert Derivative(self.x, self.y, self.t, (self.x, 2)) == 0
12781278

1279-
def test_error(self):
1280-
with pytest.raises(ValueError):
1281-
Derivative(self.x, self.t)
1282-
1283-
with pytest.raises(ValueError):
1284-
Derivative(self.x, self.y, self.t)
1285-
1286-
with pytest.raises(ValueError):
1287-
Derivative(self.x, (self.x, 0))
1279+
def test_unevaluated(self):
1280+
"""
1281+
The following should all be instantiatible without raising an
1282+
exception, but should not simplify.
1283+
"""
1284+
assert Derivative(self.x, self.t)
1285+
assert Derivative(self.x, self.y, self.t)
1286+
assert Derivative(self.x, (self.x, 0))

0 commit comments

Comments
 (0)