11# ------------- General imports -------------#
22
33from typing import Any
4+ from dataclasses import dataclass , field
45from sympy import Add , Expr , Float , Indexed , Integer , Mod , Mul , Pow , Symbol
56
67# ------------- xdsl imports -------------#
7- from xdsl .dialects import arith , builtin , func , memref , scf , stencil , gpu
8+ from xdsl .dialects import (arith , builtin , func , memref , scf ,
9+ stencil , gpu , llvm )
810from xdsl .dialects .experimental import math
911from xdsl .ir import Block , Operation , OpResult , Region , SSAValue
12+ from xdsl .pattern_rewriter import (
13+ GreedyRewritePatternApplier ,
14+ PatternRewriter ,
15+ PatternRewriteWalker ,
16+ RewritePattern ,
17+ op_type_rewrite_pattern ,
18+ )
1019
1120# ------------- devito imports -------------#
1221from devito import Grid , SteppingDimension
1322from devito .ir .equations import LoweredEq
14- from devito .symbolics import retrieve_indexed
23+ from devito .symbolics import retrieve_indexed , retrieve_function_carriers
1524from devito .logger import perf
1625
1726# ------------- devito-xdsl SSA imports -------------#
1827from devito .ir .ietxdsl import iet_ssa
28+ from devito .ir .ietxdsl .utils import is_int , is_float
1929from devito .ir .ietxdsl .ietxdsl_functions import dtypes_to_xdsltypes
30+ from devito .ir .ietxdsl .lowering import LowerIetForToScfFor
2031
2132# flake8: noqa
2233
@@ -50,34 +61,34 @@ def _convert_eq(self, eq: LoweredEq):
5061 function = eq .lhs .function
5162 mlir_type = dtypes_to_xdsltypes [function .dtype ]
5263 grid : Grid = function .grid
53- # get the halo of the space dimensions only e.g [(2, 2), (2, 2)] for the 2d case
64+
65+ # Get the halo of the grid.dimensions
66+ # e.g [(2, 2), (2, 2)] for the 2D case
5467 # Do not forget the issue with Devito adding an extra point!
68+ # Check 'def halo_setup' for more
5569 # (for derivative regions)
56- halo = [function .halo [function . dimensions . index ( d ) ] for d in grid .dimensions ]
70+ halo = [function .halo [d ] for d in grid .dimensions ]
5771
5872 # Shift all time values so that for all accesses at t + n, n>=0.
5973 self .time_offs = min (
6074 int (idx .indices [0 ] - grid .stepping_dim ) for idx in retrieve_indexed (eq )
6175 )
6276
63- # Calculate the actual size of our time dimension
64- actual_time_size = (
65- max (int (idx .indices [0 ] - grid .stepping_dim ) for idx in retrieve_indexed (eq ))
66- - self .time_offs
67- + 1
68- )
6977
78+ # Get the time_size
79+ time_size = max (d .function .time_size for d in retrieve_function_carriers (eq ))
80+
7081 # Build the for loop
7182 perf ("Build Time Loop" )
72- loop = self ._build_iet_for (grid .stepping_dim , actual_time_size )
83+ loop = self ._build_iet_for (grid .stepping_dim , time_size )
7384
7485 # build stencil
7586 perf ("Initialize a stencil Op" )
7687 stencil_op = iet_ssa .Stencil .get (
7788 loop .subindice_ssa_vals (),
7889 grid .shape_local ,
7990 halo ,
80- actual_time_size ,
91+ time_size ,
8192 mlir_type ,
8293 eq .lhs .function ._C_name ,
8394 )
@@ -87,7 +98,7 @@ def _convert_eq(self, eq: LoweredEq):
8798 # dims -> ssa vals
8899 perf ("Apply time offsets" )
89100 time_offset_to_field : dict [str , SSAValue ] = {
90- i : stencil_op .block .args [i ] for i in range (actual_time_size - 1 )
101+ i : stencil_op .block .args [i ] for i in range (time_size - 1 )
91102 }
92103
93104 # reset loaded values
@@ -103,8 +114,9 @@ def _convert_eq(self, eq: LoweredEq):
103114
104115 # emit return
105116 offsets = _get_dim_offsets (eq .lhs , self .time_offs )
117+
106118 assert (
107- offsets [0 ] == actual_time_size - 1
119+ offsets [0 ] == time_size - 1
108120 ), "result should be written to last time buffer"
109121 assert all (
110122 o == 0 for o in offsets [1 :]
@@ -118,51 +130,47 @@ def _convert_eq(self, eq: LoweredEq):
118130 )
119131
120132 def _visit_math_nodes (self , node : Expr ) -> SSAValue :
133+ # Handle Indexeds
121134 if isinstance (node , Indexed ):
122135 offsets = _get_dim_offsets (node , self .time_offs )
123136 return self .loaded_values [offsets ]
124- if isinstance (node , Integer ):
137+ # Handle Integers
138+ elif isinstance (node , Integer ):
125139 cst = arith .Constant .from_int_and_width (int (node ), builtin .i64 )
126140 self .block .add_op (cst )
127141 return cst .result
128- if isinstance (node , Float ):
142+ # Handle Floats
143+ elif isinstance (node , Float ):
129144 cst = arith .Constant .from_float_and_width (float (node ), builtin .f32 )
130145 self .block .add_op (cst )
131146 return cst .result
132- # if isinstance(math, Constant):
133- # symb = iet_ssa.LoadSymbolic.get(math.name, dtypes_to_xdsltypes[math.dtype])
134- # self.block.add_op(symb)
135- # return symb.result
136- if isinstance (node , Symbol ):
147+ # Handle Symbols
148+ elif isinstance (node , Symbol ):
137149 symb = iet_ssa .LoadSymbolic .get (node .name , builtin .f32 )
138150 self .block .add_op (symb )
139- return symb .result
140-
141- # handle all of the math
142- if not isinstance (node , (Add , Mul , Pow , Mod )):
143- raise ValueError (f"Unknown math: { node } " , node )
144-
145- args = [self ._visit_math_nodes (arg ) for arg in node .args ]
146-
147- # make sure all args are the same type:
148- if isinstance (node , (Add , Mul )):
151+ return symb .result
152+ # Handle Add Mul
153+ elif isinstance (node , (Add , Mul )):
154+ args = [self ._visit_math_nodes (arg ) for arg in node .args ]
149155 # add casts when necessary
150156 # get first element out, store the rest in args
151157 # this makes the reduction easier
152158 carry , * args = self ._ensure_same_type (* args )
153159 # select the correct op from arith.addi, arith.addf, arith.muli, arith.mulf
154160 if isinstance (carry .type , builtin .IntegerType ):
155161 op_cls = arith .Addi if isinstance (node , Add ) else arith .Muli
156- else :
162+ elif isinstance ( carry . type , builtin . Float32Type ) :
157163 op_cls = arith .Addf if isinstance (node , Add ) else arith .Mulf
158-
164+ else :
165+ raise ("Add support for another type" )
159166 for arg in args :
160167 op = op_cls (carry , arg )
161168 self .block .add_op (op )
162169 carry = op .result
163170 return carry
164-
165- if isinstance (node , Pow ):
171+ # Handle Pow
172+ elif isinstance (node , Pow ):
173+ args = [self ._visit_math_nodes (arg ) for arg in node .args ]
166174 assert len (args ) == 2 , "can't pow with != 2 args!"
167175 base , ex = args
168176 if is_int (base ):
@@ -183,11 +191,12 @@ def _visit_math_nodes(self, node: Expr) -> SSAValue:
183191 op = op_cls .get (base , ex )
184192 self .block .add_op (op )
185193 return op .result
194+ # Handle Mod
195+ elif isinstance (node , Mod ):
196+ raise NotImplementedError ("Go away, no mod here. >:(" )
197+ else :
198+ raise NotImplementedError (f"Unknown math: { node } " , node )
186199
187- if isinstance (node , Mod ):
188- raise ValueError ("Go away, no mod here. >:(" )
189-
190- raise ValueError ("Unknown math!" )
191200
192201 def _add_access_ops (
193202 self , reads : list [Indexed ], time_offset_to_field : dict [int , SSAValue ]
@@ -202,10 +211,11 @@ def _add_access_ops(
202211 """
203212 # get the compile time constant offsets for this read
204213 offsets = _get_dim_offsets (read , self .time_offs )
214+
205215 if offsets in self .loaded_values :
206216 continue
207217
208- # assume time dimension is first dimension
218+ # Assume time dimension is first dimension
209219 t_offset = offsets [0 ]
210220 space_offsets = offsets [1 :]
211221
@@ -251,10 +261,10 @@ def _ensure_same_type(self, *vals: SSAValue):
251261 if all (is_float (val ) for val in vals ):
252262 return vals
253263 # not everything homogeneous
254- new_vals = []
264+ processed = []
255265 for val in vals :
256266 if is_float (val ):
257- new_vals .append (val )
267+ processed .append (val )
258268 continue
259269 # if the val is the result of a arith.constant with no uses,
260270 # we change the type of the arith.constant to our desired type
@@ -267,19 +277,20 @@ def _ensure_same_type(self, *vals: SSAValue):
267277 val .op .attributes ["value" ] = builtin .FloatAttr (
268278 float (val .op .value .value .data ), builtin .f32
269279 )
270- new_vals .append (val )
280+ processed .append (val )
271281 continue
272282 # insert an integer to float cast op
273283 conv = arith .SIToFPOp (val , builtin .f32 )
274284 self .block .add_op (conv )
275- new_vals .append (conv .result )
276- return new_vals
285+ processed .append (conv .result )
286+ return processed
277287
278288
279289def _get_dim_offsets (idx : Indexed , t_offset : int ) -> tuple :
280290 # shift all time values so that for all accesses at t + n, n>=0.
281291 # time_offs = min(int(i - d) for i, d in zip(idx.indices, idx.function.dimensions))
282292 halo = ((t_offset , 0 ), * idx .function .halo [1 :])
293+
283294 try :
284295 return tuple (
285296 int (i - d - halo_offset )
@@ -291,36 +302,13 @@ def _get_dim_offsets(idx: Indexed, t_offset: int) -> tuple:
291302 raise ValueError ("Indices must be constant offset from dimension!" ) from ex
292303
293304
294- def is_int (val : SSAValue ):
295- return isinstance (val .type , builtin .IntegerType )
296-
297-
298- def is_float (val : SSAValue ):
299- return val .type in (builtin .f32 , builtin .f64 )
300-
301305
302306# -------------------------------------------------------- ####
303307# ####
304308# devito.stencil ---> stencil dialect ####
305309# ####
306310# -------------------------------------------------------- ####
307311
308- from dataclasses import dataclass , field
309-
310- from xdsl .pattern_rewriter import (
311- GreedyRewritePatternApplier ,
312- PatternRewriter ,
313- PatternRewriteWalker ,
314- RewritePattern ,
315- op_type_rewrite_pattern ,
316- )
317-
318- from devito .ir .ietxdsl .lowering import (
319- LowerIetForToScfFor ,
320- )
321-
322- from xdsl .dialects import llvm
323-
324312@dataclass
325313class WrapFunctionWithTransfers (RewritePattern ):
326314 func_name : str
@@ -545,8 +533,10 @@ def finalize_module_with_globals(module: builtin.ModuleOp, known_symbols: dict[s
545533 _InsertSymbolicConstants (known_symbols ),
546534 _LowerLoadSymbolidToFuncArgs (),
547535 ]
548- grpa = GreedyRewritePatternApplier (patterns )
549- PatternRewriteWalker (grpa ).rewrite_module (module )
536+ rewriter = GreedyRewritePatternApplier (patterns )
537+ PatternRewriteWalker (rewriter ).rewrite_module (module )
538+
539+ # GPU boilerplate
550540 if gpu_boilerplate :
551541 walker = PatternRewriteWalker (GreedyRewritePatternApplier ([WrapFunctionWithTransfers ('apply_kernel' )]))
552542 walker .rewrite_module (module )
0 commit comments