diff --git a/include/ylt/coro_io/io_context_pool.hpp b/include/ylt/coro_io/io_context_pool.hpp index 7ebc11fb0..5ac868e72 100644 --- a/include/ylt/coro_io/io_context_pool.hpp +++ b/include/ylt/coro_io/io_context_pool.hpp @@ -17,10 +17,12 @@ #include #include +#include #include #include #include #include +#include #include #include #include @@ -28,9 +30,12 @@ #include #include #include +#include #include #include "asio/dispatch.hpp" +#include "asio/executor.hpp" +#include "async_simple/Common.h" #include "async_simple/Signal.h" #ifdef __linux__ #include @@ -48,12 +53,48 @@ template class ExecutorWrapper : public async_simple::Executor { private: ExecutorImpl executor_; + std::unique_ptr> user_defined_data_; public: ExecutorWrapper(ExecutorImpl executor) : executor_(executor) {} using context_t = std::remove_cvref_t; + template + std::any &set_data(std::string key, T data) { + if (!user_defined_data_) { + user_defined_data_ = + std::make_unique>(); + } + return (*user_defined_data_)[std::move(key)] = std::move(data); + } + template + T *get_data_with_default(std::string key) { + if (!user_defined_data_) { + user_defined_data_ = + std::make_unique>(); + } + auto [iter, _] = user_defined_data_->try_emplace(key, T{}); + return std::any_cast(&iter->second); + } + template + T *get_data(const std::string &key) { + if (!user_defined_data_) { + return nullptr; + } + auto iter = (*user_defined_data_).find(key); + if (iter == user_defined_data_->end()) { + return nullptr; + } + return std::any_cast(&iter->second); + } + + void clear_all_data() { + if (user_defined_data_) { + user_defined_data_ = nullptr; + } + } + virtual bool schedule(Func func) override { asio::post(executor_, std::move(func)); return true; @@ -85,14 +126,11 @@ class ExecutorWrapper : public async_simple::Executor { operator ExecutorImpl() { return executor_; } bool currentThreadInExecutor() const override { - auto ctx = get_current(); - return *ctx == &executor_.context(); + return executor_.running_in_this_thread(); } size_t currentContextId() const override { - auto ctx = get_current(); - auto ptr = *ctx; - return ptr ? (size_t)ptr : 0; + return (size_t)&executor_.context(); } private: @@ -216,6 +254,12 @@ class io_context_pool { bool has_run_or_stop = false; bool ok = has_run_or_stop_.compare_exchange_strong(has_run_or_stop, true); + for (auto &executor : executors) { + executor->schedule([&executor]() { + executor->clear_all_data(); + }); + } + work_.clear(); if (ok) { diff --git a/include/ylt/coro_rpc/impl/coro_rpc_server.hpp b/include/ylt/coro_rpc/impl/coro_rpc_server.hpp index 05548b11f..0f695dc0f 100644 --- a/include/ylt/coro_rpc/impl/coro_rpc_server.hpp +++ b/include/ylt/coro_rpc/impl/coro_rpc_server.hpp @@ -123,8 +123,11 @@ class coro_rpc_server_base { #endif // YLT_ENABLE_NTLS #endif #ifdef YLT_ENABLE_IBV + if (!config.ibv_config.has_value() && !config.ibv_dev_lists.empty()) { + ibv_config_ = coro_io::ib_socket_t::config_t{}; + } if (config.ibv_config) { - init_ibv(config.ibv_config.value()); + init_ibv(config.ibv_config.value(), std::move(config.ibv_dev_lists)); } #endif if (!config.acceptors.empty()) { @@ -154,9 +157,9 @@ class coro_rpc_server_base { #ifdef YLT_ENABLE_IBV void init_ibv( const coro_io::ib_socket_t::config_t &conf = {}, - std::vector> ib_dev_lists = {}) { + std::vector> ibv_dev_lists = {}) { ibv_config_ = conf; - ibv_dev_lists_ = std::move(ib_dev_lists); + ibv_dev_lists_ = std::move(ibv_dev_lists); } #endif @@ -180,8 +183,8 @@ class coro_rpc_server_base { } public: - const std::vector> - &get_acceptors() const noexcept { + const std::vector> & + get_acceptors() const noexcept { return acceptors_; } async_simple::Future async_start() noexcept { @@ -224,7 +227,7 @@ class coro_rpc_server_base { } } if (!errc_) { - if constexpr (requires(typename server_config::executor_pool_t & pool) { + if constexpr (requires(typename server_config::executor_pool_t &pool) { pool.run(); }) { thd_ = std::thread([this] { diff --git a/include/ylt/coro_rpc/impl/default_config/coro_rpc_config.hpp b/include/ylt/coro_rpc/impl/default_config/coro_rpc_config.hpp index 667f206c1..3f8b2e861 100644 --- a/include/ylt/coro_rpc/impl/default_config/coro_rpc_config.hpp +++ b/include/ylt/coro_rpc/impl/default_config/coro_rpc_config.hpp @@ -48,7 +48,7 @@ struct config_t { #endif #ifdef YLT_ENABLE_IBV std::optional ibv_config = std::nullopt; - std::vector> ib_dev_lists; + std::vector> ibv_dev_lists; coro_io::load_balance_algorithm ib_dev_load_balance_algorithm = coro_io::load_balance_algorithm::RR; #endif diff --git a/include/ylt/util/random.hpp b/include/ylt/util/random.hpp new file mode 100644 index 000000000..bad4eb918 --- /dev/null +++ b/include/ylt/util/random.hpp @@ -0,0 +1,8 @@ +#include +namespace ylt::util { +template +inline engine_type& random_engine() { + static thread_local std::default_random_engine e(std::random_device{}()); + return e; +} +} // namespace ylt::util \ No newline at end of file diff --git a/src/coro_io/tests/CMakeLists.txt b/src/coro_io/tests/CMakeLists.txt index 109fdb394..298b39d02 100644 --- a/src/coro_io/tests/CMakeLists.txt +++ b/src/coro_io/tests/CMakeLists.txt @@ -1,5 +1,6 @@ set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/output/tests) add_executable(coro_io_test + test_io_context_pool.cpp test_corofile.cpp test_load_balancer.cpp test_client_pool.cpp diff --git a/src/coro_io/tests/test_io_context_pool.cpp b/src/coro_io/tests/test_io_context_pool.cpp new file mode 100644 index 000000000..54ea12481 --- /dev/null +++ b/src/coro_io/tests/test_io_context_pool.cpp @@ -0,0 +1,186 @@ +#include +#include + +#include +#include +#include + +using namespace async_simple::coro; + +TEST_CASE("test ExecutorWrapper user data functionality") { + SUBCASE("test set_data and get_data") { + asio::io_context io_ctx; + coro_io::ExecutorWrapper<> executor(io_ctx.get_executor()); + + // 测试存储和获取字符串数据 + std::string test_key = "test_string"; + std::string test_value = "Hello World"; + executor.set_data(test_key, test_value); + + std::string* retrieved_value = executor.get_data(test_key); + CHECK(retrieved_value != nullptr); + CHECK(*retrieved_value == test_value); + } + + SUBCASE("test get_data with non-existent key") { + asio::io_context io_ctx; + coro_io::ExecutorWrapper<> executor(io_ctx.get_executor()); + + std::string* retrieved_value = + executor.get_data("non_existent_key"); + CHECK(retrieved_value == nullptr); + } + + SUBCASE("test set_data and get_data with different types") { + asio::io_context io_ctx; + coro_io::ExecutorWrapper<> executor(io_ctx.get_executor()); + + // 存储整数 + executor.set_data("int_value", 42); + int* int_ptr = executor.get_data("int_value"); + CHECK(int_ptr != nullptr); + CHECK(*int_ptr == 42); + + // 存储布尔值 + executor.set_data("bool_value", true); + bool* bool_ptr = executor.get_data("bool_value"); + CHECK(bool_ptr != nullptr); + CHECK(*bool_ptr == true); + + // 存储浮点数 + executor.set_data("float_value", 3.14f); + float* float_ptr = executor.get_data("float_value"); + CHECK(float_ptr != nullptr); + CHECK(*float_ptr == 3.14f); + } + + SUBCASE("test get_data_with_default") { + asio::io_context io_ctx; + coro_io::ExecutorWrapper<> executor(io_ctx.get_executor()); + + // 测试获取不存在的键,应该创建默认值 + int* default_int = executor.get_data_with_default("default_int"); + CHECK(default_int != nullptr); + CHECK(*default_int == 0); // int类型的默认值 + + // 修改默认值 + *default_int = 100; + int* retrieved_int = executor.get_data("default_int"); + CHECK(retrieved_int != nullptr); + CHECK(*retrieved_int == 100); + + // 测试自定义类型的默认值 + std::string* default_str = + executor.get_data_with_default("default_str"); + CHECK(default_str != nullptr); + CHECK(*default_str == ""); // string类型的默认值 + } + + SUBCASE("test clear_all_data") { + asio::io_context io_ctx; + coro_io::ExecutorWrapper<> executor(io_ctx.get_executor()); + + // 设置一些数据 + executor.set_data("key1", std::string("value1")); + executor.set_data("key2", 123); + + // 验证数据存在 + std::string* str_ptr = executor.get_data("key1"); + int* int_ptr = executor.get_data("key2"); + CHECK(str_ptr != nullptr); + CHECK(int_ptr != nullptr); + CHECK(*str_ptr == "value1"); + CHECK(*int_ptr == 123); + + // 清除所有数据 + executor.clear_all_data(); + + // 验证数据已被清除 + str_ptr = executor.get_data("key1"); + int_ptr = executor.get_data("key2"); + CHECK(str_ptr == nullptr); + CHECK(int_ptr == nullptr); + + // 但默认值应该被创建 + str_ptr = executor.get_data_with_default("key1"); + CHECK(str_ptr != nullptr); + CHECK(*str_ptr == ""); + } + + SUBCASE("test ExecutorWrapper with io_context_pool") { + coro_io::io_context_pool pool(2); + + // 获取执行器并测试数据存储 + auto* executor = pool.get_executor(); + executor->set_data("pool_data", std::string("test_data")); + + std::string* retrieved = executor->get_data("pool_data"); + CHECK(retrieved != nullptr); + CHECK(*retrieved == "test_data"); + + // 测试从不同线程获取的执行器具有独立的数据存储 + auto* executor2 = pool.get_executor(); + std::string* retrieved2 = executor2->get_data("pool_data"); + CHECK(retrieved2 == nullptr); + executor2->set_data("pool_data", std::string("test_data2")); + retrieved2 = executor2->get_data("pool_data"); + CHECK(retrieved2 != nullptr); + CHECK(*retrieved2 == "test_data2"); + CHECK(*retrieved == "test_data"); + CHECK(retrieved != retrieved2); + } + + SUBCASE("test multiple data types in same executor") { + asio::io_context io_ctx; + coro_io::ExecutorWrapper<> executor(io_ctx.get_executor()); + + // 存储多种类型的数据 + executor.set_data("string_val", std::string("test")); + executor.set_data("int_val", 42); + executor.set_data("bool_val", true); + executor.set_data("double_val", 3.14159); + + // 检索并验证所有数据 + auto* str_val = executor.get_data("string_val"); + auto* int_val = executor.get_data("int_val"); + auto* bool_val = executor.get_data("bool_val"); + auto* double_val = executor.get_data("double_val"); + + CHECK(str_val != nullptr); + CHECK(int_val != nullptr); + CHECK(bool_val != nullptr); + CHECK(double_val != nullptr); + + CHECK(*str_val == "test"); + CHECK(*int_val == 42); + CHECK(*bool_val == true); + CHECK(*double_val == 3.14159); + } + + SUBCASE("test executor data persistence after tasks") { + asio::io_context io_ctx; + coro_io::ExecutorWrapper<> executor(io_ctx.get_executor()); + + // 设置数据 + executor.set_data("persistent_data", 100); + + // 执行一个任务,确保数据仍然存在 + bool task_executed = false; + executor.schedule([&task_executed, &executor]() { + task_executed = true; + auto* data = executor.get_data("persistent_data"); + CHECK(data != nullptr); + CHECK(*data == 100); + }); + + // 运行一次循环来处理任务 + io_ctx.run_one(); + + CHECK(task_executed); + + // 验证任务执行后数据仍然存在 + auto* data_after_task = executor.get_data("persistent_data"); + CHECK(data_after_task != nullptr); + CHECK(*data_after_task == 100); + } +} \ No newline at end of file diff --git a/src/coro_rpc/tests/test_acceptor.cpp b/src/coro_rpc/tests/test_acceptor.cpp index 8b5cff970..04cafedec 100644 --- a/src/coro_rpc/tests/test_acceptor.cpp +++ b/src/coro_rpc/tests/test_acceptor.cpp @@ -76,15 +76,15 @@ TEST_CASE("test server acceptor") { #ifdef YLT_ENABLE_IBV SUBCASE("test multi rdma device for server") { - std::vector> ib_dev_lists; + std::vector> ibv_dev_lists; for (auto &dev : coro_io::g_ib_device_manager()->get_dev_list()) { - ib_dev_lists.push_back(dev.second); + ibv_dev_lists.push_back(dev.second); } coro_rpc_server server( coro_rpc::config_t{.port = 8824, .thread_num = 1, .ibv_config = {}, - .ib_dev_lists = std::move(ib_dev_lists)}); + .ibv_dev_lists = std::move(ibv_dev_lists)}); server.register_handler(); auto res = server.async_start();