Skip to content

Commit c04af5f

Browse files
committed
mpi: fix in-loop allreduce with multiple reductions
1 parent b19e707 commit c04af5f

File tree

2 files changed

+39
-3
lines changed

2 files changed

+39
-3
lines changed

devito/ir/clusters/algorithms.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -456,9 +456,9 @@ def reduction_comms(clusters):
456456
# Schedule the global distributed reductions encountered before `c`,
457457
# if `c`'s IterationSpace is such that the reduction can be carried out
458458
found, fifo = split(fifo, lambda dr: dr.ispace.is_subset(c.ispace))
459-
if found:
460-
exprs = [Eq(dr.var, dr) for dr in found]
461-
processed.append(c.rebuild(exprs=exprs))
459+
for ispace, reds in groupby(found, key=lambda r: r.ispace):
460+
exprs = [Eq(dr.var, dr) for dr in reds]
461+
processed.append(Cluster(exprs=exprs, ispace=ispace))
462462

463463
# Detect the global distributed reductions in `c`
464464
for e in c.exprs:

tests/test_mpi.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2029,8 +2029,44 @@ def test_allreduce_time(self, mode):
20292029
op = Operator([Eq(ux.forward, ux + tt), Inc(g, ux)], name="Op")
20302030
assert_structure(op, ['t,x,y', 't'], 'txy')
20312031

2032+
# Reduce should be in time loop but not in space loop
2033+
iters = FindNodes(Iteration).visit(op)
2034+
for i in iters:
2035+
if i.dim.is_Time:
2036+
assert len(FindNodes(Call).visit(i)) == 1 # one allreduce
2037+
else:
2038+
assert len(FindNodes(Call).visit(i)) == 0
2039+
2040+
op.apply(time_m=0, time_M=nt-1)
2041+
assert np.isclose(np.max(g.data), 4356.0)
2042+
2043+
@pytest.mark.parallel(mode=2)
2044+
def test_multi_allreduce_time(self, mode):
2045+
space_order = 8
2046+
nx, ny = 11, 11
2047+
2048+
grid = Grid(shape=(nx, ny))
2049+
tt = grid.time_dim
2050+
nt = 10
2051+
2052+
ux = TimeFunction(name="ux", grid=grid, time_order=1, space_order=space_order)
2053+
g = TimeFunction(name="g", grid=grid, dimensions=(tt, ), shape=(nt,))
2054+
h = TimeFunction(name="h", grid=grid, dimensions=(tt, ), shape=(nt,))
2055+
2056+
op = Operator([Eq(ux.forward, ux + tt), Inc(g, ux), Inc(h, ux)], name="Op")
2057+
assert_structure(op, ['t,x,y', 't'], 'txy')
2058+
2059+
# Make sure the two allreduce calls are in the time the loop
2060+
iters = FindNodes(Iteration).visit(op)
2061+
for i in iters:
2062+
if i.dim.is_Time:
2063+
assert len(FindNodes(Call).visit(i)) == 2 # Two allreduce
2064+
else:
2065+
assert len(FindNodes(Call).visit(i)) == 0
2066+
20322067
op.apply(time_m=0, time_M=nt-1)
20332068
assert np.isclose(np.max(g.data), 4356.0)
2069+
assert np.isclose(np.max(h.data), 4356.0)
20342070

20352071

20362072
class TestOperatorAdvanced:

0 commit comments

Comments
 (0)