Skip to content

Commit 3f9af54

Browse files
Merge pull request #2601 from devitocodes/hotfix-factor-override
compiler: Fixup factor overriding
2 parents e050f23 + ec1840f commit 3f9af54

File tree

2 files changed

+51
-0
lines changed

2 files changed

+51
-0
lines changed

devito/types/dimension.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -987,6 +987,8 @@ def _arg_values(self, interval, grid=None, args=None, **kwargs):
987987
args = args or {}
988988
fname = self.symbolic_factor.name
989989
fact = kwargs.get(fname, args.get(fname, self.factor_data))
990+
if isinstance(fact, Constant):
991+
fact = fact.data
990992

991993
toint = lambda x: math.ceil(x / fact)
992994
vals = {}

tests/test_dimension.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1933,6 +1933,55 @@ def test_const_factor(self):
19331933
assert t2.factor.data == f1
19341934
assert t2.spacing == t1.spacing
19351935

1936+
def test_symbolic_factor_override_legacy(self):
1937+
grid = Grid(shape=(4, 4))
1938+
time = grid.time_dim
1939+
1940+
fact = Constant(name='fact', dtype=np.int32, value=4)
1941+
cd = ConditionalDimension(name='cd', parent=time, factor=fact)
1942+
1943+
u = TimeFunction(name='u', grid=grid, time_order=0)
1944+
usave = TimeFunction(name='usave', grid=grid, time_dim=cd, save=4)
1945+
1946+
eqns = [Eq(usave, u),
1947+
Eq(u.forward, u + 1)]
1948+
1949+
op = Operator(eqns)
1950+
1951+
op.apply()
1952+
1953+
assert all(np.all(usave.data[i] == i*4) for i in range(4))
1954+
1955+
# Now override the factor
1956+
fact1 = Constant(name='fact1', dtype=np.int32, value=8)
1957+
1958+
op.apply(time_M=31, fact=fact1)
1959+
1960+
assert all(np.all(usave.data[i] == 16 + i*8) for i in range(4))
1961+
1962+
def test_symbolic_factor_override(self):
1963+
grid = Grid(shape=(4, 4))
1964+
time = grid.time_dim
1965+
1966+
cd = ConditionalDimension(name='cd', parent=time, factor=4)
1967+
1968+
u = TimeFunction(name='u', grid=grid, time_order=0)
1969+
usave = TimeFunction(name='usave', grid=grid, time_dim=cd, save=4)
1970+
1971+
eqns = [Eq(usave, u),
1972+
Eq(u.forward, u + 1)]
1973+
1974+
op = Operator(eqns)
1975+
1976+
op.apply()
1977+
1978+
assert all(np.all(usave.data[i] == i*4) for i in range(4))
1979+
1980+
# Now override the factor
1981+
op.apply(time_M=31, **{cd.symbolic_factor.name: 8})
1982+
1983+
assert all(np.all(usave.data[i] == 16 + i*8) for i in range(4))
1984+
19361985

19371986
class TestCustomDimension:
19381987

0 commit comments

Comments
 (0)