Skip to content

Commit 026eaf8

Browse files
committed
compiler: Introduce extract_dtype
1 parent 9db4915 commit 026eaf8

File tree

3 files changed

+13
-8
lines changed

3 files changed

+13
-8
lines changed

devito/finite_differences/differentiable.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from devito.finite_differences.tools import make_shift_x0, coeff_priority
1818
from devito.logger import warning
1919
from devito.tools import (as_tuple, filter_ordered, flatten, frozendict,
20-
infer_dtype, is_integer, split, is_number)
20+
infer_dtype, extract_dtype, is_integer, split, is_number)
2121
from devito.types import Array, DimensionTuple, Evaluable, StencilDimension
2222
from devito.types.basic import AbstractFunction
2323

@@ -665,8 +665,7 @@ class RealComplexPart(ComplexPart):
665665

666666
@cached_property
667667
def dtype(self):
668-
dtypes = {getattr(e, 'dtype', None) for e in self.free_symbols}
669-
dtype = infer_dtype(dtypes - {None})
668+
dtype = extract_dtype(self)
670669
return dtype(0).real.__class__
671670

672671

devito/passes/clusters/cse.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from devito.ir import Cluster, Scope, cluster_pass
1515
from devito.symbolics import estimate_cost, q_leaf, q_terminal
1616
from devito.symbolics.manipulation import _uxreplace
17-
from devito.tools import DAG, as_list, as_tuple, frozendict, infer_dtype
17+
from devito.tools import DAG, as_list, as_tuple, frozendict, extract_dtype
1818
from devito.types import Eq, Symbol, Temp
1919

2020
__all__ = ['cse']
@@ -278,9 +278,7 @@ def expr(self):
278278

279279
@property
280280
def dtype(self):
281-
dtypes = {getattr(e, 'dtype', None)
282-
for e in self.expr.free_symbols}
283-
return infer_dtype(dtypes - {None})
281+
return extract_dtype(self.expr)
284282

285283
@property
286284
def conditionals(self):

devito/tools/dtypes_lowering.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
'double3', 'double4', 'dtypes_vector_mapper', 'dtype_to_mpidtype',
1616
'dtype_to_cstr', 'dtype_to_ctype', 'infer_datasize', 'dtype_to_mpitype',
1717
'dtype_len', 'ctypes_to_cstr', 'c_restrict_void_p', 'ctypes_vector_mapper',
18-
'is_external_ctype', 'infer_dtype', 'CustomDtype', 'mpi4py_mapper']
18+
'is_external_ctype', 'infer_dtype', 'extract_dtype', 'CustomDtype',
19+
'mpi4py_mapper']
1920

2021

2122
# *** Custom np.dtypes
@@ -365,3 +366,10 @@ def infer_dtype(dtypes):
365366
else:
366367
# E.g., mixed integer arithmetic
367368
return max(dtypes, key=lambda i: np.dtype(i).itemsize, default=None)
369+
370+
371+
def extract_dtype(expr):
372+
"""Extract the "winning" dtype from an expression"""
373+
dtypes = {getattr(e, 'dtype', None)
374+
for e in expr.free_symbols}
375+
return infer_dtype(dtypes - {None})

0 commit comments

Comments
 (0)