Skip to content

Commit 45811ad

Browse files
committed
Add test to verify that data products can be output from providers
1 parent fa5c12c commit 45811ad

File tree

5 files changed

+80
-1
lines changed

5 files changed

+80
-1
lines changed

phlex/core/declared_output.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,10 @@ namespace phlex::experimental {
99
tbb::flow::graph& g,
1010
detail::output_function_t&& ft) :
1111
consumer{std::move(name), std::move(predicates)},
12-
node_{g, concurrency, [f = std::move(ft)](message const& msg) -> tbb::flow::continue_msg {
12+
node_{g, concurrency, [this, f = std::move(ft)](message const& msg) -> tbb::flow::continue_msg {
1313
if (not msg.store->is_flush()) {
1414
f(*msg.store);
15+
++calls_;
1516
}
1617
return {};
1718
}}

phlex/core/declared_output.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,11 @@ namespace phlex::experimental {
2929
detail::output_function_t&& ft);
3030

3131
tbb::flow::receiver<message>& port() noexcept;
32+
std::size_t num_calls() const { return calls_; }
3233

3334
private:
3435
tbb::flow::function_node<message> node_;
36+
std::atomic<std::size_t> calls_;
3537
};
3638

3739
using declared_output_ptr = std::unique_ptr<declared_output>;

phlex/core/node_catalog.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ namespace phlex::experimental {
2626
if (auto node = providers.get(node_name)) {
2727
return node->num_calls();
2828
}
29+
if (auto node = outputs.get(node_name)) {
30+
return node->num_calls();
31+
}
2932
throw std::runtime_error("Unknown node type with name: "s + node_name);
3033
}
3134
}

test/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,12 @@ cet_test(
143143
Boost::json
144144
phlex::core
145145
)
146+
cet_test(
147+
output_products
148+
USE_CATCH2_MAIN
149+
SOURCE output_products.cpp
150+
LIBRARIES layer_generator phlex::core spdlog::spdlog
151+
)
146152
cet_test(
147153
data_cell_counting
148154
USE_CATCH2_MAIN

test/output_products.cpp

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
// =======================================================================================
2+
// This is a simple test to ensure that data products are "written" or "output" to an
3+
// output node.
4+
//
5+
// N.B. Output nodes will eventually be replaced with preserver nodes.
6+
// =======================================================================================
7+
8+
#include "phlex/core/framework_graph.hpp"
9+
#include "phlex/model/data_cell_index.hpp"
10+
#include "plugins/layer_generator.hpp"
11+
12+
#include "catch2/catch_test_macros.hpp"
13+
14+
#include <ranges>
15+
#include <set>
16+
#include <string>
17+
18+
using namespace phlex;
19+
20+
namespace {
21+
class product_recorder {
22+
public:
23+
explicit product_recorder(std::set<std::string>& products) : products_{&products} {}
24+
25+
void record(experimental::product_store const& store)
26+
{
27+
for (auto const& product_name : store | std::views::keys) {
28+
products_->insert(product_name);
29+
}
30+
}
31+
32+
private:
33+
std::set<std::string>* products_;
34+
};
35+
}
36+
37+
TEST_CASE("Output data products", "[graph]")
38+
{
39+
experimental::layer_generator gen;
40+
gen.add_layer("spill", {"job", 1u});
41+
42+
experimental::framework_graph g{driver_for_test(gen)};
43+
44+
g.provide("provide_number", [](data_cell_index const&) -> int { return 17; })
45+
.output_product("number_from_provider"_in("spill"));
46+
47+
g.transform(
48+
"square_number",
49+
[](int const number) -> int { return number * number; },
50+
concurrency::unlimited)
51+
.input_family("number_from_provider"_in("spill"))
52+
.output_products("squared_number");
53+
54+
std::set<std::string> products_from_nodes;
55+
g.make<product_recorder>(products_from_nodes)
56+
.output("record_numbers", &product_recorder::record, concurrency::serial);
57+
58+
g.execute();
59+
60+
CHECK(g.execution_count("provide_number") == 1u);
61+
CHECK(g.execution_count("square_number") == 1u);
62+
// The "record_numbers" output node should be executed twice: once to receive the data
63+
// store from the "provide_number" provider, and once to receive the data store from the
64+
// "square_number" transform.
65+
CHECK(g.execution_count("record_numbers") == 2u);
66+
CHECK(products_from_nodes == std::set<std::string>{"number_from_provider", "squared_number"});
67+
}

0 commit comments

Comments
 (0)