11# ------------- General imports -------------#
22
33from typing import Any
4+ from dataclasses import dataclass , field
45from sympy import Add , Expr , Float , Indexed , Integer , Mod , Mul , Pow , Symbol
56
67# ------------- xdsl imports -------------#
7- from xdsl .dialects import arith , builtin , func , memref , scf , stencil , gpu
8+ from xdsl .dialects import (arith , builtin , func , memref , scf ,
9+ stencil , gpu , llvm )
810from xdsl .dialects .experimental import math
911from xdsl .ir import Block , Operation , OpResult , Region , SSAValue
12+ from xdsl .pattern_rewriter import (
13+ GreedyRewritePatternApplier ,
14+ PatternRewriter ,
15+ PatternRewriteWalker ,
16+ RewritePattern ,
17+ op_type_rewrite_pattern ,
18+ )
1019
1120# ------------- devito imports -------------#
1221from devito import Grid , SteppingDimension
1827from devito .ir .ietxdsl import iet_ssa
1928from devito .ir .ietxdsl .utils import is_int , is_float
2029from devito .ir .ietxdsl .ietxdsl_functions import dtypes_to_xdsltypes
21-
30+ from devito . ir . ietxdsl . lowering import LowerIetForToScfFor
2231
2332# flake8: noqa
2433
@@ -121,34 +130,28 @@ def _convert_eq(self, eq: LoweredEq):
121130 )
122131
123132 def _visit_math_nodes (self , node : Expr ) -> SSAValue :
133+ # Handle Indexeds
124134 if isinstance (node , Indexed ):
125135 offsets = _get_dim_offsets (node , self .time_offs )
126136 return self .loaded_values [offsets ]
127- if isinstance (node , Integer ):
137+ # Handle Integers
138+ elif isinstance (node , Integer ):
128139 cst = arith .Constant .from_int_and_width (int (node ), builtin .i64 )
129140 self .block .add_op (cst )
130141 return cst .result
131- if isinstance (node , Float ):
142+ # Handle Floats
143+ elif isinstance (node , Float ):
132144 cst = arith .Constant .from_float_and_width (float (node ), builtin .f32 )
133145 self .block .add_op (cst )
134146 return cst .result
135- # if isinstance(math, Constant):
136- # symb = iet_ssa.LoadSymbolic.get(math.name, dtypes_to_xdsltypes[math.dtype])
137- # self.block.add_op(symb)
138- # return symb.result
139- if isinstance (node , Symbol ):
147+ # Handle Symbols
148+ elif isinstance (node , Symbol ):
140149 symb = iet_ssa .LoadSymbolic .get (node .name , builtin .f32 )
141150 self .block .add_op (symb )
142- return symb .result
143-
144- # handle all of the math
145- if not isinstance (node , (Add , Mul , Pow , Mod )):
146- raise ValueError (f"Unknown math: { node } " , node )
147-
148- args = [self ._visit_math_nodes (arg ) for arg in node .args ]
149-
150- # make sure all args are the same type:
151- if isinstance (node , (Add , Mul )):
151+ return symb .result
152+ # Handle Add Mul
153+ elif isinstance (node , (Add , Mul )):
154+ args = [self ._visit_math_nodes (arg ) for arg in node .args ]
152155 # add casts when necessary
153156 # get first element out, store the rest in args
154157 # this makes the reduction easier
@@ -160,14 +163,14 @@ def _visit_math_nodes(self, node: Expr) -> SSAValue:
160163 op_cls = arith .Addf if isinstance (node , Add ) else arith .Mulf
161164 else :
162165 raise ("Add support for another type" )
163-
164166 for arg in args :
165167 op = op_cls (carry , arg )
166168 self .block .add_op (op )
167169 carry = op .result
168170 return carry
169-
170- if isinstance (node , Pow ):
171+ # Handle Pow
172+ elif isinstance (node , Pow ):
173+ args = [self ._visit_math_nodes (arg ) for arg in node .args ]
171174 assert len (args ) == 2 , "can't pow with != 2 args!"
172175 base , ex = args
173176 if is_int (base ):
@@ -188,11 +191,12 @@ def _visit_math_nodes(self, node: Expr) -> SSAValue:
188191 op = op_cls .get (base , ex )
189192 self .block .add_op (op )
190193 return op .result
194+ # Handle Mod
195+ elif isinstance (node , Mod ):
196+ raise NotImplementedError ("Go away, no mod here. >:(" )
197+ else :
198+ raise NotImplementedError (f"Unknown math: { node } " , node )
191199
192- if isinstance (node , Mod ):
193- raise ValueError ("Go away, no mod here. >:(" )
194-
195- raise ValueError ("Unknown math!" )
196200
197201 def _add_access_ops (
198202 self , reads : list [Indexed ], time_offset_to_field : dict [int , SSAValue ]
@@ -257,10 +261,10 @@ def _ensure_same_type(self, *vals: SSAValue):
257261 if all (is_float (val ) for val in vals ):
258262 return vals
259263 # not everything homogeneous
260- new_vals = []
264+ processed = []
261265 for val in vals :
262266 if is_float (val ):
263- new_vals .append (val )
267+ processed .append (val )
264268 continue
265269 # if the val is the result of a arith.constant with no uses,
266270 # we change the type of the arith.constant to our desired type
@@ -273,13 +277,13 @@ def _ensure_same_type(self, *vals: SSAValue):
273277 val .op .attributes ["value" ] = builtin .FloatAttr (
274278 float (val .op .value .value .data ), builtin .f32
275279 )
276- new_vals .append (val )
280+ processed .append (val )
277281 continue
278282 # insert an integer to float cast op
279283 conv = arith .SIToFPOp (val , builtin .f32 )
280284 self .block .add_op (conv )
281- new_vals .append (conv .result )
282- return new_vals
285+ processed .append (conv .result )
286+ return processed
283287
284288
285289def _get_dim_offsets (idx : Indexed , t_offset : int ) -> tuple :
@@ -305,22 +309,6 @@ def _get_dim_offsets(idx: Indexed, t_offset: int) -> tuple:
305309# ####
306310# -------------------------------------------------------- ####
307311
308- from dataclasses import dataclass , field
309-
310- from xdsl .pattern_rewriter import (
311- GreedyRewritePatternApplier ,
312- PatternRewriter ,
313- PatternRewriteWalker ,
314- RewritePattern ,
315- op_type_rewrite_pattern ,
316- )
317-
318- from devito .ir .ietxdsl .lowering import (
319- LowerIetForToScfFor ,
320- )
321-
322- from xdsl .dialects import llvm
323-
324312@dataclass
325313class WrapFunctionWithTransfers (RewritePattern ):
326314 func_name : str
0 commit comments