88#include " phlex/core/input_arguments.hpp"
99#include " phlex/core/message.hpp"
1010#include " phlex/core/products_consumer.hpp"
11- #include " phlex/core/registrar.hpp"
1211#include " phlex/core/specified_label.hpp"
1312#include " phlex/core/store_counters.hpp"
1413#include " phlex/model/algorithm_name.hpp"
@@ -52,8 +51,8 @@ namespace phlex::experimental {
5251
5352 // =====================================================================================
5453
55- template <typename AlgorithmBits>
56- class pre_fold {
54+ template <typename AlgorithmBits, typename InitTuple >
55+ class fold_node : public declared_fold , private count_stores {
5756 using all_parameter_types = typename AlgorithmBits::input_parameter_types;
5857 using input_parameter_types = skip_first_type<all_parameter_types>; // Skip fold object
5958 static constexpr auto N = std::tuple_size_v<input_parameter_types>;
@@ -62,146 +61,61 @@ namespace phlex::experimental {
6261 static constexpr std::size_t M = 1 ; // hard-coded for now
6362 using function_t = typename AlgorithmBits::bound_type;
6463
65- template <typename InitTuple>
66- class total_fold ;
67-
6864 public:
69- pre_fold (registrar<declared_fold_ptr> reg,
70- algorithm_name name,
71- std::size_t concurrency,
72- std::vector<std::string> predicates,
73- tbb::flow::graph& g,
74- AlgorithmBits alg,
75- specified_labels input_products) :
76- name_{std::move (name)},
77- concurrency_{concurrency},
78- predicates_{std::move (predicates)},
79- graph_{g},
80- ft_{alg.release_algorithm ()},
81- product_labels_{std::move (input_products)},
82- reg_{std::move (reg)}
83- {
84- }
85-
86- template <std::size_t Msize>
87- auto & to (std::array<std::string, Msize> output_keys)
88- {
89- static_assert (
90- M == Msize,
91- " The number of function parameters is not the same as the number of returned output "
92- " objects." );
93- std::ranges::transform (output_keys, output_names_.begin (), to_qualified_name{name_});
94- reg_.set_creator ([this ](auto ) { return create (std::make_tuple ()); });
95- return *this ;
96- }
97-
98- auto & to (std::convertible_to<std::string> auto &&... ts)
99- {
100- static_assert (
101- M == sizeof ...(ts),
102- " The number of function parameters is not the same as the number of returned output "
103- " objects." );
104- return to (std::array<std::string, M>{std::forward<decltype (ts)>(ts)...});
105- }
106-
107- auto & partitioned_by (std::string const & level_name)
108- {
109- partition_ = level_name;
110- return *this ;
111- }
112-
113- auto & initialized_with (auto &&... ts)
114- {
115- reg_.set_creator ([this , init = std::tuple{ts...}](auto ) { return create (std::move (init)); });
116- return *this ;
117- }
118-
119- private:
120- template <typename T>
121- declared_fold_ptr create (T init)
122- {
123- if (empty (partition_)) {
124- throw std::runtime_error (" The fold range must be specified using the 'over(...)' syntax." );
125- }
126- return std::make_unique<total_fold<decltype (init)>>(std::move (name_),
127- concurrency_,
128- std::move (predicates_),
129- graph_,
130- std::move (ft_),
131- std::move (init),
132- std::move (product_labels_),
133- std::move (output_names_),
134- std::move (partition_));
135- }
136-
137- algorithm_name name_;
138- std::size_t concurrency_;
139- std::vector<std::string> predicates_;
140- tbb::flow::graph& graph_;
141- function_t ft_;
142- specified_labels product_labels_;
143- std::string partition_{level_id::base ().level_name ()};
144- std::array<qualified_name, M> output_names_;
145- registrar<declared_fold_ptr> reg_;
146- };
147-
148- template <typename AlgorithmBits>
149- template <typename InitTuple>
150- class pre_fold <AlgorithmBits>::total_fold : public declared_fold, private count_stores {
151- public:
152- total_fold (algorithm_name name,
153- std::size_t concurrency,
154- std::vector<std::string> predicates,
155- tbb::flow::graph& g,
156- function_t && f,
157- InitTuple initializer,
158- specified_labels input_products,
159- std::array<qualified_name, M> output,
160- std::string partition) :
161- declared_fold{std::move (name), std::move (predicates), std::move (input_products)},
65+ fold_node (algorithm_name name,
66+ std::size_t concurrency,
67+ std::vector<std::string> predicates,
68+ tbb::flow::graph& g,
69+ AlgorithmBits alg,
70+ InitTuple initializer,
71+ specified_labels product_labels,
72+ std::vector<std::string> output,
73+ std::string partition) :
74+ declared_fold{std::move (name), std::move (predicates), std::move (product_labels)},
16275 initializer_{std::move (initializer)},
163- output_ (output.begin (), output.end()) ,
76+ output_{ to_qualified_names ( full_name (), std::move (output))} ,
16477 partition_{std::move (partition)},
16578 join_{make_join_or_none (g, std::make_index_sequence<N>{})},
166- fold_{
167- g, concurrency, [this , ft = std::move (f)](messages_t <N> const & messages, auto & outputs) {
168- // N.B. The assumption is that a fold will *never* need to cache
169- // the product store it creates. Any flush messages *do not* need
170- // to be propagated to downstream nodes.
171- auto const & msg = most_derived (messages);
172- auto const & [store, original_message_id] = std::tie (msg.store , msg.original_id );
173-
174- if (not store->is_flush () and not store->id ()->parent (partition_)) {
175- return ;
176- }
177-
178- if (store->is_flush ()) {
179- // Downstream nodes always get the flush.
180- get<0 >(outputs).try_put (msg);
181- if (store->id ()->level_name () != partition_) {
182- return ;
183- }
184- }
185-
186- auto const & fold_store = store->is_flush () ? store : store->parent (partition_);
187- assert (fold_store);
188- auto const & id_hash_for_counter = fold_store->id ()->hash ();
189-
190- if (store->is_flush ()) {
191- counter_for (id_hash_for_counter).set_flush_value (store, original_message_id);
192- } else {
193- call (ft, messages, std::make_index_sequence<N>{});
194- counter_for (id_hash_for_counter).increment (store->id ()->level_hash ());
195- }
196-
197- if (auto counter = done_with (id_hash_for_counter)) {
198- auto parent = fold_store->make_continuation (this ->full_name ());
199- commit_ (*parent);
200- ++product_count_;
201- // FIXME: This msg.eom value may be wrong!
202- get<0 >(outputs).try_put ({parent, msg.eom , counter->original_message_id ()});
203- }
204- }}
79+ fold_{g,
80+ concurrency,
81+ [this , ft = alg.release_algorithm ()](messages_t <N> const & messages, auto & outputs) {
82+ // N.B. The assumption is that a fold will *never* need to cache
83+ // the product store it creates. Any flush messages *do not* need
84+ // to be propagated to downstream nodes.
85+ auto const & msg = most_derived (messages);
86+ auto const & [store, original_message_id] = std::tie (msg.store , msg.original_id );
87+
88+ if (not store->is_flush () and not store->id ()->parent (partition_)) {
89+ return ;
90+ }
91+
92+ if (store->is_flush ()) {
93+ // Downstream nodes always get the flush.
94+ get<0 >(outputs).try_put (msg);
95+ if (store->id ()->level_name () != partition_) {
96+ return ;
97+ }
98+ }
99+
100+ auto const & fold_store = store->is_flush () ? store : store->parent (partition_);
101+ assert (fold_store);
102+ auto const & id_hash_for_counter = fold_store->id ()->hash ();
103+
104+ if (store->is_flush ()) {
105+ counter_for (id_hash_for_counter).set_flush_value (store, original_message_id);
106+ } else {
107+ call (ft, messages, std::make_index_sequence<N>{});
108+ counter_for (id_hash_for_counter).increment (store->id ()->level_hash ());
109+ }
110+
111+ if (auto counter = done_with (id_hash_for_counter)) {
112+ auto parent = fold_store->make_continuation (this ->full_name ());
113+ commit_ (*parent);
114+ ++product_count_;
115+ // FIXME: This msg.eom value may be wrong!
116+ get<0 >(outputs).try_put ({parent, msg.eom , counter->original_message_id ()});
117+ }
118+ }}
205119 {
206120 make_edge (join_, fold_);
207121 }
0 commit comments