11#pragma once
22#include < xsimd/xsimd.hpp>
33
4+ #include " ./../types.hpp"
5+
46#if defined(_MSC_VER)
57// For Microsoft Visual Studio
68#define RESTRICT __restrict
@@ -46,7 +48,7 @@ namespace triqs_ctint::measures {
4648 static_assert (!std::is_void_v<make_complex_sized_batch_t <double , min_width>>, " failed to create min_width complex batch" );
4749 static_assert (!std::is_void_v<make_complex_sized_batch_t <double , max_width>>, " failed to create max_width complex batch" );
4850
49- template <auto bl1_batch, auto bl2_batch>
51+ template <int bl1_batch, int bl2_batch>
5052 void process_inner_loop (mc_weight_t sign, const auto &M1a, const auto &M2a, const auto &M1b, const auto &M2b, auto &M4a, const auto bl1_size,
5153 const auto bl2_size, const auto bl1, const auto bl2) {
5254 using batch1_t = make_complex_sized_batch_t <double , std::max (std::min (bl1_batch, max_width), min_width)>;
@@ -93,14 +95,15 @@ namespace triqs_ctint::measures {
9395 }
9496 }
9597
96- template <auto bl1_batch, auto bl2_batch>
97- void iw4_accumulate_kernel (mc_weight_t sign, const auto &M, auto &M4_iw, const auto bl1, const auto bl2) {
98- auto const &iw_mesh = std::get<0 >(M4_iw (0 , 0 ).mesh ());
99- auto const bl1_size = M[bl1].target_shape ()[0 ];
100- auto const bl2_size = M[bl2].target_shape ()[0 ];
101- auto const M1 = M[bl1];
102- auto const M2 = M[bl2];
103- auto &M4 = M4_iw (bl1, bl2);
98+ const auto iw4_accumulate_kernel = []<int bl1_batch, int bl2_batch>(mc_weight_t sign, const auto &M, auto &M4_iw, const auto bl1,
99+ const auto bl2) {
100+ // auto const &iw_mesh = std::get<0>(M4_iw(0, 0).mesh());
101+ auto &[iw_mesh, _, _] = M4_iw (0 , 0 ).mesh ();
102+ auto const M1 = M[bl1];
103+ auto const M2 = M[bl2];
104+ auto const bl1_size = M1.target_shape ()[0 ];
105+ auto const bl2_size = M2.target_shape ()[0 ];
106+ auto &M4 = M4_iw (bl1, bl2);
104107
105108 for (const auto &iw1 : iw_mesh) {
106109 for (const auto &iw2 : iw_mesh) {
@@ -115,58 +118,60 @@ namespace triqs_ctint::measures {
115118 }
116119 }
117120 }
118- }
121+ };
119122
120- template <auto bl1_batch, auto bl2_batch>
121- void iw4ph_accumulate_kernel (mc_weight_t sign, const auto &M, auto &M4_iw, const auto bl1, const auto bl2) {
122- auto const &iW_mesh = std::get<0 >(M4_iw (0 , 0 ).mesh ());
123- auto const &iw_mesh = std::get<1 >(M4_iw (0 , 0 ).mesh ());
124- auto const bl1_size = M[bl1].target_shape ()[0 ];
125- auto const bl2_size = M[bl2].target_shape ()[0 ];
126- auto const M1 = M[bl1];
127- auto const M2 = M[bl2];
128- auto &M4 = M4_iw (bl1, bl2);
123+ const auto iw4ph_accumulate_kernel =
124+ []<int bl1_batch, int bl2_batch>(mc_weight_t sign, const auto &M, auto &M4_iw, const auto bl1, const auto bl2) {
125+ // auto const &iW_mesh = std::get<0>(M4_iw(0, 0).mesh());
126+ // auto const &iw_mesh = std::get<1>(M4_iw(0, 0).mesh());
127+ auto const &[iW_mesh, iw_mesh, _] = M4_iw (0 , 0 ).mesh ();
128+ auto const M1 = M[bl1];
129+ auto const M2 = M[bl2];
130+ auto const bl1_size = M1.target_shape ()[0 ];
131+ auto const bl2_size = M2.target_shape ()[0 ];
132+ auto &M4 = M4_iw (bl1, bl2);
129133
130- for (auto iW : iW_mesh) {
131- for (auto iw : iw_mesh) {
132- for (auto iwp : iw_mesh) {
133- const auto M1a = M1[iW + iw, iw.value ()];
134- const auto M2a = M2[iwp.value (), iW + iwp];
135- const auto M1b = M1[iwp.value (), iw.value ()];
136- const auto M2b = M2[iW + iw, iW + iwp];
137- auto M4a = M4[iW, iw, iwp];
138- process_inner_loop<bl1_batch, bl2_batch>(sign, M1a, M2a, M1b, M2b, M4a, bl1_size, bl2_size, bl1, bl2);
139- }
140- }
141- }
142- }
134+ for (auto iW : iW_mesh) {
135+ for (auto iw : iw_mesh) {
136+ for (auto iwp : iw_mesh) {
137+ const auto M1a = M1[iW + iw, iw.value ()];
138+ const auto M2a = M2[iwp.value (), iW + iwp];
139+ const auto M1b = M1[iwp.value (), iw.value ()];
140+ const auto M2b = M2[iW + iw, iW + iwp];
141+ auto M4a = M4[iW, iw, iwp];
142+ process_inner_loop<bl1_batch, bl2_batch>(sign, M1a, M2a, M1b, M2b, M4a, bl1_size, bl2_size, bl1, bl2);
143+ }
144+ }
145+ }
146+ };
143147
144- template <auto bl1_batch, auto bl2_batch>
145- void iw4pp_accumulate_kernel (mc_weight_t sign, const auto &M, auto &M4_iw, const auto bl1, const auto bl2) {
146- auto const &iW_mesh = std::get<0 >(M4_iw (0 , 0 ).mesh ());
147- auto const &iw_mesh = std::get<1 >(M4_iw (0 , 0 ).mesh ());
148- auto const bl1_size = M[bl1].target_shape ()[0 ];
149- auto const bl2_size = M[bl2].target_shape ()[0 ];
150- auto const M1 = M[bl1];
151- auto const M2 = M[bl2];
152- auto &M4 = M4_iw (bl1, bl2);
148+ const auto iw4pp_accumulate_kernel =
149+ []<int bl1_batch, int bl2_batch>(mc_weight_t sign, const auto &M, auto &M4_iw, const auto bl1, const auto bl2) {
150+ // auto const &iW_mesh = std::get<0>(M4_iw(0, 0).mesh());
151+ // auto const &iw_mesh = std::get<1>(M4_iw(0, 0).mesh());
152+ auto const &[iW_mesh, iw_mesh, _] = M4_iw (0 , 0 ).mesh ();
153+ auto const bl1_size = M[bl1].target_shape ()[0 ];
154+ auto const bl2_size = M[bl2].target_shape ()[0 ];
155+ auto const M1 = M[bl1];
156+ auto const M2 = M[bl2];
157+ auto &M4 = M4_iw (bl1, bl2);
153158
154- for (auto iW : iW_mesh) {
155- for (auto iw : iw_mesh) {
156- for (auto iwp : iw_mesh) {
157- const auto M1a = M1[iW - iwp, iw.value ()];
158- const auto M2a = M2[iwp.value (), iW - iw];
159- const auto M1b = M1[iwp.value (), iw.value ()];
160- const auto M2b = M2[iW - iwp, iW - iw];
161- auto M4a = M4[iW, iw, iwp];
162- process_inner_loop<bl1_batch, bl2_batch>(sign, M1a, M2a, M1b, M2b, M4a, bl1_size, bl2_size, bl1, bl2);
163- }
164- }
165- }
166- }
159+ for (auto iW : iW_mesh) {
160+ for (auto iw : iw_mesh) {
161+ for (auto iwp : iw_mesh) {
162+ const auto M1a = M1[iW - iwp, iw.value ()];
163+ const auto M2a = M2[iwp.value (), iW - iw];
164+ const auto M1b = M1[iwp.value (), iw.value ()];
165+ const auto M2b = M2[iW - iwp, iW - iw];
166+ auto M4a = M4[iW, iw, iwp];
167+ process_inner_loop<bl1_batch, bl2_batch>(sign, M1a, M2a, M1b, M2b, M4a, bl1_size, bl2_size, bl1, bl2);
168+ }
169+ }
170+ }
171+ };
167172
168- template < auto bl1_batch, auto bl2_batch>
169- void iw3pp_accumulate_kernel ( mc_weight_t sign, const auto &GM, auto &M3_iw, const auto bl1, const auto bl2) {
173+ const auto iw3pp_accumulate_kernel = []< int bl1_batch, int bl2_batch>( mc_weight_t sign, const auto &GM, auto &M3_iw, const auto bl1,
174+ const auto bl2) {
170175 auto const [iW_mesh, iw_mesh] = M3_iw (0 , 0 ).mesh ();
171176 auto const bl1_size = GM[bl1].target_shape ()[0 ];
172177 auto const bl2_size = GM[bl2].target_shape ()[0 ];
@@ -182,11 +187,10 @@ namespace triqs_ctint::measures {
182187 process_inner_loop<bl1_batch, bl2_batch>(sign, M1a, M2a, M1a, M2a, M4a, bl1_size, bl2_size, bl1, bl2);
183188 }
184189 }
185- }
190+ };
186191
187- template <auto bl1_batch, auto bl2_batch>
188- void iw3ph_accumulate_kernel (mc_weight_t sign, const auto &M, const auto &GMG, const auto &GM, const auto &MG, auto &M3_iw, const auto bl1,
189- const auto bl2) {
192+ const auto iw3ph_accumulate_kernel = []<int bl1_batch, int bl2_batch>(mc_weight_t sign, const auto &M, const auto &GMG, const auto &GM,
193+ const auto &MG, auto &M3_iw, const auto bl1, const auto bl2) {
190194 auto const [iW_mesh, iw_mesh] = M3_iw (0 , 0 ).mesh ();
191195 auto const bl1_size = M[bl1].target_shape ()[0 ];
192196 auto const bl2_size = M[bl2].target_shape ()[0 ];
@@ -206,95 +210,48 @@ namespace triqs_ctint::measures {
206210 process_inner_loop<bl1_batch, bl2_batch>(sign, M1a, M2a, M1b, M2b, M4a, bl1_size, bl2_size, bl1, bl2);
207211 }
208212 }
209- }
213+ };
210214 } // namespace
211215
212216 namespace simd {
213217
214- void iw4_accumulate ( mc_weight_t sign, const auto &M, auto &M4_iw, const auto bl1, const auto bl2, const auto bl2_size ) {
218+ template < auto kernel> void kernel_dispatch ( const auto bl2_size, auto &&...args ) {
215219 // Dispatch to the correct SIMD instruction width based on the size of the blocks
216220 // It will try to use the widest SIMD instruction available for the given block sizes
217221 // TODO: fold expressions might be an option to simplify the code
218222 if (bl2_size >= 8 ) {
219- return iw4_accumulate_kernel <8 , 8 >(sign, M, M4_iw, bl1, bl2 );
223+ return kernel. template operator () <8 , 8 >(args... );
220224 } else if (bl2_size >= 4 ) {
221- return iw4_accumulate_kernel <8 , 4 >(sign, M, M4_iw, bl1, bl2 );
225+ return kernel. template operator () <8 , 4 >(args... );
222226 } else if (bl2_size >= 3 ) {
223- return iw4_accumulate_kernel <8 , 2 >(sign, M, M4_iw, bl1, bl2 );
227+ return kernel. template operator () <8 , 2 >(args... );
224228 } else if (bl2_size >= 2 ) {
225- return iw4_accumulate_kernel <4 , 2 >(sign, M, M4_iw, bl1, bl2 );
229+ return kernel. template operator () <4 , 2 >(args... );
226230 } else {
227- return iw4_accumulate_kernel <1 , 1 >(sign, M, M4_iw, bl1, bl2 );
231+ return kernel. template operator () <1 , 1 >(args... );
228232 }
229233 }
230234
235+ void iw4_accumulate (mc_weight_t sign, const auto &M, auto &M4_iw, const auto bl1, const auto bl2, const auto bl2_size) {
236+ kernel_dispatch<iw4_accumulate_kernel>(bl2_size, sign, M, M4_iw, bl1, bl2);
237+ }
238+
231239 void iw4ph_accumulate (mc_weight_t sign, const auto &M, auto &M4_iw, const auto bl1, const auto bl2, const auto bl2_size) {
232- // Dispatch to the correct SIMD instruction width based on the size of the blocks
233- // It will try to use the widest SIMD instruction available for the given block sizes
234- // TODO: fold expressions might be an option to simplify the code
235- if (bl2_size >= 8 ) {
236- return iw4ph_accumulate_kernel<8 , 8 >(sign, M, M4_iw, bl1, bl2);
237- } else if (bl2_size >= 4 ) {
238- return iw4ph_accumulate_kernel<8 , 4 >(sign, M, M4_iw, bl1, bl2);
239- } else if (bl2_size >= 3 ) {
240- return iw4ph_accumulate_kernel<8 , 2 >(sign, M, M4_iw, bl1, bl2);
241- } else if (bl2_size >= 2 ) {
242- return iw4ph_accumulate_kernel<4 , 2 >(sign, M, M4_iw, bl1, bl2);
243- } else {
244- return iw4ph_accumulate_kernel<1 , 1 >(sign, M, M4_iw, bl1, bl2);
245- }
240+ kernel_dispatch<iw4ph_accumulate_kernel>(bl2_size, sign, M, M4_iw, bl1, bl2);
246241 }
247242
248243 void iw4pp_accumulate (mc_weight_t sign, const auto &M, auto &M4_iw, const auto bl1, const auto bl2, const auto bl2_size) {
249- // Dispatch to the correct SIMD instruction width based on the size of the blocks
250- // It will try to use the widest SIMD instruction available for the given block sizes
251- // TODO: fold expressions might be an option to simplify the code
252- if (bl2_size >= 8 ) {
253- return iw4pp_accumulate_kernel<8 , 8 >(sign, M, M4_iw, bl1, bl2);
254- } else if (bl2_size >= 4 ) {
255- return iw4pp_accumulate_kernel<8 , 4 >(sign, M, M4_iw, bl1, bl2);
256- } else if (bl2_size >= 3 ) {
257- return iw4pp_accumulate_kernel<8 , 2 >(sign, M, M4_iw, bl1, bl2);
258- } else if (bl2_size >= 2 ) {
259- return iw4pp_accumulate_kernel<4 , 2 >(sign, M, M4_iw, bl1, bl2);
260- } else {
261- return iw4pp_accumulate_kernel<1 , 1 >(sign, M, M4_iw, bl1, bl2);
262- }
244+ kernel_dispatch<iw4pp_accumulate_kernel>(bl2_size, sign, M, M4_iw, bl1, bl2);
263245 }
264246
265247 void iw3ph_accumulate (mc_weight_t sign, const auto &M, const auto &GMG, const auto &GM, const auto &MG, auto &M4_iw, const auto bl1,
266248 const auto bl2, const auto bl2_size) {
267- // Dispatch to the correct SIMD instruction width based on the size of the blocks
268- // It will try to use the widest SIMD instruction available for the given block sizes
269- // TODO: fold expressions might be an option to simplify the code
270- if (bl2_size >= 8 ) {
271- return iw3ph_accumulate_kernel<8 , 8 >(sign, M, GMG, GM, MG, M4_iw, bl1, bl2);
272- } else if (bl2_size >= 4 ) {
273- return iw3ph_accumulate_kernel<8 , 4 >(sign, M, GMG, GM, MG, M4_iw, bl1, bl2);
274- } else if (bl2_size >= 3 ) {
275- return iw3ph_accumulate_kernel<8 , 2 >(sign, M, GMG, GM, MG, M4_iw, bl1, bl2);
276- } else if (bl2_size >= 2 ) {
277- return iw3ph_accumulate_kernel<4 , 2 >(sign, M, GMG, GM, MG, M4_iw, bl1, bl2);
278- } else {
279- return iw3ph_accumulate_kernel<1 , 1 >(sign, M, GMG, GM, MG, M4_iw, bl1, bl2);
280- }
249+ kernel_dispatch<iw3ph_accumulate_kernel>(bl2_size, sign, M, GMG, GM, MG, M4_iw, bl1, bl2);
281250 }
282251
283252 void iw3pp_accumulate (mc_weight_t sign, const auto &M, auto &M4_iw, const auto bl1, const auto bl2, const auto bl2_size) {
284- // Dispatch to the correct SIMD instruction width based on the size of the blocks
285- // It will try to use the widest SIMD instruction available for the given block sizes
286- // TODO: fold expressions might be an option to simplify the code
287- if (bl2_size >= 8 ) {
288- return iw3pp_accumulate_kernel<8 , 8 >(sign, M, M4_iw, bl1, bl2);
289- } else if (bl2_size >= 4 ) {
290- return iw3pp_accumulate_kernel<8 , 4 >(sign, M, M4_iw, bl1, bl2);
291- } else if (bl2_size >= 3 ) {
292- return iw3pp_accumulate_kernel<8 , 2 >(sign, M, M4_iw, bl1, bl2);
293- } else if (bl2_size >= 2 ) {
294- return iw3pp_accumulate_kernel<4 , 2 >(sign, M, M4_iw, bl1, bl2);
295- } else {
296- return iw3pp_accumulate_kernel<1 , 1 >(sign, M, M4_iw, bl1, bl2);
297- }
253+ kernel_dispatch<iw3pp_accumulate_kernel>(bl2_size, sign, M, M4_iw, bl1, bl2);
298254 }
255+
299256 } // namespace simd
300257} // namespace triqs_ctint::measures
0 commit comments