Skip to content

Commit 405df56

Browse files
committed
Simplify kernel dispatch in iw_accumulate.hpp
1 parent f03ad07 commit 405df56

1 file changed

Lines changed: 80 additions & 123 deletions

File tree

Lines changed: 80 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#pragma once
22
#include <xsimd/xsimd.hpp>
33

4+
#include "./../types.hpp"
5+
46
#if defined(_MSC_VER)
57
// For Microsoft Visual Studio
68
#define RESTRICT __restrict
@@ -46,7 +48,7 @@ namespace triqs_ctint::measures {
4648
static_assert(!std::is_void_v<make_complex_sized_batch_t<double, min_width>>, "failed to create min_width complex batch");
4749
static_assert(!std::is_void_v<make_complex_sized_batch_t<double, max_width>>, "failed to create max_width complex batch");
4850

49-
template <auto bl1_batch, auto bl2_batch>
51+
template <int bl1_batch, int bl2_batch>
5052
void process_inner_loop(mc_weight_t sign, const auto &M1a, const auto &M2a, const auto &M1b, const auto &M2b, auto &M4a, const auto bl1_size,
5153
const auto bl2_size, const auto bl1, const auto bl2) {
5254
using batch1_t = make_complex_sized_batch_t<double, std::max(std::min(bl1_batch, max_width), min_width)>;
@@ -93,14 +95,15 @@ namespace triqs_ctint::measures {
9395
}
9496
}
9597

96-
template <auto bl1_batch, auto bl2_batch>
97-
void iw4_accumulate_kernel(mc_weight_t sign, const auto &M, auto &M4_iw, const auto bl1, const auto bl2) {
98-
auto const &iw_mesh = std::get<0>(M4_iw(0, 0).mesh());
99-
auto const bl1_size = M[bl1].target_shape()[0];
100-
auto const bl2_size = M[bl2].target_shape()[0];
101-
auto const M1 = M[bl1];
102-
auto const M2 = M[bl2];
103-
auto &M4 = M4_iw(bl1, bl2);
98+
const auto iw4_accumulate_kernel = []<int bl1_batch, int bl2_batch>(mc_weight_t sign, const auto &M, auto &M4_iw, const auto bl1,
99+
const auto bl2) {
100+
//auto const &iw_mesh = std::get<0>(M4_iw(0, 0).mesh());
101+
auto &[iw_mesh, _, _] = M4_iw(0, 0).mesh();
102+
auto const M1 = M[bl1];
103+
auto const M2 = M[bl2];
104+
auto const bl1_size = M1.target_shape()[0];
105+
auto const bl2_size = M2.target_shape()[0];
106+
auto &M4 = M4_iw(bl1, bl2);
104107

105108
for (const auto &iw1 : iw_mesh) {
106109
for (const auto &iw2 : iw_mesh) {
@@ -115,58 +118,60 @@ namespace triqs_ctint::measures {
115118
}
116119
}
117120
}
118-
}
121+
};
119122

120-
template <auto bl1_batch, auto bl2_batch>
121-
void iw4ph_accumulate_kernel(mc_weight_t sign, const auto &M, auto &M4_iw, const auto bl1, const auto bl2) {
122-
auto const &iW_mesh = std::get<0>(M4_iw(0, 0).mesh());
123-
auto const &iw_mesh = std::get<1>(M4_iw(0, 0).mesh());
124-
auto const bl1_size = M[bl1].target_shape()[0];
125-
auto const bl2_size = M[bl2].target_shape()[0];
126-
auto const M1 = M[bl1];
127-
auto const M2 = M[bl2];
128-
auto &M4 = M4_iw(bl1, bl2);
123+
const auto iw4ph_accumulate_kernel =
124+
[]<int bl1_batch, int bl2_batch>(mc_weight_t sign, const auto &M, auto &M4_iw, const auto bl1, const auto bl2) {
125+
//auto const &iW_mesh = std::get<0>(M4_iw(0, 0).mesh());
126+
//auto const &iw_mesh = std::get<1>(M4_iw(0, 0).mesh());
127+
auto const &[iW_mesh, iw_mesh, _] = M4_iw(0, 0).mesh();
128+
auto const M1 = M[bl1];
129+
auto const M2 = M[bl2];
130+
auto const bl1_size = M1.target_shape()[0];
131+
auto const bl2_size = M2.target_shape()[0];
132+
auto &M4 = M4_iw(bl1, bl2);
129133

130-
for (auto iW : iW_mesh) {
131-
for (auto iw : iw_mesh) {
132-
for (auto iwp : iw_mesh) {
133-
const auto M1a = M1[iW + iw, iw.value()];
134-
const auto M2a = M2[iwp.value(), iW + iwp];
135-
const auto M1b = M1[iwp.value(), iw.value()];
136-
const auto M2b = M2[iW + iw, iW + iwp];
137-
auto M4a = M4[iW, iw, iwp];
138-
process_inner_loop<bl1_batch, bl2_batch>(sign, M1a, M2a, M1b, M2b, M4a, bl1_size, bl2_size, bl1, bl2);
139-
}
140-
}
141-
}
142-
}
134+
for (auto iW : iW_mesh) {
135+
for (auto iw : iw_mesh) {
136+
for (auto iwp : iw_mesh) {
137+
const auto M1a = M1[iW + iw, iw.value()];
138+
const auto M2a = M2[iwp.value(), iW + iwp];
139+
const auto M1b = M1[iwp.value(), iw.value()];
140+
const auto M2b = M2[iW + iw, iW + iwp];
141+
auto M4a = M4[iW, iw, iwp];
142+
process_inner_loop<bl1_batch, bl2_batch>(sign, M1a, M2a, M1b, M2b, M4a, bl1_size, bl2_size, bl1, bl2);
143+
}
144+
}
145+
}
146+
};
143147

144-
template <auto bl1_batch, auto bl2_batch>
145-
void iw4pp_accumulate_kernel(mc_weight_t sign, const auto &M, auto &M4_iw, const auto bl1, const auto bl2) {
146-
auto const &iW_mesh = std::get<0>(M4_iw(0, 0).mesh());
147-
auto const &iw_mesh = std::get<1>(M4_iw(0, 0).mesh());
148-
auto const bl1_size = M[bl1].target_shape()[0];
149-
auto const bl2_size = M[bl2].target_shape()[0];
150-
auto const M1 = M[bl1];
151-
auto const M2 = M[bl2];
152-
auto &M4 = M4_iw(bl1, bl2);
148+
const auto iw4pp_accumulate_kernel =
149+
[]<int bl1_batch, int bl2_batch>(mc_weight_t sign, const auto &M, auto &M4_iw, const auto bl1, const auto bl2) {
150+
//auto const &iW_mesh = std::get<0>(M4_iw(0, 0).mesh());
151+
//auto const &iw_mesh = std::get<1>(M4_iw(0, 0).mesh());
152+
auto const &[iW_mesh, iw_mesh, _] = M4_iw(0, 0).mesh();
153+
auto const bl1_size = M[bl1].target_shape()[0];
154+
auto const bl2_size = M[bl2].target_shape()[0];
155+
auto const M1 = M[bl1];
156+
auto const M2 = M[bl2];
157+
auto &M4 = M4_iw(bl1, bl2);
153158

154-
for (auto iW : iW_mesh) {
155-
for (auto iw : iw_mesh) {
156-
for (auto iwp : iw_mesh) {
157-
const auto M1a = M1[iW - iwp, iw.value()];
158-
const auto M2a = M2[iwp.value(), iW - iw];
159-
const auto M1b = M1[iwp.value(), iw.value()];
160-
const auto M2b = M2[iW - iwp, iW - iw];
161-
auto M4a = M4[iW, iw, iwp];
162-
process_inner_loop<bl1_batch, bl2_batch>(sign, M1a, M2a, M1b, M2b, M4a, bl1_size, bl2_size, bl1, bl2);
163-
}
164-
}
165-
}
166-
}
159+
for (auto iW : iW_mesh) {
160+
for (auto iw : iw_mesh) {
161+
for (auto iwp : iw_mesh) {
162+
const auto M1a = M1[iW - iwp, iw.value()];
163+
const auto M2a = M2[iwp.value(), iW - iw];
164+
const auto M1b = M1[iwp.value(), iw.value()];
165+
const auto M2b = M2[iW - iwp, iW - iw];
166+
auto M4a = M4[iW, iw, iwp];
167+
process_inner_loop<bl1_batch, bl2_batch>(sign, M1a, M2a, M1b, M2b, M4a, bl1_size, bl2_size, bl1, bl2);
168+
}
169+
}
170+
}
171+
};
167172

168-
template <auto bl1_batch, auto bl2_batch>
169-
void iw3pp_accumulate_kernel(mc_weight_t sign, const auto &GM, auto &M3_iw, const auto bl1, const auto bl2) {
173+
const auto iw3pp_accumulate_kernel = []<int bl1_batch, int bl2_batch>(mc_weight_t sign, const auto &GM, auto &M3_iw, const auto bl1,
174+
const auto bl2) {
170175
auto const [iW_mesh, iw_mesh] = M3_iw(0, 0).mesh();
171176
auto const bl1_size = GM[bl1].target_shape()[0];
172177
auto const bl2_size = GM[bl2].target_shape()[0];
@@ -182,11 +187,10 @@ namespace triqs_ctint::measures {
182187
process_inner_loop<bl1_batch, bl2_batch>(sign, M1a, M2a, M1a, M2a, M4a, bl1_size, bl2_size, bl1, bl2);
183188
}
184189
}
185-
}
190+
};
186191

187-
template <auto bl1_batch, auto bl2_batch>
188-
void iw3ph_accumulate_kernel(mc_weight_t sign, const auto &M, const auto &GMG, const auto &GM, const auto &MG, auto &M3_iw, const auto bl1,
189-
const auto bl2) {
192+
const auto iw3ph_accumulate_kernel = []<int bl1_batch, int bl2_batch>(mc_weight_t sign, const auto &M, const auto &GMG, const auto &GM,
193+
const auto &MG, auto &M3_iw, const auto bl1, const auto bl2) {
190194
auto const [iW_mesh, iw_mesh] = M3_iw(0, 0).mesh();
191195
auto const bl1_size = M[bl1].target_shape()[0];
192196
auto const bl2_size = M[bl2].target_shape()[0];
@@ -206,95 +210,48 @@ namespace triqs_ctint::measures {
206210
process_inner_loop<bl1_batch, bl2_batch>(sign, M1a, M2a, M1b, M2b, M4a, bl1_size, bl2_size, bl1, bl2);
207211
}
208212
}
209-
}
213+
};
210214
} // namespace
211215

212216
namespace simd {
213217

214-
void iw4_accumulate(mc_weight_t sign, const auto &M, auto &M4_iw, const auto bl1, const auto bl2, const auto bl2_size) {
218+
template <auto kernel> void kernel_dispatch(const auto bl2_size, auto &&...args) {
215219
// Dispatch to the correct SIMD instruction width based on the size of the blocks
216220
// It will try to use the widest SIMD instruction available for the given block sizes
217221
// TODO: fold expressions might be an option to simplify the code
218222
if (bl2_size >= 8) {
219-
return iw4_accumulate_kernel<8, 8>(sign, M, M4_iw, bl1, bl2);
223+
return kernel.template operator()<8, 8>(args...);
220224
} else if (bl2_size >= 4) {
221-
return iw4_accumulate_kernel<8, 4>(sign, M, M4_iw, bl1, bl2);
225+
return kernel.template operator()<8, 4>(args...);
222226
} else if (bl2_size >= 3) {
223-
return iw4_accumulate_kernel<8, 2>(sign, M, M4_iw, bl1, bl2);
227+
return kernel.template operator()<8, 2>(args...);
224228
} else if (bl2_size >= 2) {
225-
return iw4_accumulate_kernel<4, 2>(sign, M, M4_iw, bl1, bl2);
229+
return kernel.template operator()<4, 2>(args...);
226230
} else {
227-
return iw4_accumulate_kernel<1, 1>(sign, M, M4_iw, bl1, bl2);
231+
return kernel.template operator()<1, 1>(args...);
228232
}
229233
}
230234

235+
void iw4_accumulate(mc_weight_t sign, const auto &M, auto &M4_iw, const auto bl1, const auto bl2, const auto bl2_size) {
236+
kernel_dispatch<iw4_accumulate_kernel>(bl2_size, sign, M, M4_iw, bl1, bl2);
237+
}
238+
231239
void iw4ph_accumulate(mc_weight_t sign, const auto &M, auto &M4_iw, const auto bl1, const auto bl2, const auto bl2_size) {
232-
// Dispatch to the correct SIMD instruction width based on the size of the blocks
233-
// It will try to use the widest SIMD instruction available for the given block sizes
234-
// TODO: fold expressions might be an option to simplify the code
235-
if (bl2_size >= 8) {
236-
return iw4ph_accumulate_kernel<8, 8>(sign, M, M4_iw, bl1, bl2);
237-
} else if (bl2_size >= 4) {
238-
return iw4ph_accumulate_kernel<8, 4>(sign, M, M4_iw, bl1, bl2);
239-
} else if (bl2_size >= 3) {
240-
return iw4ph_accumulate_kernel<8, 2>(sign, M, M4_iw, bl1, bl2);
241-
} else if (bl2_size >= 2) {
242-
return iw4ph_accumulate_kernel<4, 2>(sign, M, M4_iw, bl1, bl2);
243-
} else {
244-
return iw4ph_accumulate_kernel<1, 1>(sign, M, M4_iw, bl1, bl2);
245-
}
240+
kernel_dispatch<iw4ph_accumulate_kernel>(bl2_size, sign, M, M4_iw, bl1, bl2);
246241
}
247242

248243
void iw4pp_accumulate(mc_weight_t sign, const auto &M, auto &M4_iw, const auto bl1, const auto bl2, const auto bl2_size) {
249-
// Dispatch to the correct SIMD instruction width based on the size of the blocks
250-
// It will try to use the widest SIMD instruction available for the given block sizes
251-
// TODO: fold expressions might be an option to simplify the code
252-
if (bl2_size >= 8) {
253-
return iw4pp_accumulate_kernel<8, 8>(sign, M, M4_iw, bl1, bl2);
254-
} else if (bl2_size >= 4) {
255-
return iw4pp_accumulate_kernel<8, 4>(sign, M, M4_iw, bl1, bl2);
256-
} else if (bl2_size >= 3) {
257-
return iw4pp_accumulate_kernel<8, 2>(sign, M, M4_iw, bl1, bl2);
258-
} else if (bl2_size >= 2) {
259-
return iw4pp_accumulate_kernel<4, 2>(sign, M, M4_iw, bl1, bl2);
260-
} else {
261-
return iw4pp_accumulate_kernel<1, 1>(sign, M, M4_iw, bl1, bl2);
262-
}
244+
kernel_dispatch<iw4pp_accumulate_kernel>(bl2_size, sign, M, M4_iw, bl1, bl2);
263245
}
264246

265247
void iw3ph_accumulate(mc_weight_t sign, const auto &M, const auto &GMG, const auto &GM, const auto &MG, auto &M4_iw, const auto bl1,
266248
const auto bl2, const auto bl2_size) {
267-
// Dispatch to the correct SIMD instruction width based on the size of the blocks
268-
// It will try to use the widest SIMD instruction available for the given block sizes
269-
// TODO: fold expressions might be an option to simplify the code
270-
if (bl2_size >= 8) {
271-
return iw3ph_accumulate_kernel<8, 8>(sign, M, GMG, GM, MG, M4_iw, bl1, bl2);
272-
} else if (bl2_size >= 4) {
273-
return iw3ph_accumulate_kernel<8, 4>(sign, M, GMG, GM, MG, M4_iw, bl1, bl2);
274-
} else if (bl2_size >= 3) {
275-
return iw3ph_accumulate_kernel<8, 2>(sign, M, GMG, GM, MG, M4_iw, bl1, bl2);
276-
} else if (bl2_size >= 2) {
277-
return iw3ph_accumulate_kernel<4, 2>(sign, M, GMG, GM, MG, M4_iw, bl1, bl2);
278-
} else {
279-
return iw3ph_accumulate_kernel<1, 1>(sign, M, GMG, GM, MG, M4_iw, bl1, bl2);
280-
}
249+
kernel_dispatch<iw3ph_accumulate_kernel>(bl2_size, sign, M, GMG, GM, MG, M4_iw, bl1, bl2);
281250
}
282251

283252
void iw3pp_accumulate(mc_weight_t sign, const auto &M, auto &M4_iw, const auto bl1, const auto bl2, const auto bl2_size) {
284-
// Dispatch to the correct SIMD instruction width based on the size of the blocks
285-
// It will try to use the widest SIMD instruction available for the given block sizes
286-
// TODO: fold expressions might be an option to simplify the code
287-
if (bl2_size >= 8) {
288-
return iw3pp_accumulate_kernel<8, 8>(sign, M, M4_iw, bl1, bl2);
289-
} else if (bl2_size >= 4) {
290-
return iw3pp_accumulate_kernel<8, 4>(sign, M, M4_iw, bl1, bl2);
291-
} else if (bl2_size >= 3) {
292-
return iw3pp_accumulate_kernel<8, 2>(sign, M, M4_iw, bl1, bl2);
293-
} else if (bl2_size >= 2) {
294-
return iw3pp_accumulate_kernel<4, 2>(sign, M, M4_iw, bl1, bl2);
295-
} else {
296-
return iw3pp_accumulate_kernel<1, 1>(sign, M, M4_iw, bl1, bl2);
297-
}
253+
kernel_dispatch<iw3pp_accumulate_kernel>(bl2_size, sign, M, M4_iw, bl1, bl2);
298254
}
255+
299256
} // namespace simd
300257
} // namespace triqs_ctint::measures

0 commit comments

Comments
 (0)