Skip to content

Commit 9a3ff77

Browse files
committed
tests: Fix memory estimate with overrides
1 parent 1a1e406 commit 9a3ff77

File tree

2 files changed

+122
-36
lines changed

2 files changed

+122
-36
lines changed

devito/operator/operator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1373,9 +1373,10 @@ def get_nbytes(obj):
13731373
# FIXME: Probably wrong for streamed functions
13741374
# Will overreport memory usage currently
13751375
try:
1376+
# TODO: is _obj even needed?
13761377
v = get_nbytes(self[i.name]._obj)
13771378
except AttributeError:
1378-
v = get_nbytes(i)
1379+
v = get_nbytes(self.get(i.name, i))
13791380

13801381
if i._mem_host:
13811382
host += v

tests/test_operator.py

Lines changed: 120 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
SparseFunction, SparseTimeFunction, Dimension, error, SpaceDimension,
1313
NODE, CELL, dimensions, configuration, TensorFunction,
1414
TensorTimeFunction, VectorFunction, VectorTimeFunction,
15-
div, grad, switchconfig, exp)
15+
div, grad, switchconfig, exp, Buffer)
1616
from devito import Inc, Le, Lt, Ge, Gt # noqa
1717
from devito.exceptions import InvalidOperator
1818
from devito.finite_differences.differentiable import diff2sympy
@@ -2059,10 +2059,15 @@ def test_indirection(self):
20592059
class TestEstimateMemory:
20602060
"""Tests for the Operator.estimate_memory() utility"""
20612061

2062-
def parse_output(self, output):
2062+
def parse_output(self, output, expected):
20632063
"""Parse estimate_memory machine-readable output"""
2064-
name, disk, host, device = output.split()
2065-
return name, int(disk), int(host), int(device)
2064+
# Check that no allocation occurs as estimate_memory should avoid data touch
2065+
assert "Allocating" not in output.text
2066+
2067+
name, disk, host, device = output.records[-1].message.split()
2068+
extracted = (name, int(disk), int(host), int(device))
2069+
2070+
assert extracted == expected
20662071

20672072
@pytest.mark.parametrize('shape', [(11,), (101, 101), (101, 101, 101)])
20682073
@pytest.mark.parametrize('dtype', [np.int8, np.int16, np.float32,
@@ -2077,15 +2082,10 @@ def test_basic_usage(self, caplog, shape, dtype, so):
20772082
# Machine-readable output for parsing
20782083
op.estimate_memory(human_readable=False)
20792084

2080-
# Check that no allocation occurs as estimate_memory should avoid data touch
2081-
assert "Allocating" not in caplog.text
2082-
20832085
# Check output of estimate_memory
2084-
name, dist, host, device = self.parse_output(caplog.records[-1].message)
2085-
assert name == "Kernel"
2086-
assert dist == 0
2087-
assert device == 0
2088-
assert host == reduce(mul, [s + 2*so for s in shape])*np.dtype(dtype).itemsize
2086+
host = reduce(mul, [s + 2*so for s in shape])*np.dtype(dtype).itemsize
2087+
expected = ("Kernel", 0, host, 0)
2088+
self.parse_output(caplog, expected)
20892089

20902090
def test_multiple_objects(self, caplog):
20912091
grid = Grid(shape=(101, 101))
@@ -2094,42 +2094,127 @@ def test_multiple_objects(self, caplog):
20942094
g = Function(name='g', grid=grid, space_order=4, dtype=np.float64)
20952095
with switchconfig(log_level='DEBUG'), caplog.at_level(logging.DEBUG):
20962096
op = Operator([Eq(f, 1), Eq(g, 1)])
2097-
20982097
op.estimate_memory(human_readable=False)
2099-
assert "Allocating" not in caplog.text
2100-
name, dist, host, device = self.parse_output(caplog.records[-1].message)
2101-
assert name == "Kernel"
2102-
assert dist == 0
2103-
assert device == 0
2098+
21042099
check = sum(reduce(mul, func.shape_allocated)*np.dtype(func.dtype).itemsize
21052100
for func in (f, g))
2106-
assert host == check
2101+
expected = ("Kernel", 0, check, 0)
2102+
self.parse_output(caplog, expected)
21072103

2108-
def test_sparse(self, caplog):
2109-
# FIXME: Can be consolidated with previous test
2104+
@pytest.mark.parametrize('time', [True, False])
2105+
def test_sparse(self, caplog, time):
21102106
grid = Grid(shape=(101, 101))
21112107
f = Function(name='f', grid=grid, space_order=2)
2112-
src = SparseFunction(name='src', grid=grid, npoint=10000)
2108+
if time:
2109+
src = SparseTimeFunction(name='src', grid=grid, npoint=1000, nt=10)
2110+
else:
2111+
src = SparseFunction(name='src', grid=grid, npoint=1000)
21132112
src_term = src.inject(field=f, expr=src)
21142113

21152114
with switchconfig(log_level='DEBUG'), caplog.at_level(logging.DEBUG):
21162115
op = Operator(src_term)
21172116
op.estimate_memory(human_readable=False)
2118-
assert "Allocating" not in caplog.text
2119-
name, dist, host, device = self.parse_output(caplog.records[-1].message)
2120-
assert name == "Kernel"
2121-
assert dist == 0
2122-
assert device == 0
21232117

21242118
check = sum(reduce(mul, func.shape_allocated)*np.dtype(func.dtype).itemsize
21252119
for func in (f, src, src.coordinates))
2120+
expected = ("Kernel", 0, check, 0)
2121+
self.parse_output(caplog, expected)
2122+
2123+
@pytest.mark.parametrize('save', [None, Buffer(3), 10])
2124+
def test_timefunction(self, caplog, save):
2125+
grid = Grid(shape=(101, 101))
2126+
f = Function(name='f', grid=grid, space_order=2, save=save)
2127+
2128+
with switchconfig(log_level='DEBUG'), caplog.at_level(logging.DEBUG):
2129+
op = Operator(Eq(f, 1))
2130+
op.estimate_memory(human_readable=False)
2131+
check = reduce(mul, f.shape_allocated)*np.dtype(f.dtype).itemsize
2132+
expected = ("Kernel", 0, check, 0)
2133+
self.parse_output(caplog, expected)
2134+
2135+
def test_mashup(self, caplog):
2136+
grid = Grid(shape=(101, 101))
2137+
f = Function(name='f', grid=grid, space_order=4)
2138+
g = TimeFunction(name='g', grid=grid, space_order=4)
2139+
2140+
src0 = SparseFunction(name='src0', grid=grid, npoint=100)
2141+
src1 = SparseFunction(name='src1', grid=grid, npoint=100)
2142+
2143+
eq0 = Eq(f, 1)
2144+
eq1 = Eq(g, 1)
2145+
2146+
src_term0 = src0.inject(field=f, expr=src0)
2147+
src_term1 = src1.inject(field=f, expr=src1)
2148+
2149+
with switchconfig(log_level='DEBUG'), caplog.at_level(logging.DEBUG):
2150+
op = Operator([eq0, eq1] + src_term0 + src_term1)
2151+
op.estimate_memory(human_readable=False)
2152+
2153+
check = sum(reduce(mul, func.shape_allocated)*np.dtype(func.dtype).itemsize
2154+
for func in (f, g, src0, src0.coordinates,
2155+
src1, src1.coordinates))
2156+
expected = ("Kernel", 0, check, 0)
2157+
self.parse_output(caplog, expected)
2158+
2159+
def test_temp_array(self, caplog):
2160+
"""Check that temporary arrays will be factored into the memory calculation"""
2161+
grid = Grid(shape=(101, 101))
2162+
f = TimeFunction(name='f', grid=grid, space_order=2)
2163+
g = TimeFunction(name='g', grid=grid, space_order=2)
2164+
a = Function(name='a', grid=grid, space_order=2)
2165+
2166+
# Reuse an expensive function to encourage generation of an array temp
2167+
eq0 = Eq(f.forward, g + sympy.sin(a))
2168+
eq1 = Eq(g.forward, f + sympy.sin(a))
2169+
2170+
with switchconfig(log_level='DEBUG'), caplog.at_level(logging.DEBUG):
2171+
op = Operator([eq0, eq1])
2172+
2173+
# Regression to ensure this test functions as intended
2174+
# Ensure an array temporary is created
2175+
assert "r0[x][y]" in str(op.ccode)
2176+
2177+
op.estimate_memory(human_readable=False)
2178+
2179+
check = sum(reduce(mul, func.shape_allocated)*np.dtype(func.dtype).itemsize
2180+
for func in (f, g, a))
2181+
2182+
# Factor in the temp array
2183+
check += reduce(mul, a.shape)*np.dtype(a.dtype).itemsize
2184+
2185+
expected = ("Kernel", 0, check, 0)
2186+
self.parse_output(caplog, expected)
2187+
2188+
def test_overrides(self, caplog):
2189+
grid0 = Grid(shape=(101, 101))
2190+
# Original fields
2191+
f0 = Function(name='f0', grid=grid0, space_order=4)
2192+
tf0 = TimeFunction(name='tf0', grid=grid0, space_order=4)
2193+
s0 = SparseFunction(name='s0', grid=grid0, npoint=100)
2194+
st0 = SparseTimeFunction(name='st0', grid=grid0, npoint=100, nt=10)
2195+
2196+
grid1 = Grid(shape=(201, 201)) # Bigger grid so overrides are distinct
2197+
# Replacement fields
2198+
f1 = Function(name='f1', grid=grid1, space_order=4)
2199+
tf1 = TimeFunction(name='tf1', grid=grid1, space_order=4)
2200+
s1 = SparseFunction(name='s1', grid=grid1, npoint=200)
2201+
st1 = SparseTimeFunction(name='st1', grid=grid1, npoint=200, nt=20)
2202+
2203+
eq0 = Eq(f0, 1)
2204+
eq1 = Eq(tf0, 1)
2205+
s0_term = s0.inject(field=f0, expr=s0)
2206+
st0_term = st0.inject(field=tf0, expr=st0)
2207+
2208+
with switchconfig(log_level='DEBUG'), caplog.at_level(logging.DEBUG):
2209+
op = Operator([eq0, eq1] + s0_term + st0_term)
2210+
2211+
# Apply overrides for the check
2212+
op.estimate_memory(f0=f1, tf0=tf1, s0=s1, st0=st1, human_readable=False)
2213+
2214+
check = sum(reduce(mul, func.shape_allocated)*np.dtype(func.dtype).itemsize
2215+
for func in (f1, tf1, s1, s1.coordinates, st1, st1.coordinates))
21262216

2127-
assert host == check
2217+
expected = ("Kernel", 0, check, 0)
2218+
self.parse_output(caplog, expected)
21282219

2129-
# Check with timefunctions (with both buffer and save)
2130-
# Check with sparsetimefunctions
2131-
# Check mashup
2132-
# Check overrides for functions
2133-
# Check overrides for sparsefunctions
2134-
# Check overrides for timefunctions
2135-
# Check overrides for sparsetimefunctions
2220+
# Test with OpenACC

0 commit comments

Comments
 (0)