Skip to content

Commit f007d30

Browse files
committed
compiler: Add VectorAccess
1 parent 9d20f34 commit f007d30

File tree

3 files changed

+69
-3
lines changed

3 files changed

+69
-3
lines changed

devito/ir/cgen/printer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,7 @@ def _print_Fallback(self, expr):
401401
_print_IndexSum = _print_Fallback
402402
_print_ReservedWord = _print_Fallback
403403
_print_Basic = _print_Fallback
404+
_print_VectorAccess = _print_Fallback
404405

405406

406407
# Lifted from SymPy so that we go through our own `_print_math_func`

devito/symbolics/extended_sympy.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import numpy as np
77
import sympy
8-
from sympy import Expr, Function, Number, Tuple, sympify
8+
from sympy import Expr, Function, Number, Tuple, cacheit, sympify
99
from sympy.core.decorators import call_highest_priority
1010

1111
from devito.finite_differences.elementary import Min, Max
@@ -21,7 +21,8 @@
2121
'ListInitializer', 'Byref', 'IndexedPointer', 'Cast', 'DefFunction',
2222
'MathFunction', 'InlineIf', 'ReservedWord', 'Keyword', 'String',
2323
'Macro', 'Class', 'MacroArgument', 'Deref', 'Namespace',
24-
'Rvalue', 'Null', 'SizeOf', 'rfunc', 'BasicWrapperMixin', 'ValueLimit']
24+
'Rvalue', 'Null', 'SizeOf', 'rfunc', 'BasicWrapperMixin', 'ValueLimit',
25+
'VectorAccess']
2526

2627

2728
class CondEq(sympy.Eq):
@@ -793,6 +794,51 @@ def __str__(self):
793794
__repr__ = __str__
794795

795796

797+
class VectorAccess(Expr, Pickable):
798+
799+
"""
800+
Represent a vector access operation at high-level.
801+
"""
802+
803+
def __new__(cls, *args, **kwargs):
804+
return Expr.__new__(cls, *args)
805+
806+
def __str__(self):
807+
return f"VL4<{self.base}>"
808+
809+
__repr__ = __str__
810+
811+
func = Pickable._rebuild
812+
813+
def _sympystr(self, printer):
814+
return str(self)
815+
816+
@property
817+
def base(self):
818+
return self.args[0]
819+
820+
@property
821+
def function(self):
822+
return self.base.function
823+
824+
@property
825+
def indices(self):
826+
return self.base.indices
827+
828+
@property
829+
def dtype(self):
830+
return self.function.dtype
831+
832+
@cacheit
833+
def sort_key(self, order=None):
834+
# Ensure that the VectorAccess is sorted as the base
835+
return self.base.sort_key(order=order)
836+
837+
# Default assumptions correspond to those of the `base`
838+
for i in ('is_real', 'is_imaginary', 'is_commutative'):
839+
locals()[i] = property(lambda self, v=i: getattr(self.base, v))
840+
841+
796842
# Some other utility objects
797843
Null = Macro('NULL')
798844

tests/test_symbolics.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
CallFromPointer, Cast, DefFunction, FieldFromPointer,
1515
INT, FieldFromComposite, IntDiv, Namespace, Rvalue,
1616
ReservedWord, ListInitializer, uxreplace, pow_to_mul,
17-
retrieve_derivatives, BaseCast, SizeOf)
17+
retrieve_derivatives, BaseCast, SizeOf, VectorAccess)
1818
from devito.tools import as_tuple, CustomDtype
1919
from devito.types import (Array, Bundle, FIndexed, LocalObject, Object,
2020
ComponentAccess, StencilDimension, Symbol as dSymbol)
@@ -501,6 +501,25 @@ def test_component_access():
501501
assert cf2 == cf1
502502

503503

504+
def test_vector_access():
505+
grid = Grid(shape=(3, 3, 3))
506+
507+
f = Function(name='f', grid=grid)
508+
g = Function(name='g', grid=grid)
509+
510+
v = VectorAccess(f.indexify())
511+
512+
assert v.base == f.indexify()
513+
assert v.function is f
514+
515+
# Code generation
516+
assert ccode(v) == 'VL4<f[x, y, z]>'
517+
518+
# Reconstruction
519+
v1 = v.func(g.indexify())
520+
assert ccode(v1) == 'VL4<g[x, y, z]>'
521+
522+
504523
def test_canonical_ordering_of_weights():
505524
grid = Grid(shape=(3, 3, 3))
506525
x, y, z = grid.dimensions

0 commit comments

Comments
 (0)