Skip to content

Commit caf9221

Browse files
committed
Add test_remove_stripe_fw_performance
1 parent 2c4ae43 commit caf9221

1 file changed

Lines changed: 25 additions & 0 deletions

File tree

tests/test_prep/test_stripe.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,31 @@ def test_remove_stripe_fw_calc_mem_big(wname, slices, level, dims, ensure_clean_
107107
assert estimated_mem_peak <= actual_mem_peak * 1.3
108108

109109

110+
@pytest.mark.perf
111+
def test_remove_stripe_fw_performance(ensure_clean_memory):
112+
data_host = (
113+
np.random.random_sample(size=(1801, 5, 2560)).astype(np.float32) * 2.0 + 0.001
114+
)
115+
data = cp.asarray(data_host, dtype=np.float32)
116+
117+
# do a cold run first
118+
remove_stripe_fw(cp.copy(data))
119+
120+
dev = cp.cuda.Device()
121+
dev.synchronize()
122+
123+
start = time.perf_counter_ns()
124+
nvtx.RangePush("Core")
125+
for _ in range(10):
126+
# have to take copy, as data is modified in-place
127+
remove_stripe_fw(cp.copy(data))
128+
nvtx.RangePop()
129+
dev.synchronize()
130+
duration_ms = float(time.perf_counter_ns() - start) * 1e-6 / 10
131+
132+
assert "performance in ms" == duration_ms
133+
134+
110135
@pytest.mark.parametrize("angles", [180, 181])
111136
@pytest.mark.parametrize("det_x", [11, 18])
112137
@pytest.mark.parametrize("det_y", [5, 7, 8])

0 commit comments

Comments
 (0)