Skip to content

Commit 7c00d99

Browse files
Merge pull request #2600 from devitocodes/blocking
compiler: Blocked SubDimensions now use root dimension name
2 parents 3f9af54 + d4fb22c commit 7c00d99

File tree

6 files changed

+33
-33
lines changed

6 files changed

+33
-33
lines changed

devito/passes/clusters/blocking.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -354,15 +354,15 @@ def _derive_block_dims(self, clusters, prefix, d, blk_size_gen):
354354
except KeyError:
355355
pass
356356

357-
base = self.sregistry.make_name(prefix=d.name)
357+
base = self.sregistry.make_name(prefix=d.root.name)
358358

359-
name = self.sregistry.make_name(prefix="%s_blk" % base)
359+
name = self.sregistry.make_name(prefix=f"{base}_blk")
360360
bd = BlockDimension(name, d, d.symbolic_min, d.symbolic_max, step)
361361
step = bd.step
362362
block_dims = [bd]
363363

364364
for _ in range(1, self.levels):
365-
name = self.sregistry.make_name(prefix="%s_blk" % base)
365+
name = self.sregistry.make_name(prefix=f"{base}_blk")
366366
bd = BlockDimension(name, bd, bd, bd + bd.step - 1, size=step)
367367
block_dims.append(bd)
368368

tests/test_dimension.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2044,21 +2044,21 @@ def test_topofusion_w_subdims_conddims(self):
20442044

20452045
# Check generated code -- expect the gsave equation to be scheduled together
20462046
# in the same loop nest with the fsave equation
2047-
bns, _ = assert_blocking(op, {'x0_blk0', 'x1_blk0', 'ix0_blk0'})
2047+
bns, _ = assert_blocking(op, {'x0_blk0', 'x1_blk0', 'x2_blk0'})
20482048
exprs = FindNodes(Expression).visit(bns['x0_blk0'])
20492049
assert len(exprs) == 2
20502050
assert exprs[0].write is f
20512051
assert exprs[1].write is g
20522052

20532053
exprs = FindNodes(Expression).visit(bns['x1_blk0'])
2054+
assert len(exprs) == 1
2055+
assert exprs[0].write is h
2056+
2057+
exprs = FindNodes(Expression).visit(bns['x2_blk0'])
20542058
assert len(exprs) == 2
20552059
assert exprs[0].write is fsave
20562060
assert exprs[1].write is gsave
20572061

2058-
exprs = FindNodes(Expression).visit(bns['ix0_blk0'])
2059-
assert len(exprs) == 1
2060-
assert exprs[0].write is h
2061-
20622062
def test_topofusion_w_subdims_conddims_v2(self):
20632063
"""
20642064
Like `test_topofusion_w_subdims_conddims` but with more SubDomains,
@@ -2085,9 +2085,9 @@ def test_topofusion_w_subdims_conddims_v2(self):
20852085

20862086
# Check generated code -- expect the gsave equation to be scheduled together
20872087
# in the same loop nest with the fsave equation
2088-
bns, _ = assert_blocking(op, {'ix0_blk0', 'x0_blk0'})
2089-
assert len(FindNodes(Expression).visit(bns['ix0_blk0'])) == 3
2090-
exprs = FindNodes(Expression).visit(bns['x0_blk0'])
2088+
bns, _ = assert_blocking(op, {'x0_blk0', 'x1_blk0'})
2089+
assert len(FindNodes(Expression).visit(bns['x0_blk0'])) == 3
2090+
exprs = FindNodes(Expression).visit(bns['x1_blk0'])
20912091
assert len(exprs) == 2
20922092
assert exprs[0].write is fsave
20932093
assert exprs[1].write is gsave
@@ -2118,19 +2118,19 @@ def test_topofusion_w_subdims_conddims_v3(self):
21182118

21192119
# Check generated code -- expect the gsave equation to be scheduled together
21202120
# in the same loop nest with the fsave equation
2121-
bns, _ = assert_blocking(op, {'ix0_blk0', 'x0_blk0', 'ix1_blk0'})
2122-
exprs = FindNodes(Expression).visit(bns['ix0_blk0'])
2121+
bns, _ = assert_blocking(op, {'x0_blk0', 'x1_blk0', 'x2_blk0'})
2122+
exprs = FindNodes(Expression).visit(bns['x0_blk0'])
21232123
assert len(exprs) == 2
21242124
assert exprs[0].write is f
21252125
assert exprs[1].write is g
21262126

2127-
exprs = FindNodes(Expression).visit(bns['x0_blk0'])
2127+
exprs = FindNodes(Expression).visit(bns['x2_blk0'])
21282128
assert len(exprs) == 2
21292129
assert exprs[0].write is fsave
21302130
assert exprs[1].write is gsave
21312131

21322132
# Additional nest due to anti-dependence
2133-
exprs = FindNodes(Expression).visit(bns['ix1_blk0'])
2133+
exprs = FindNodes(Expression).visit(bns['x1_blk0'])
21342134
assert len(exprs) == 2
21352135
assert exprs[1].write is h
21362136

tests/test_dle.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,9 @@ def test_cache_blocking_structure_subdims():
130130
# Non-local SubDimension -> blocking expected
131131
op = Operator(Eq(f.forward, f.dx + 1, subdomain=grid.interior))
132132

133-
bns, _ = assert_blocking(op, {'ix0_blk0'})
133+
bns, _ = assert_blocking(op, {'x0_blk0'})
134134

135-
trees = retrieve_iteration_tree(bns['ix0_blk0'])
135+
trees = retrieve_iteration_tree(bns['x0_blk0'])
136136
tree = trees[0]
137137
assert len(tree) == 5
138138
assert tree[0].dim.is_Block and tree[0].dim.parent.name == 'ix' and\
@@ -257,7 +257,7 @@ def test_leftright_subdims(self):
257257

258258
op = Operator(eqns, opt=('fission', 'blocking', {'blockrelax': 'device-aware'}))
259259

260-
bns, _ = assert_blocking(op, {'x0_blk0', 'xl0_blk0', 'xr0_blk0'})
260+
bns, _ = assert_blocking(op, {'x0_blk0', 'x1_blk0', 'x2_blk0'})
261261
assert all(IsPerfectIteration().visit(i) for i in bns.values())
262262
assert all(len(FindNodes(Iteration).visit(i)) == 4 for i in bns.values())
263263

@@ -1362,9 +1362,9 @@ def test_nested_cache_blocking_structure_subdims(self, blocklevels):
13621362
'par-collapse-ncores': 2,
13631363
'par-dynamic-work': 0}))
13641364

1365-
bns, _ = assert_blocking(op, {'ix0_blk0'})
1365+
bns, _ = assert_blocking(op, {'x0_blk0'})
13661366

1367-
trees = retrieve_iteration_tree(bns['ix0_blk0'])
1367+
trees = retrieve_iteration_tree(bns['x0_blk0'])
13681368
assert len(trees) == 1
13691369
tree = trees[0]
13701370
assert len(tree) == 5 + (blocklevels - 1) * 2

tests/test_dse.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -599,11 +599,11 @@ def test_full_shape_w_subdims(self, rotate):
599599
'cire-rotate': rotate}))
600600

601601
# Check code generation
602-
bns, pbs = assert_blocking(op1, {'ix0_blk0'})
603-
xs, ys, zs = get_params(op1, 'ix0_blk0_size', 'iy0_blk0_size', 'z_size')
604-
arrays = [i for i in FindSymbols().visit(bns['ix0_blk0']) if i.is_Array]
602+
bns, pbs = assert_blocking(op1, {'x0_blk0'})
603+
xs, ys, zs = get_params(op1, 'x0_blk0_size', 'y0_blk0_size', 'z_size')
604+
arrays = [i for i in FindSymbols().visit(bns['x0_blk0']) if i.is_Array]
605605
assert len(arrays) == 1
606-
assert len(FindNodes(VExpanded).visit(pbs['ix0_blk0'])) == 1
606+
assert len(FindNodes(VExpanded).visit(pbs['x0_blk0'])) == 1
607607
check_array(arrays[0], ((1, 1), (1, 1), (1, 1)), (xs+2, ys+2, zs+2), rotate)
608608

609609
# Check numerical output
@@ -773,11 +773,11 @@ def test_mixed_shapes_v2_w_subdims(self, rotate):
773773
'cire-mingain': 0, 'cire-rotate': rotate}))
774774

775775
# Check code generation
776-
bns, pbs = assert_blocking(op1, {'ix0_blk0'})
777-
xs, ys, zs = get_params(op1, 'ix0_blk0_size', 'iy0_blk0_size', 'z_size')
778-
arrays = [i for i in FindSymbols().visit(bns['ix0_blk0']) if i.is_Array]
776+
bns, pbs = assert_blocking(op1, {'x0_blk0'})
777+
xs, ys, zs = get_params(op1, 'x0_blk0_size', 'y0_blk0_size', 'z_size')
778+
arrays = [i for i in FindSymbols().visit(bns['x0_blk0']) if i.is_Array]
779779
assert len(arrays) == 2
780-
assert len(FindNodes(VExpanded).visit(pbs['ix0_blk0'])) == 2
780+
assert len(FindNodes(VExpanded).visit(pbs['x0_blk0'])) == 2
781781
check_array(arrays[0], ((1, 0), (1, 0), (0, 0)), (xs+1, ys+1, zs), rotate)
782782
check_array(arrays[1], ((1, 1), (1, 0)), (ys+2, zs+1), rotate)
783783

tests/test_mpi.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1279,11 +1279,11 @@ def test_hoist_haloupdate_with_subdims(self, mode):
12791279
assert calls[1].name == 'haloupdate0'
12801280

12811281
# ... and none in the created efuncs
1282-
bns, _ = assert_blocking(op, {'ix0_blk0', 'x0_blk0'})
1283-
calls = FindNodes(Call).visit(bns['ix0_blk0'])
1284-
assert len(calls) == 0
1282+
bns, _ = assert_blocking(op, {'x0_blk0', 'x1_blk0'})
12851283
calls = FindNodes(Call).visit(bns['x0_blk0'])
12861284
assert len(calls) == 0
1285+
calls = FindNodes(Call).visit(bns['x1_blk0'])
1286+
assert len(calls) == 0
12871287

12881288
@pytest.mark.parallel(mode=1)
12891289
def test_hoist_haloupdate_from_innerloop(self, mode):

tests/test_subdomains.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -680,8 +680,8 @@ class Dummy(SubDomainSet):
680680
# Make sure it jit-compiles
681681
op.cfunction
682682

683-
assert_structure(op, ['t,n0', 't,n0,ix0_blk0,iy0_blk0,x,y,z'],
684-
't,n0,ix0_blk0,iy0_blk0,x,y,z')
683+
assert_structure(op, ['t,n0', 't,n0,x0_blk0,y0_blk0,x,y,z'],
684+
't,n0,x0_blk0,y0_blk0,x,y,z')
685685

686686
# Drag a rebuilt MultiSubDimension out of the operator
687687
dims = {d.name: d for d in FindSymbols('dimensions').visit(op)}

0 commit comments

Comments
 (0)