Skip to content

Commit 107aefd

Browse files
author
George Bisbas
authored
Merge pull request #41 from xdslproject/compiler_more
compiler: more cleanup
2 parents 68e9eaf + 97faa07 commit 107aefd

File tree

10 files changed

+201
-1194
lines changed

10 files changed

+201
-1194
lines changed

devito/core/cpu.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,3 +311,32 @@ class Cpu64FsgCOperator(Cpu64FsgOperator):
311311

312312
class Cpu64FsgOmpOperator(Cpu64FsgOperator):
313313
_Target = OmpTarget
314+
315+
316+
# -----------XDSL
317+
# This is a collection of xDSL optimization pipelines
318+
# Ideally they should follow the same type of subclassing as the rest of
319+
# the Devito Operatos
320+
321+
322+
MLIR_CPU_PIPELINE = '"builtin.module(canonicalize, cse, loop-invariant-code-motion, canonicalize, cse, loop-invariant-code-motion, cse, canonicalize, fold-memref-alias-ops, expand-strided-metadata, loop-invariant-code-motion, lower-affine, convert-scf-to-cf, convert-math-to-llvm, convert-func-to-llvm{use-bare-ptr-memref-call-conv}, finalize-memref-to-llvm, canonicalize, cse)"' # noqa
323+
324+
MLIR_OPENMP_PIPELINE = '"builtin.module(canonicalize, cse, loop-invariant-code-motion, canonicalize, cse, loop-invariant-code-motion,cse,canonicalize,fold-memref-alias-ops,expand-strided-metadata, loop-invariant-code-motion,lower-affine,finalize-memref-to-llvm,loop-invariant-code-motion,canonicalize,cse,convert-scf-to-openmp,finalize-memref-to-llvm,convert-scf-to-cf,convert-func-to-llvm{use-bare-ptr-memref-call-conv},convert-openmp-to-llvm,convert-math-to-llvm,reconcile-unrealized-casts,canonicalize,cse)"' # noqa
325+
# gpu-launch-sink-index-computations seemed to have no impact
326+
MLIR_GPU_PIPELINE = lambda block_sizes: f'"builtin.module(test-math-algebraic-simplification,scf-parallel-loop-tiling{{parallel-loop-tile-sizes={block_sizes}}},func.func(gpu-map-parallel-loops),convert-parallel-loops-to-gpu,lower-affine, canonicalize,cse, fold-memref-alias-ops, gpu-launch-sink-index-computations, gpu-kernel-outlining, canonicalize{{region-simplify}},cse,fold-memref-alias-ops,expand-strided-metadata,lower-affine,canonicalize,cse,func.func(gpu-async-region),canonicalize,cse,convert-arith-to-llvm{{index-bitwidth=64}},convert-scf-to-cf,convert-cf-to-llvm{{index-bitwidth=64}},canonicalize,cse,convert-func-to-llvm{{use-bare-ptr-memref-call-conv}},gpu.module(convert-gpu-to-nvvm,reconcile-unrealized-casts,canonicalize,gpu-to-cubin),gpu-to-llvm,canonicalize,cse)"' # noqa
327+
328+
XDSL_CPU_PIPELINE = lambda nb_tiled_dims: f'"stencil-shape-inference,convert-stencil-to-ll-mlir{{{generate_tiling_arg(nb_tiled_dims)}}},printf-to-llvm"' # noqa
329+
330+
XDSL_GPU_PIPELINE = "stencil-shape-inference,convert-stencil-to-ll-mlir{target=gpu},reconcile-unrealized-casts,printf-to-llvm" # noqa
331+
332+
XDSL_MPI_PIPELINE = lambda decomp, nb_tiled_dims: f'"dmp-decompose{decomp},canonicalize-dmp,convert-stencil-to-ll-mlir{{{generate_tiling_arg(nb_tiled_dims)}}},dmp-to-mpi{{mpi_init=false}},lower-mpi,printf-to-llvm"' # noqa
333+
334+
335+
def generate_tiling_arg(nb_tiled_dims: int):
336+
"""
337+
Generate the tile-sizes arg for the convert-stencil-to-ll-mlir pass.
338+
Generating no argument if the diled_dims arg is 0
339+
"""
340+
if nb_tiled_dims == 0:
341+
return ''
342+
return "tile-sizes=" + ",".join(["64"]*nb_tiled_dims)

devito/ir/ietxdsl/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,5 @@
1-
from devito.ir.ietxdsl.lowering import LowerIetForToScfFor, LowerIetForToScfParallel, DropIetComments, iet_to_standard_mlir # noqa
2-
from devito.ir.ietxdsl.cluster_to_ssa import finalize_module_with_globals, convert_devito_stencil_to_xdsl_stencil # noqa
1+
from devito.ir.ietxdsl.lowering import (LowerIetForToScfFor, LowerIetForToScfParallel)
2+
from devito.ir.ietxdsl.cluster_to_ssa import (finalize_module_with_globals,
3+
convert_devito_stencil_to_xdsl_stencil)
4+
5+
# flake8: noqa

devito/ir/ietxdsl/cluster_to_ssa.py

Lines changed: 60 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,33 @@
11
# ------------- General imports -------------#
22

33
from typing import Any
4+
from dataclasses import dataclass, field
45
from sympy import Add, Expr, Float, Indexed, Integer, Mod, Mul, Pow, Symbol
56

67
# ------------- xdsl imports -------------#
7-
from xdsl.dialects import arith, builtin, func, memref, scf, stencil, gpu
8+
from xdsl.dialects import (arith, builtin, func, memref, scf,
9+
stencil, gpu, llvm)
810
from xdsl.dialects.experimental import math
911
from xdsl.ir import Block, Operation, OpResult, Region, SSAValue
12+
from xdsl.pattern_rewriter import (
13+
GreedyRewritePatternApplier,
14+
PatternRewriter,
15+
PatternRewriteWalker,
16+
RewritePattern,
17+
op_type_rewrite_pattern,
18+
)
1019

1120
# ------------- devito imports -------------#
1221
from devito import Grid, SteppingDimension
1322
from devito.ir.equations import LoweredEq
14-
from devito.symbolics import retrieve_indexed
23+
from devito.symbolics import retrieve_indexed, retrieve_function_carriers
1524
from devito.logger import perf
1625

1726
# ------------- devito-xdsl SSA imports -------------#
1827
from devito.ir.ietxdsl import iet_ssa
28+
from devito.ir.ietxdsl.utils import is_int, is_float
1929
from devito.ir.ietxdsl.ietxdsl_functions import dtypes_to_xdsltypes
30+
from devito.ir.ietxdsl.lowering import LowerIetForToScfFor
2031

2132
# flake8: noqa
2233

@@ -50,34 +61,34 @@ def _convert_eq(self, eq: LoweredEq):
5061
function = eq.lhs.function
5162
mlir_type = dtypes_to_xdsltypes[function.dtype]
5263
grid: Grid = function.grid
53-
# get the halo of the space dimensions only e.g [(2, 2), (2, 2)] for the 2d case
64+
65+
# Get the halo of the grid.dimensions
66+
# e.g [(2, 2), (2, 2)] for the 2D case
5467
# Do not forget the issue with Devito adding an extra point!
68+
# Check 'def halo_setup' for more
5569
# (for derivative regions)
56-
halo = [function.halo[function.dimensions.index(d)] for d in grid.dimensions]
70+
halo = [function.halo[d] for d in grid.dimensions]
5771

5872
# Shift all time values so that for all accesses at t + n, n>=0.
5973
self.time_offs = min(
6074
int(idx.indices[0] - grid.stepping_dim) for idx in retrieve_indexed(eq)
6175
)
6276

63-
# Calculate the actual size of our time dimension
64-
actual_time_size = (
65-
max(int(idx.indices[0] - grid.stepping_dim) for idx in retrieve_indexed(eq))
66-
- self.time_offs
67-
+ 1
68-
)
6977

78+
# Get the time_size
79+
time_size = max(d.function.time_size for d in retrieve_function_carriers(eq))
80+
7081
# Build the for loop
7182
perf("Build Time Loop")
72-
loop = self._build_iet_for(grid.stepping_dim, actual_time_size)
83+
loop = self._build_iet_for(grid.stepping_dim, time_size)
7384

7485
# build stencil
7586
perf("Initialize a stencil Op")
7687
stencil_op = iet_ssa.Stencil.get(
7788
loop.subindice_ssa_vals(),
7889
grid.shape_local,
7990
halo,
80-
actual_time_size,
91+
time_size,
8192
mlir_type,
8293
eq.lhs.function._C_name,
8394
)
@@ -87,7 +98,7 @@ def _convert_eq(self, eq: LoweredEq):
8798
# dims -> ssa vals
8899
perf("Apply time offsets")
89100
time_offset_to_field: dict[str, SSAValue] = {
90-
i: stencil_op.block.args[i] for i in range(actual_time_size - 1)
101+
i: stencil_op.block.args[i] for i in range(time_size - 1)
91102
}
92103

93104
# reset loaded values
@@ -103,8 +114,9 @@ def _convert_eq(self, eq: LoweredEq):
103114

104115
# emit return
105116
offsets = _get_dim_offsets(eq.lhs, self.time_offs)
117+
106118
assert (
107-
offsets[0] == actual_time_size - 1
119+
offsets[0] == time_size - 1
108120
), "result should be written to last time buffer"
109121
assert all(
110122
o == 0 for o in offsets[1:]
@@ -118,51 +130,47 @@ def _convert_eq(self, eq: LoweredEq):
118130
)
119131

120132
def _visit_math_nodes(self, node: Expr) -> SSAValue:
133+
# Handle Indexeds
121134
if isinstance(node, Indexed):
122135
offsets = _get_dim_offsets(node, self.time_offs)
123136
return self.loaded_values[offsets]
124-
if isinstance(node, Integer):
137+
# Handle Integers
138+
elif isinstance(node, Integer):
125139
cst = arith.Constant.from_int_and_width(int(node), builtin.i64)
126140
self.block.add_op(cst)
127141
return cst.result
128-
if isinstance(node, Float):
142+
# Handle Floats
143+
elif isinstance(node, Float):
129144
cst = arith.Constant.from_float_and_width(float(node), builtin.f32)
130145
self.block.add_op(cst)
131146
return cst.result
132-
# if isinstance(math, Constant):
133-
# symb = iet_ssa.LoadSymbolic.get(math.name, dtypes_to_xdsltypes[math.dtype])
134-
# self.block.add_op(symb)
135-
# return symb.result
136-
if isinstance(node, Symbol):
147+
# Handle Symbols
148+
elif isinstance(node, Symbol):
137149
symb = iet_ssa.LoadSymbolic.get(node.name, builtin.f32)
138150
self.block.add_op(symb)
139-
return symb.result
140-
141-
# handle all of the math
142-
if not isinstance(node, (Add, Mul, Pow, Mod)):
143-
raise ValueError(f"Unknown math: {node}", node)
144-
145-
args = [self._visit_math_nodes(arg) for arg in node.args]
146-
147-
# make sure all args are the same type:
148-
if isinstance(node, (Add, Mul)):
151+
return symb.result
152+
# Handle Add Mul
153+
elif isinstance(node, (Add, Mul)):
154+
args = [self._visit_math_nodes(arg) for arg in node.args]
149155
# add casts when necessary
150156
# get first element out, store the rest in args
151157
# this makes the reduction easier
152158
carry, *args = self._ensure_same_type(*args)
153159
# select the correct op from arith.addi, arith.addf, arith.muli, arith.mulf
154160
if isinstance(carry.type, builtin.IntegerType):
155161
op_cls = arith.Addi if isinstance(node, Add) else arith.Muli
156-
else:
162+
elif isinstance(carry.type, builtin.Float32Type):
157163
op_cls = arith.Addf if isinstance(node, Add) else arith.Mulf
158-
164+
else:
165+
raise("Add support for another type")
159166
for arg in args:
160167
op = op_cls(carry, arg)
161168
self.block.add_op(op)
162169
carry = op.result
163170
return carry
164-
165-
if isinstance(node, Pow):
171+
# Handle Pow
172+
elif isinstance(node, Pow):
173+
args = [self._visit_math_nodes(arg) for arg in node.args]
166174
assert len(args) == 2, "can't pow with != 2 args!"
167175
base, ex = args
168176
if is_int(base):
@@ -183,11 +191,12 @@ def _visit_math_nodes(self, node: Expr) -> SSAValue:
183191
op = op_cls.get(base, ex)
184192
self.block.add_op(op)
185193
return op.result
194+
# Handle Mod
195+
elif isinstance(node, Mod):
196+
raise NotImplementedError("Go away, no mod here. >:(")
197+
else:
198+
raise NotImplementedError(f"Unknown math: {node}", node)
186199

187-
if isinstance(node, Mod):
188-
raise ValueError("Go away, no mod here. >:(")
189-
190-
raise ValueError("Unknown math!")
191200

192201
def _add_access_ops(
193202
self, reads: list[Indexed], time_offset_to_field: dict[int, SSAValue]
@@ -202,10 +211,11 @@ def _add_access_ops(
202211
"""
203212
# get the compile time constant offsets for this read
204213
offsets = _get_dim_offsets(read, self.time_offs)
214+
205215
if offsets in self.loaded_values:
206216
continue
207217

208-
# assume time dimension is first dimension
218+
# Assume time dimension is first dimension
209219
t_offset = offsets[0]
210220
space_offsets = offsets[1:]
211221

@@ -251,10 +261,10 @@ def _ensure_same_type(self, *vals: SSAValue):
251261
if all(is_float(val) for val in vals):
252262
return vals
253263
# not everything homogeneous
254-
new_vals = []
264+
processed = []
255265
for val in vals:
256266
if is_float(val):
257-
new_vals.append(val)
267+
processed.append(val)
258268
continue
259269
# if the val is the result of a arith.constant with no uses,
260270
# we change the type of the arith.constant to our desired type
@@ -267,19 +277,20 @@ def _ensure_same_type(self, *vals: SSAValue):
267277
val.op.attributes["value"] = builtin.FloatAttr(
268278
float(val.op.value.value.data), builtin.f32
269279
)
270-
new_vals.append(val)
280+
processed.append(val)
271281
continue
272282
# insert an integer to float cast op
273283
conv = arith.SIToFPOp(val, builtin.f32)
274284
self.block.add_op(conv)
275-
new_vals.append(conv.result)
276-
return new_vals
285+
processed.append(conv.result)
286+
return processed
277287

278288

279289
def _get_dim_offsets(idx: Indexed, t_offset: int) -> tuple:
280290
# shift all time values so that for all accesses at t + n, n>=0.
281291
# time_offs = min(int(i - d) for i, d in zip(idx.indices, idx.function.dimensions))
282292
halo = ((t_offset, 0), *idx.function.halo[1:])
293+
283294
try:
284295
return tuple(
285296
int(i - d - halo_offset)
@@ -291,36 +302,13 @@ def _get_dim_offsets(idx: Indexed, t_offset: int) -> tuple:
291302
raise ValueError("Indices must be constant offset from dimension!") from ex
292303

293304

294-
def is_int(val: SSAValue):
295-
return isinstance(val.type, builtin.IntegerType)
296-
297-
298-
def is_float(val: SSAValue):
299-
return val.type in (builtin.f32, builtin.f64)
300-
301305

302306
# -------------------------------------------------------- ####
303307
# ####
304308
# devito.stencil ---> stencil dialect ####
305309
# ####
306310
# -------------------------------------------------------- ####
307311

308-
from dataclasses import dataclass, field
309-
310-
from xdsl.pattern_rewriter import (
311-
GreedyRewritePatternApplier,
312-
PatternRewriter,
313-
PatternRewriteWalker,
314-
RewritePattern,
315-
op_type_rewrite_pattern,
316-
)
317-
318-
from devito.ir.ietxdsl.lowering import (
319-
LowerIetForToScfFor,
320-
)
321-
322-
from xdsl.dialects import llvm
323-
324312
@dataclass
325313
class WrapFunctionWithTransfers(RewritePattern):
326314
func_name: str
@@ -545,8 +533,10 @@ def finalize_module_with_globals(module: builtin.ModuleOp, known_symbols: dict[s
545533
_InsertSymbolicConstants(known_symbols),
546534
_LowerLoadSymbolidToFuncArgs(),
547535
]
548-
grpa = GreedyRewritePatternApplier(patterns)
549-
PatternRewriteWalker(grpa).rewrite_module(module)
536+
rewriter = GreedyRewritePatternApplier(patterns)
537+
PatternRewriteWalker(rewriter).rewrite_module(module)
538+
539+
# GPU boilerplate
550540
if gpu_boilerplate:
551541
walker = PatternRewriteWalker(GreedyRewritePatternApplier([WrapFunctionWithTransfers('apply_kernel')]))
552542
walker.rewrite_module(module)

0 commit comments

Comments
 (0)