Skip to content

Commit cb1a5c6

Browse files
committed
ci: add test for cond multi all-reduce
1 parent 893cb09 commit cb1a5c6

File tree

2 files changed

+33
-1
lines changed

2 files changed

+33
-1
lines changed

devito/mpi/reduction_scheme.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ def __new__(cls, var, op=None, grid=None, ispace=None, **kwargs):
1919
obj.op = op
2020
obj.grid = grid
2121
obj.ispace = ispace
22-
obj.guards = kwargs.get('guards', None)
2322
return obj
2423

2524
def __repr__(self):

tests/test_mpi.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2068,6 +2068,39 @@ def test_multi_allreduce_time(self, mode):
20682068
assert np.isclose(np.max(g.data), 4356.0)
20692069
assert np.isclose(np.max(h.data), 4356.0)
20702070

2071+
@pytest.mark.parallel(mode=1)
2072+
def test_multi_allreduce_time_cond(self, mode):
2073+
space_order = 8
2074+
nx, ny = 11, 11
2075+
2076+
grid = Grid(shape=(nx, ny))
2077+
tt = grid.time_dim
2078+
nt = 20
2079+
ct = ConditionalDimension(name="ct", parent=tt, factor=2)
2080+
2081+
ux = TimeFunction(name="ux", grid=grid, time_order=1, space_order=space_order)
2082+
g = TimeFunction(name="g", grid=grid, dimensions=(ct, ), shape=(int(nt/2),),
2083+
time_dim=ct)
2084+
h = TimeFunction(name="h", grid=grid, dimensions=(ct, ), shape=(int(nt/2),),
2085+
time_dim=ct)
2086+
2087+
op = Operator([Eq(g, 0), Eq(ux.forward, tt), Inc(g, ux), Inc(h, ux)], name="Op")
2088+
assert_structure(op, ['t', 't,x,y', 't,x,y'], 'txyxy')
2089+
2090+
# Make sure the two allreduce calls are in the time the loop
2091+
iters = FindNodes(Iteration).visit(op)
2092+
for i in iters:
2093+
if i.dim.is_Time:
2094+
assert len(FindNodes(Call).visit(i)) == 2 # Two allreduce
2095+
else:
2096+
assert len(FindNodes(Call).visit(i)) == 0
2097+
2098+
op.apply(time_m=0, time_M=nt-1)
2099+
2100+
expected = [nx * ny * max(t-1, 0) for t in range(0, nt, 2)]
2101+
assert np.allclose(g.data, expected)
2102+
assert np.allclose(h.data, expected)
2103+
20712104

20722105
class TestOperatorAdvanced:
20732106

0 commit comments

Comments
 (0)