Skip to content

Commit 9ab28e0

Browse files
committed
mpi: fix allreduce iteration space
1 parent ac5097e commit 9ab28e0

File tree

3 files changed

+65
-27
lines changed

3 files changed

+65
-27
lines changed

devito/ir/clusters/algorithms.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from devito.exceptions import CompilationError
99
from devito.finite_differences.elementary import Max, Min
1010
from devito.ir.support import (Any, Backward, Forward, IterationSpace, erange,
11-
pull_dims, null_ispace)
11+
pull_dims)
1212
from devito.ir.equations import OpMin, OpMax, identity_mapper
1313
from devito.ir.clusters.analysis import analyze
1414
from devito.ir.clusters.cluster import Cluster, ClusterGroup
@@ -493,9 +493,9 @@ def reduction_comms(clusters):
493493
processed.append(c)
494494

495495
# Leftover reductions are placed at the very end
496-
if fifo:
497-
exprs = [Eq(dr.var, dr) for dr in fifo]
498-
processed.append(Cluster(exprs=exprs, ispace=null_ispace))
496+
for ispace, reds in groupby(fifo, key=lambda r: r.ispace):
497+
exprs = [Eq(dr.var, dr) for dr in reds]
498+
processed.append(Cluster(exprs=exprs, ispace=ispace))
499499

500500
return processed
501501

devito/ir/iet/visitors.py

Lines changed: 43 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1295,6 +1295,28 @@ def __init__(self, mapper, nested=False):
12951295
self.mapper = mapper
12961296
self.nested = nested
12971297

1298+
def transform(self, o, handle, **kwargs):
1299+
if handle is None:
1300+
# None -> drop `o`
1301+
return None
1302+
elif isinstance(handle, Iterable):
1303+
# Iterable -> inject `handle` into `o`'s children
1304+
if not o.children:
1305+
raise CompilationError("Cannot inject nodes in a leaf node")
1306+
if self.nested:
1307+
children = [self._visit(i, **kwargs) for i in o.children]
1308+
else:
1309+
children = o.children
1310+
children = (tuple(handle) + children[0],) + tuple(children[1:])
1311+
return o._rebuild(*children, **o.args_frozen)
1312+
else:
1313+
# Replace `o` with `handle`
1314+
if self.nested:
1315+
children = [self._visit(i, **kwargs) for i in handle.children]
1316+
return handle._rebuild(*children, **handle.args_frozen)
1317+
else:
1318+
return handle
1319+
12981320
def visit_object(self, o, **kwargs):
12991321
return o
13001322

@@ -1304,32 +1326,30 @@ def visit_tuple(self, o, **kwargs):
13041326

13051327
visit_list = visit_tuple
13061328

1329+
def visit_ExpressionBundle(self, o, **kwargs):
1330+
if o in self.mapper:
1331+
handle = self.mapper[o]
1332+
return self.transform(o, handle, **kwargs)
1333+
children = [self._visit(i) for i in o.children]
1334+
if not [i for i in children if i]:
1335+
return None
1336+
return o._rebuild(*children, **o.args_frozen)
1337+
1338+
def visit_Iteration(self, o, **kwargs):
1339+
if o in self.mapper:
1340+
handle = self.mapper[o]
1341+
return self.transform(o, handle, **kwargs)
1342+
children = [self._visit(i) for i in o.children]
1343+
if not [i for i in children if i]:
1344+
return None
1345+
return o._rebuild(*children, **o.args_frozen)
1346+
13071347
def visit_Node(self, o, **kwargs):
13081348
if o in self.mapper:
13091349
handle = self.mapper[o]
1310-
if handle is None:
1311-
# None -> drop `o`
1312-
return None
1313-
elif isinstance(handle, Iterable):
1314-
# Iterable -> inject `handle` into `o`'s children
1315-
if not o.children:
1316-
raise CompilationError("Cannot inject nodes in a leaf node")
1317-
if self.nested:
1318-
children = [self._visit(i, **kwargs) for i in o.children]
1319-
else:
1320-
children = o.children
1321-
children = (tuple(handle) + children[0],) + tuple(children[1:])
1322-
return o._rebuild(*children, **o.args_frozen)
1323-
else:
1324-
# Replace `o` with `handle`
1325-
if self.nested:
1326-
children = [self._visit(i, **kwargs) for i in handle.children]
1327-
return handle._rebuild(*children, **handle.args_frozen)
1328-
else:
1329-
return handle
1330-
else:
1331-
children = [self._visit(i, **kwargs) for i in o.children]
1332-
return o._rebuild(*children, **o.args_frozen)
1350+
return self.transform(o, handle, **kwargs)
1351+
children = [self._visit(i, **kwargs) for i in o.children]
1352+
return o._rebuild(*children, **o.args_frozen)
13331353

13341354
def visit_Operator(self, o, **kwargs):
13351355
raise ValueError("Cannot apply a Transformer visitor to an Operator directly")

tests/test_mpi.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2014,6 +2014,24 @@ def test_merge_smart_if_within_conditional(self, mode):
20142014
for n in FindNodes(Conditional).visit(op1):
20152015
assert len(FindNodes(HaloUpdateCall).visit(n)) == 0
20162016

2017+
@pytest.mark.parallel(mode=2)
2018+
def test_allreduce_time(self, mode):
2019+
space_order = 8
2020+
nx, ny = 11, 11
2021+
2022+
grid = Grid(shape=(nx, ny))
2023+
tt = grid.time_dim
2024+
nt = 10
2025+
2026+
ux = TimeFunction(name="ux", grid=grid, time_order=1, space_order=space_order)
2027+
g = TimeFunction(name="g", grid=grid, dimensions=(tt, ), shape=(nt,))
2028+
2029+
op = Operator([Eq(ux.forward, ux + tt), Inc(g, ux)], name="Op")
2030+
assert_structure(op, ['t,x,y', 't'], 'txy')
2031+
2032+
op.apply(time_m=0, time_M=nt-1)
2033+
assert np.isclose(np.max(g.data), 4356.0)
2034+
20172035

20182036
class TestOperatorAdvanced:
20192037

0 commit comments

Comments
 (0)