Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 49 additions & 5 deletions include/ylt/coro_io/io_context_pool.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,25 @@
#include <async_simple/Executor.h>
#include <async_simple/coro/Lazy.h>

#include <any>
#include <asio/io_context.hpp>
#include <asio/post.hpp>
#include <asio/steady_timer.hpp>
#include <atomic>
#include <cstddef>
#include <cstdint>
#include <future>
#include <iostream>
#include <memory>
#include <mutex>
#include <thread>
#include <type_traits>
#include <utility>
#include <vector>

#include "asio/dispatch.hpp"
#include "asio/executor.hpp"
#include "async_simple/Common.h"
#include "async_simple/Signal.h"
#ifdef __linux__
#include <pthread.h>
Expand All @@ -48,12 +53,48 @@ template <typename ExecutorImpl = asio::io_context::executor_type>
class ExecutorWrapper : public async_simple::Executor {
private:
ExecutorImpl executor_;
std::unique_ptr<std::unordered_map<std::string, std::any>> user_defined_data_;

public:
ExecutorWrapper(ExecutorImpl executor) : executor_(executor) {}

using context_t = std::remove_cvref_t<decltype(executor_.context())>;

template <typename T>
std::any &set_data(std::string key, T data) {
if (!user_defined_data_) {
user_defined_data_ =
std::make_unique<std::unordered_map<std::string, std::any>>();
}
return (*user_defined_data_)[std::move(key)] = std::move(data);
}
template <typename T>
T *get_data_with_default(std::string key) {
if (!user_defined_data_) {
user_defined_data_ =
std::make_unique<std::unordered_map<std::string, std::any>>();
}
auto [iter, _] = user_defined_data_->try_emplace(key, T{});
return std::any_cast<T>(&iter->second);
}
template <typename T>
T *get_data(const std::string &key) {
if (!user_defined_data_) {
return nullptr;
}
auto iter = (*user_defined_data_).find(key);
if (iter == user_defined_data_->end()) {
return nullptr;
}
return std::any_cast<T>(&iter->second);
}

void clear_all_data() {
if (user_defined_data_) {
user_defined_data_ = nullptr;
}
}

virtual bool schedule(Func func) override {
asio::post(executor_, std::move(func));
return true;
Expand Down Expand Up @@ -85,14 +126,11 @@ class ExecutorWrapper : public async_simple::Executor {
operator ExecutorImpl() { return executor_; }

bool currentThreadInExecutor() const override {
auto ctx = get_current();
return *ctx == &executor_.context();
return executor_.running_in_this_thread();
}

size_t currentContextId() const override {
auto ctx = get_current();
auto ptr = *ctx;
return ptr ? (size_t)ptr : 0;
return (size_t)&executor_.context();
}

private:
Expand Down Expand Up @@ -216,6 +254,12 @@ class io_context_pool {
bool has_run_or_stop = false;
bool ok = has_run_or_stop_.compare_exchange_strong(has_run_or_stop, true);

for (auto &executor : executors) {
executor->schedule([&executor]() {
executor->clear_all_data();
});
}

work_.clear();

if (ok) {
Expand Down
15 changes: 9 additions & 6 deletions include/ylt/coro_rpc/impl/coro_rpc_server.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,11 @@ class coro_rpc_server_base {
#endif // YLT_ENABLE_NTLS
#endif
#ifdef YLT_ENABLE_IBV
if (!config.ibv_config.has_value() && !config.ibv_dev_lists.empty()) {
ibv_config_ = coro_io::ib_socket_t::config_t{};
}
if (config.ibv_config) {
init_ibv(config.ibv_config.value());
init_ibv(config.ibv_config.value(), std::move(config.ibv_dev_lists));
}
#endif
if (!config.acceptors.empty()) {
Expand Down Expand Up @@ -154,9 +157,9 @@ class coro_rpc_server_base {
#ifdef YLT_ENABLE_IBV
void init_ibv(
const coro_io::ib_socket_t::config_t &conf = {},
std::vector<std::shared_ptr<coro_io::ib_device_t>> ib_dev_lists = {}) {
std::vector<std::shared_ptr<coro_io::ib_device_t>> ibv_dev_lists = {}) {
ibv_config_ = conf;
ibv_dev_lists_ = std::move(ib_dev_lists);
ibv_dev_lists_ = std::move(ibv_dev_lists);
}
#endif

Expand All @@ -180,8 +183,8 @@ class coro_rpc_server_base {
}

public:
const std::vector<std::unique_ptr<coro_io::server_acceptor_base>>
&get_acceptors() const noexcept {
const std::vector<std::unique_ptr<coro_io::server_acceptor_base>> &
get_acceptors() const noexcept {
return acceptors_;
}
async_simple::Future<coro_rpc::err_code> async_start() noexcept {
Expand Down Expand Up @@ -224,7 +227,7 @@ class coro_rpc_server_base {
}
}
if (!errc_) {
if constexpr (requires(typename server_config::executor_pool_t & pool) {
if constexpr (requires(typename server_config::executor_pool_t &pool) {
pool.run();
}) {
thd_ = std::thread([this] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ struct config_t {
#endif
#ifdef YLT_ENABLE_IBV
std::optional<coro_io::ib_socket_t::config_t> ibv_config = std::nullopt;
std::vector<std::shared_ptr<coro_io::ib_device_t>> ib_dev_lists;
std::vector<std::shared_ptr<coro_io::ib_device_t>> ibv_dev_lists;
coro_io::load_balance_algorithm ib_dev_load_balance_algorithm =
coro_io::load_balance_algorithm::RR;
#endif
Expand Down
8 changes: 8 additions & 0 deletions include/ylt/util/random.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#include <random>
namespace ylt::util {
template <typename engine_type = std::default_random_engine>
inline engine_type& random_engine() {
static thread_local std::default_random_engine e(std::random_device{}());
return e;
}
} // namespace ylt::util
1 change: 1 addition & 0 deletions src/coro_io/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/output/tests)
add_executable(coro_io_test
test_io_context_pool.cpp
test_corofile.cpp
test_load_balancer.cpp
test_client_pool.cpp
Expand Down
186 changes: 186 additions & 0 deletions src/coro_io/tests/test_io_context_pool.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
#include <async_simple/coro/SyncAwait.h>
#include <doctest.h>

#include <string>
#include <thread>
#include <ylt/coro_io/io_context_pool.hpp>

using namespace async_simple::coro;

TEST_CASE("test ExecutorWrapper user data functionality") {
SUBCASE("test set_data and get_data") {
asio::io_context io_ctx;
coro_io::ExecutorWrapper<> executor(io_ctx.get_executor());

// 测试存储和获取字符串数据
std::string test_key = "test_string";
std::string test_value = "Hello World";
executor.set_data(test_key, test_value);

std::string* retrieved_value = executor.get_data<std::string>(test_key);
CHECK(retrieved_value != nullptr);
CHECK(*retrieved_value == test_value);
}

SUBCASE("test get_data with non-existent key") {
asio::io_context io_ctx;
coro_io::ExecutorWrapper<> executor(io_ctx.get_executor());

std::string* retrieved_value =
executor.get_data<std::string>("non_existent_key");
CHECK(retrieved_value == nullptr);
}

SUBCASE("test set_data and get_data with different types") {
asio::io_context io_ctx;
coro_io::ExecutorWrapper<> executor(io_ctx.get_executor());

// 存储整数
executor.set_data("int_value", 42);
int* int_ptr = executor.get_data<int>("int_value");
CHECK(int_ptr != nullptr);
CHECK(*int_ptr == 42);

// 存储布尔值
executor.set_data("bool_value", true);
bool* bool_ptr = executor.get_data<bool>("bool_value");
CHECK(bool_ptr != nullptr);
CHECK(*bool_ptr == true);

// 存储浮点数
executor.set_data("float_value", 3.14f);
float* float_ptr = executor.get_data<float>("float_value");
CHECK(float_ptr != nullptr);
CHECK(*float_ptr == 3.14f);
}

SUBCASE("test get_data_with_default") {
asio::io_context io_ctx;
coro_io::ExecutorWrapper<> executor(io_ctx.get_executor());

// 测试获取不存在的键,应该创建默认值
int* default_int = executor.get_data_with_default<int>("default_int");
CHECK(default_int != nullptr);
CHECK(*default_int == 0); // int类型的默认值

// 修改默认值
*default_int = 100;
int* retrieved_int = executor.get_data<int>("default_int");
CHECK(retrieved_int != nullptr);
CHECK(*retrieved_int == 100);

// 测试自定义类型的默认值
std::string* default_str =
executor.get_data_with_default<std::string>("default_str");
CHECK(default_str != nullptr);
CHECK(*default_str == ""); // string类型的默认值
}

SUBCASE("test clear_all_data") {
asio::io_context io_ctx;
coro_io::ExecutorWrapper<> executor(io_ctx.get_executor());

// 设置一些数据
executor.set_data("key1", std::string("value1"));
executor.set_data("key2", 123);

// 验证数据存在
std::string* str_ptr = executor.get_data<std::string>("key1");
int* int_ptr = executor.get_data<int>("key2");
CHECK(str_ptr != nullptr);
CHECK(int_ptr != nullptr);
CHECK(*str_ptr == "value1");
CHECK(*int_ptr == 123);

// 清除所有数据
executor.clear_all_data();

// 验证数据已被清除
str_ptr = executor.get_data<std::string>("key1");
int_ptr = executor.get_data<int>("key2");
CHECK(str_ptr == nullptr);
CHECK(int_ptr == nullptr);

// 但默认值应该被创建
str_ptr = executor.get_data_with_default<std::string>("key1");
CHECK(str_ptr != nullptr);
CHECK(*str_ptr == "");
}

SUBCASE("test ExecutorWrapper with io_context_pool") {
coro_io::io_context_pool pool(2);

// 获取执行器并测试数据存储
auto* executor = pool.get_executor();
executor->set_data("pool_data", std::string("test_data"));

std::string* retrieved = executor->get_data<std::string>("pool_data");
CHECK(retrieved != nullptr);
CHECK(*retrieved == "test_data");

// 测试从不同线程获取的执行器具有独立的数据存储
auto* executor2 = pool.get_executor();
std::string* retrieved2 = executor2->get_data<std::string>("pool_data");
CHECK(retrieved2 == nullptr);
executor2->set_data("pool_data", std::string("test_data2"));
retrieved2 = executor2->get_data<std::string>("pool_data");
CHECK(retrieved2 != nullptr);
CHECK(*retrieved2 == "test_data2");
CHECK(*retrieved == "test_data");
CHECK(retrieved != retrieved2);
}

SUBCASE("test multiple data types in same executor") {
asio::io_context io_ctx;
coro_io::ExecutorWrapper<> executor(io_ctx.get_executor());

// 存储多种类型的数据
executor.set_data("string_val", std::string("test"));
executor.set_data("int_val", 42);
executor.set_data("bool_val", true);
executor.set_data("double_val", 3.14159);

// 检索并验证所有数据
auto* str_val = executor.get_data<std::string>("string_val");
auto* int_val = executor.get_data<int>("int_val");
auto* bool_val = executor.get_data<bool>("bool_val");
auto* double_val = executor.get_data<double>("double_val");

CHECK(str_val != nullptr);
CHECK(int_val != nullptr);
CHECK(bool_val != nullptr);
CHECK(double_val != nullptr);

CHECK(*str_val == "test");
CHECK(*int_val == 42);
CHECK(*bool_val == true);
CHECK(*double_val == 3.14159);
}

SUBCASE("test executor data persistence after tasks") {
asio::io_context io_ctx;
coro_io::ExecutorWrapper<> executor(io_ctx.get_executor());

// 设置数据
executor.set_data("persistent_data", 100);

// 执行一个任务,确保数据仍然存在
bool task_executed = false;
executor.schedule([&task_executed, &executor]() {
task_executed = true;
auto* data = executor.get_data<int>("persistent_data");
CHECK(data != nullptr);
CHECK(*data == 100);
});

// 运行一次循环来处理任务
io_ctx.run_one();

CHECK(task_executed);

// 验证任务执行后数据仍然存在
auto* data_after_task = executor.get_data<int>("persistent_data");
CHECK(data_after_task != nullptr);
CHECK(*data_after_task == 100);
}
}
6 changes: 3 additions & 3 deletions src/coro_rpc/tests/test_acceptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,15 @@ TEST_CASE("test server acceptor") {

#ifdef YLT_ENABLE_IBV
SUBCASE("test multi rdma device for server") {
std::vector<std::shared_ptr<coro_io::ib_device_t>> ib_dev_lists;
std::vector<std::shared_ptr<coro_io::ib_device_t>> ibv_dev_lists;
for (auto &dev : coro_io::g_ib_device_manager()->get_dev_list()) {
ib_dev_lists.push_back(dev.second);
ibv_dev_lists.push_back(dev.second);
}
coro_rpc_server server(
coro_rpc::config_t{.port = 8824,
.thread_num = 1,
.ibv_config = {},
.ib_dev_lists = std::move(ib_dev_lists)});
.ibv_dev_lists = std::move(ibv_dev_lists)});
server.register_handler<test_rdma_multi_dev_server>();

auto res = server.async_start();
Expand Down
Loading