|
| 1 | +// ======================================================================================= |
| 2 | +// This is a simple test to ensure that data products are "written" or "output" to an |
| 3 | +// output node. |
| 4 | +// |
| 5 | +// N.B. Output nodes will eventually be replaced with preserver nodes. |
| 6 | +// ======================================================================================= |
| 7 | + |
| 8 | +#include "phlex/core/framework_graph.hpp" |
| 9 | +#include "phlex/model/data_cell_index.hpp" |
| 10 | +#include "plugins/layer_generator.hpp" |
| 11 | + |
| 12 | +#include "catch2/catch_test_macros.hpp" |
| 13 | + |
| 14 | +#include <ranges> |
| 15 | +#include <set> |
| 16 | +#include <string> |
| 17 | + |
| 18 | +using namespace phlex; |
| 19 | + |
| 20 | +namespace { |
| 21 | + class product_recorder { |
| 22 | + public: |
| 23 | + explicit product_recorder(std::set<std::string>& products) : products_{&products} {} |
| 24 | + |
| 25 | + void record(experimental::product_store const& store) |
| 26 | + { |
| 27 | + for (auto const& product_name : store | std::views::keys) { |
| 28 | + products_->insert(product_name); |
| 29 | + } |
| 30 | + } |
| 31 | + |
| 32 | + private: |
| 33 | + std::set<std::string>* products_; |
| 34 | + }; |
| 35 | +} |
| 36 | + |
| 37 | +TEST_CASE("Output data products", "[graph]") |
| 38 | +{ |
| 39 | + experimental::layer_generator gen; |
| 40 | + gen.add_layer("spill", {"job", 1u}); |
| 41 | + |
| 42 | + experimental::framework_graph g{driver_for_test(gen)}; |
| 43 | + |
| 44 | + g.provide("provide_number", [](data_cell_index const&) -> int { return 17; }) |
| 45 | + .output_product("number_from_provider"_in("spill")); |
| 46 | + |
| 47 | + g.transform( |
| 48 | + "square_number", |
| 49 | + [](int const number) -> int { return number * number; }, |
| 50 | + concurrency::unlimited) |
| 51 | + .input_family("number_from_provider"_in("spill")) |
| 52 | + .output_products("squared_number"); |
| 53 | + |
| 54 | + std::set<std::string> products_from_nodes; |
| 55 | + g.make<product_recorder>(products_from_nodes) |
| 56 | + .output("record_numbers", &product_recorder::record, concurrency::serial); |
| 57 | + |
| 58 | + g.execute(); |
| 59 | + |
| 60 | + CHECK(g.execution_count("provide_number") == 1u); |
| 61 | + CHECK(g.execution_count("square_number") == 1u); |
| 62 | + // The "record_numbers" output node should be executed twice: once to receive the data |
| 63 | + // store from the "provide_number" provider, and once to receive the data store from the |
| 64 | + // "square_number" transform. |
| 65 | + CHECK(g.execution_count("record_numbers") == 2u); |
| 66 | + CHECK(products_from_nodes == std::set<std::string>{"number_from_provider", "squared_number"}); |
| 67 | +} |
0 commit comments