diff --git a/doc/cfg/mududb_cfg.toml b/doc/cfg/mududb_cfg.toml index c486038..bcd42a4 100644 --- a/doc/cfg/mududb_cfg.toml +++ b/doc/cfg/mududb_cfg.toml @@ -6,7 +6,9 @@ mpk_path = "/your database path/mpk" # The path of database files -data_path = "/your database path/data" +# `data_path` is still accepted as a compatibility alias; new files should use +# the field name used by the Rust configuration struct. +db_path = "/your database path/data" listen_ip = "127.0.0.1" http_listen_port = 8300 @@ -16,13 +18,21 @@ pg_listen_port = 5432 # Enable the WASI component runtime used by the new backend. enable_async = true +# Component ABI target. Omit this field to use the default `p2`. +# Allowed values: "p2", "p3". +component_target = "p2" + # 0 = Legacy # 1 = IOUring +# 2 = Tokio server_mode = 1 # TCP port used by the custom framed client/server protocol. tcp_listen_port = 9527 +# Enable one TCP listener per worker when supported by the selected backend. +tcp_multi_port = false + # 0 means: detect the worker count from available CPU cores. io_uring_worker_threads = 0 diff --git a/doc/cn/how_to_start.cn.md b/doc/cn/how_to_start.cn.md index 80a52e5..1dd4db5 100644 --- a/doc/cn/how_to_start.cn.md +++ b/doc/cn/how_to_start.cn.md @@ -116,18 +116,23 @@ python script/build/install_binaries.py --all-workspace-bins ``` -## 创建配置文件 +## 配置文件 [mududb_cfg.toml 示例](../cfg/mududb_cfg.toml) -在以下位置创建配置文件: +`mudud` 会从以下位置读取配置: ```bash mkdir -p ${HOME}/.mududb -touch ${HOME}/.mududb/mududb_cfg.toml ``` -如果该文件不存在,`mudud` 首次启动时也会按默认值自动创建 `${HOME}/.mududb/mududb_cfg.toml`。 +不要创建空的 `mududb_cfg.toml`:服务端只要发现文件存在,就会把它当作用户配置解析。如果该文件不存在,`mudud` 首次启动时会按默认值自动创建 `${HOME}/.mududb/mududb_cfg.toml`。 + +如果要使用示例配置,可以复制: + +```bash +cp doc/cfg/mududb_cfg.toml ${HOME}/.mududb/mududb_cfg.toml +``` ## 使用 MuduDB diff --git a/doc/cn/partition.cn.md b/doc/cn/partition.cn.md index e44879e..414cdc8 100644 --- a/doc/cn/partition.cn.md +++ b/doc/cn/partition.cn.md @@ -226,6 +226,80 @@ worker 会按需懒创建自己需要访问的 partition relation。 当路由命中远端 worker 时,请求会按 partition placement 转发到该 worker 处理。 +## Port Sharding 策略 + +partition placement 负责决定一个逻辑 partition 归属哪个 worker。Port sharding 则把这些 worker 暴露为稳定的 +TCP 端口,让客户端和工具可以直接连接到目标数据所在的 worker。 + +当前 IOUring 和 Tokio 后端使用以下端口策略: + +- `tcp_listen_port` 是基础 TCP 端口 +- worker `0` 监听 `tcp_listen_port` +- worker `1` 监听 `tcp_listen_port + 1` +- worker `N` 监听 `tcp_listen_port + N` + +当后端 worker 数量大于 1 时,启动路径会自动开启 `tcp_multi_port`。例如 `tcp_listen_port = 9527` 且有 4 个 +worker 时,端口分布如下: + +```text +worker 0 -> 9527 +worker 1 -> 9528 +worker 2 -> 9529 +worker 3 -> 9530 +``` + +这是传输层的分片机制,不替代 partition placement。placement 仍然负责把 `partition_id` 映射到 `worker_id`; +port sharding 负责把目标 worker 映射到可连接的 TCP endpoint。 + +### 发现 Worker 端口 + +可以通过 HTTP 管理接口查看当前 topology: + +```bash +mcli --http-addr 127.0.0.1:8300 server-topology +``` + +返回结果包含: + +- `worker_count` +- `tcp_multi_port` +- `tcp_base_listen_port` +- 每个 worker 的 `worker_index` +- 每个 worker 的 `worker_id` +- 每个 worker 的 `tcp_listen_port` +- 当前与该 worker 关联的 partition id 列表 + +可以用 `partition-route` 查询某个 rule key 或 range 会路由到哪些 partition 和 worker: + +```bash +mcli --http-addr 127.0.0.1:8300 partition-route --rule-name r_orders --key 1001,50001 +``` + +客户端可以组合使用这两个接口: + +1. `partition-route` 解析目标 `partition_id` 和 `worker_id`。 +2. `server-topology` 把该 `worker_id` 解析为 `tcp_listen_port`。 +3. 客户端连接 `listen_ip:tcp_listen_port`。 + +### 作用 + +当调用方能在打开或重新绑定 session 前确定目标 partition 时,port sharding 可以减少不必要的 worker 间转发。 + +它提供的能力包括: + +- 让客户端直接亲和到 partition 所在 worker +- 减少 partition-local 流量的跨 worker 转发 +- 为 benchmark 和运维工具提供稳定的 per-worker endpoint +- 从外部观察 placement 如何映射到网络端口 + +### 限制 + +port sharding 目前绑定在多 worker 的 IOUring 和 Tokio 后端路径上。legacy 后端仍保持较早的单 TCP/PG serving +模型,不应当把它当作 per-worker 端口分片部署来使用。 + +客户端仍然必须尊重 partition placement。连接某个 worker 端口只是在选择执行 worker;它不会移动数据,也不会改变某个 +partition 的归属关系。 + ## 当前语义与限制 当前实现有明确边界。 @@ -235,6 +309,7 @@ worker 会按需懒创建自己需要访问的 partition relation。 - partition pruning 目前只围绕 key 列进行 - placement 是显式元数据,不是自动调度 - 远端 partition 访问通过 worker-to-worker RPC 完成 +- per-worker port sharding 可用于多 worker 的 IOUring 和 Tokio 后端路径 目前还存在一个重要限制: @@ -267,6 +342,7 @@ worker 会按需懒创建自己需要访问的 partition relation。 - worker 放置 - 物理存储 - 执行期路由 +- 可选的 per-worker TCP port sharding 这样的拆分可以保持 schema 模型干净,也为后续扩展打基础,例如: diff --git a/doc/en/how_to_start.md b/doc/en/how_to_start.md index af7f420..a62910c 100644 --- a/doc/en/how_to_start.md +++ b/doc/en/how_to_start.md @@ -115,18 +115,23 @@ python script/build/install_binaries.py --all-workspace-bins ``` -## Create a Configuration File +## Configuration File [mududb_cfg.toml example](../cfg/mududb_cfg.toml) -Create the configuration file at: +`mudud` reads its configuration from: ```bash mkdir -p ${HOME}/.mududb -touch ${HOME}/.mududb/mududb_cfg.toml ``` -If the file does not exist, `mudud` also creates `${HOME}/.mududb/mududb_cfg.toml` automatically on first start with default values. +Do not create an empty `mududb_cfg.toml`: the server treats an existing file as user configuration and parses it. If the file does not exist, `mudud` creates `${HOME}/.mududb/mududb_cfg.toml` automatically on first start with default values. + +To use the example configuration instead: + +```bash +cp doc/cfg/mududb_cfg.toml ${HOME}/.mududb/mududb_cfg.toml +``` ## Use MuduDB diff --git a/doc/en/partition.md b/doc/en/partition.md index 8ec1abd..41aeecf 100644 --- a/doc/en/partition.md +++ b/doc/en/partition.md @@ -228,6 +228,80 @@ Supported remote actions: Remote requests are routed by partition placement and executed by the worker that owns the target partition. +## Port Sharding + +Partition placement decides which worker owns a logical partition. Port sharding exposes those workers through stable +TCP ports so clients and tools can connect directly to the worker that owns the data they want to use. + +The current IOUring and Tokio backends use this policy: + +- `tcp_listen_port` is the base TCP port +- worker `0` listens on `tcp_listen_port` +- worker `1` listens on `tcp_listen_port + 1` +- worker `N` listens on `tcp_listen_port + N` + +When the backend has more than one worker, `tcp_multi_port` is enabled by the backend startup path. For example, with +`tcp_listen_port = 9527` and four workers, the worker ports are: + +```text +worker 0 -> 9527 +worker 1 -> 9528 +worker 2 -> 9529 +worker 3 -> 9530 +``` + +This is a transport-level sharding mechanism. It does not replace partition placement. Placement still maps +`partition_id` to `worker_id`; port sharding maps the selected worker to a reachable TCP endpoint. + +### Discovering Worker Ports + +Use the HTTP management API to read the live topology: + +```bash +mcli --http-addr 127.0.0.1:8300 server-topology +``` + +The response includes: + +- `worker_count` +- `tcp_multi_port` +- `tcp_base_listen_port` +- each worker's `worker_index` +- each worker's `worker_id` +- each worker's `tcp_listen_port` +- the partition ids currently associated with that worker + +Use `partition-route` to resolve a rule key or range to partition and worker ids: + +```bash +mcli --http-addr 127.0.0.1:8300 partition-route --rule-name r_orders --key 1001,50001 +``` + +A client can combine both calls: + +1. `partition-route` resolves the target `partition_id` and `worker_id`. +2. `server-topology` resolves that `worker_id` to `tcp_listen_port`. +3. The client connects to `listen_ip:tcp_listen_port`. + +### Why It Exists + +Port sharding is useful when the caller can choose the target partition before opening or rebinding a session. + +It enables: + +- direct client affinity to the owning worker +- fewer cross-worker forwards for partition-local traffic +- stable per-worker endpoints for benchmark and operational tooling +- an external way to inspect how placement maps onto network endpoints + +### Limits + +Port sharding is currently tied to the multi-worker IOUring and Tokio backend paths. The legacy backend keeps the older +single TCP/PG serving model and should not be treated as a per-worker port-sharded deployment. + +Clients still need to respect partition placement. Connecting to a worker port only selects an execution worker; it +does not move data or change which worker owns a partition. + ## Current Semantics and Limits The current implementation is intentionally scoped. @@ -237,6 +311,7 @@ The current implementation is intentionally scoped. - partition pruning is based on key columns, not arbitrary predicates - placement is explicit metadata - remote partition access uses worker-to-worker RPC +- per-worker port sharding is available on the multi-worker IOUring and Tokio backend paths There is still an important transactional limit: @@ -258,13 +333,17 @@ Avoid using the current implementation as if it were already a general distribut ## Summary -The partition subsystem separates: +The current partition subsystem separates: - logical partition definition - table binding - worker placement - physical storage -- execution-time routing +- runtime routing +- optional per-worker TCP port sharding + +This keeps the schema model separate from deployment topology while still allowing clients and tools to discover the +worker endpoint that owns a routed partition. This keeps the schema model clean and allows the engine to evolve toward partition rebalance, partition split or merge, and distributed commit in later iterations. diff --git a/mudu_cli/README.md b/mudu_cli/README.md index 29901bd..624efb9 100644 --- a/mudu_cli/README.md +++ b/mudu_cli/README.md @@ -19,6 +19,11 @@ It exposes these operations: - `invoke` - `app-install` - `app-invoke` +- `app-list` +- `app-detail` +- `app-uninstall` +- `server-topology` +- `partition-route` ## Examples @@ -106,6 +111,17 @@ mcli --addr 127.0.0.1:9527 --http-addr 127.0.0.1:8300 app-invoke --app kv --modu }' ``` +Management commands: + +```bash +mcli --http-addr 127.0.0.1:8300 app-list +mcli --http-addr 127.0.0.1:8300 app-detail --app wallet +mcli --http-addr 127.0.0.1:8300 app-detail --app wallet --module wallet --proc create_user +mcli --http-addr 127.0.0.1:8300 app-uninstall --app wallet +mcli --http-addr 127.0.0.1:8300 server-topology +mcli --http-addr 127.0.0.1:8300 partition-route --rule-name user_rule --key user-100 +``` + ## JSON input JSON request bodies can be supplied in three ways: diff --git a/mudu_cli/src/binding_api.rs b/mudu_cli/src/binding_api.rs index 4badc67..a587cca 100644 --- a/mudu_cli/src/binding_api.rs +++ b/mudu_cli/src/binding_api.rs @@ -42,6 +42,7 @@ pub struct MuduKeyValueBinding { #[derive(Debug, Clone, uniffi::Record)] pub struct WorkerTopologyBinding { pub worker_index: u64, + pub tcp_listen_port: u16, pub worker_id: String, pub partitions: Vec, } @@ -49,6 +50,8 @@ pub struct WorkerTopologyBinding { #[derive(Debug, Clone, uniffi::Record)] pub struct ServerTopologyBinding { pub worker_count: u64, + pub tcp_multi_port: bool, + pub tcp_base_listen_port: u16, pub workers: Vec, } @@ -278,11 +281,14 @@ fn to_server_topology_binding( ) -> ServerTopologyBinding { ServerTopologyBinding { worker_count: topology.worker_count as u64, + tcp_multi_port: topology.tcp_multi_port, + tcp_base_listen_port: topology.tcp_base_listen_port, workers: topology .workers .into_iter() .map(|worker| WorkerTopologyBinding { worker_index: worker.worker_index as u64, + tcp_listen_port: worker.tcp_listen_port, worker_id: worker.worker_id.to_string(), partitions: worker .partitions diff --git a/mudu_cli/src/management.rs b/mudu_cli/src/management.rs index a0abebb..2644e44 100644 --- a/mudu_cli/src/management.rs +++ b/mudu_cli/src/management.rs @@ -50,6 +50,8 @@ where #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub struct WorkerTopology { pub worker_index: usize, + #[serde(default)] + pub tcp_listen_port: u16, #[serde( serialize_with = "serialize_oid_as_unioid", deserialize_with = "deserialize_oid_from_unioid" @@ -65,9 +67,39 @@ pub struct WorkerTopology { #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub struct ServerTopology { pub worker_count: usize, + #[serde(default)] + pub tcp_multi_port: bool, + #[serde(default)] + pub tcp_base_listen_port: u16, pub workers: Vec, } +impl ServerTopology { + pub fn worker_port_by_index(&self, worker_index: usize) -> Option { + self.workers + .iter() + .find(|w| w.worker_index == worker_index) + .map(|w| w.tcp_listen_port) + } + + pub fn worker_port_by_id(&self, worker_id: OID) -> Option { + self.workers + .iter() + .find(|w| w.worker_id == worker_id) + .map(|w| w.tcp_listen_port) + } + + pub fn worker_addr_by_index(&self, listen_ip: &str, worker_index: usize) -> Option { + self.worker_port_by_index(worker_index) + .map(|port| format!("{}:{}", listen_ip, port)) + } + + pub fn worker_addr_by_id(&self, listen_ip: &str, worker_id: OID) -> Option { + self.worker_port_by_id(worker_id) + .map(|port| format!("{}:{}", listen_ip, port)) + } +} + #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub struct PartitionRouteEntry { #[serde( @@ -280,6 +312,7 @@ mod tests { fn worker_topology_round_trips_oid_as_unioid() { let worker = WorkerTopology { worker_index: 0, + tcp_listen_port: 9527, worker_id: (1u128 << 100) + 7, partitions: vec![(1u128 << 99) + 3], }; @@ -316,4 +349,35 @@ mod tests { )); assert!(!is_server_topology_unsupported("connection refused")); } + + #[test] + fn topology_resolves_worker_addr() { + let topology = ServerTopology { + worker_count: 2, + tcp_multi_port: true, + tcp_base_listen_port: 9527, + workers: vec![ + WorkerTopology { + worker_index: 0, + tcp_listen_port: 9527, + worker_id: 11, + partitions: vec![], + }, + WorkerTopology { + worker_index: 1, + tcp_listen_port: 9528, + worker_id: 22, + partitions: vec![], + }, + ], + }; + assert_eq!( + topology.worker_addr_by_index("127.0.0.1", 1), + Some("127.0.0.1:9528".to_string()) + ); + assert_eq!( + topology.worker_addr_by_id("127.0.0.1", 11), + Some("127.0.0.1:9527".to_string()) + ); + } } diff --git a/mudu_kernel/Cargo.toml b/mudu_kernel/Cargo.toml index 3b764ca..4610d8c 100644 --- a/mudu_kernel/Cargo.toml +++ b/mudu_kernel/Cargo.toml @@ -31,7 +31,7 @@ tracing = { workspace = true } pgwire = { workspace = true } futures = { workspace = true } -scc = {workspace = true} +scc = { workspace = true } project-root = { workspace = true } csv-async = { version = "1.3.0" } diff --git a/mudu_kernel/fuzz/fuzz_targets/_de_en_x_l_batch.rs b/mudu_kernel/fuzz/fuzz_targets/_de_en_x_l_batch.rs index d1134d4..94adf80 100644 --- a/mudu_kernel/fuzz/fuzz_targets/_de_en_x_l_batch.rs +++ b/mudu_kernel/fuzz/fuzz_targets/_de_en_x_l_batch.rs @@ -4,6 +4,6 @@ extern crate libfuzzer_sys; use mududb::fuzz::_fuzz_run::_target; -fuzz_target!(|param:&[u8]| { +fuzz_target!(|param: &[u8]| { _target("_de_en_x_l_batch", param); -}); \ No newline at end of file +}); diff --git a/mudu_kernel/fuzz/fuzz_targets/_de_en_x_l_up_tuple.rs b/mudu_kernel/fuzz/fuzz_targets/_de_en_x_l_up_tuple.rs index 8a37bd4..f31992c 100644 --- a/mudu_kernel/fuzz/fuzz_targets/_de_en_x_l_up_tuple.rs +++ b/mudu_kernel/fuzz/fuzz_targets/_de_en_x_l_up_tuple.rs @@ -4,6 +4,6 @@ extern crate libfuzzer_sys; use mududb::fuzz::_fuzz_run::_target; -fuzz_target!(|param:&[u8]| { +fuzz_target!(|param: &[u8]| { _target("_de_en_x_l_up_tuple", param); -}); \ No newline at end of file +}); diff --git a/mudu_kernel/fuzz/fuzz_targets/_delta_apply.rs b/mudu_kernel/fuzz/fuzz_targets/_delta_apply.rs index 26c2294..654ae28 100644 --- a/mudu_kernel/fuzz/fuzz_targets/_delta_apply.rs +++ b/mudu_kernel/fuzz/fuzz_targets/_delta_apply.rs @@ -4,6 +4,6 @@ extern crate libfuzzer_sys; use mududb::fuzz::_fuzz_run::_target; -fuzz_target!(|param:&[u8]| { +fuzz_target!(|param: &[u8]| { _target("_delta_apply", param); }); diff --git a/mudu_kernel/fuzz/fuzz_targets/_gen_order_csv.rs b/mudu_kernel/fuzz/fuzz_targets/_gen_order_csv.rs index 02dddeb..8e7c6e1 100644 --- a/mudu_kernel/fuzz/fuzz_targets/_gen_order_csv.rs +++ b/mudu_kernel/fuzz/fuzz_targets/_gen_order_csv.rs @@ -4,6 +4,6 @@ extern crate libfuzzer_sys; use mududb::fuzz::_fuzz_run::_target; -fuzz_target!(|param:&[u8]| { +fuzz_target!(|param: &[u8]| { _target("_gen_order_csv", param); }); diff --git a/mudu_kernel/fuzz/fuzz_targets/_schema_table.rs b/mudu_kernel/fuzz/fuzz_targets/_schema_table.rs index 6a32df5..8941426 100644 --- a/mudu_kernel/fuzz/fuzz_targets/_schema_table.rs +++ b/mudu_kernel/fuzz/fuzz_targets/_schema_table.rs @@ -4,6 +4,6 @@ extern crate libfuzzer_sys; use mududb::fuzz::_fuzz_run::_target; -fuzz_target!(|param:&[u8]| { +fuzz_target!(|param: &[u8]| { _target("_schema_table", param); -}); \ No newline at end of file +}); diff --git a/mudu_kernel/fuzz/fuzz_targets/_type_convert.rs b/mudu_kernel/fuzz/fuzz_targets/_type_convert.rs index 3357676..4a202f4 100644 --- a/mudu_kernel/fuzz/fuzz_targets/_type_convert.rs +++ b/mudu_kernel/fuzz/fuzz_targets/_type_convert.rs @@ -4,6 +4,6 @@ extern crate libfuzzer_sys; use mududb::fuzz::_fuzz_run::_target; -fuzz_target!(|param:&[u8]| { +fuzz_target!(|param: &[u8]| { _target("_type_convert", param); }); diff --git a/mudu_kernel/fuzz/fuzz_targets/_x_log_append.rs b/mudu_kernel/fuzz/fuzz_targets/_x_log_append.rs index 5967793..8523822 100644 --- a/mudu_kernel/fuzz/fuzz_targets/_x_log_append.rs +++ b/mudu_kernel/fuzz/fuzz_targets/_x_log_append.rs @@ -4,6 +4,6 @@ extern crate libfuzzer_sys; use mududb::fuzz::_fuzz_run::_target; -fuzz_target!(|param:&[u8]| { +fuzz_target!(|param: &[u8]| { _target("_x_log_append", param); }); diff --git a/mudu_kernel/fuzz/src/lib.rs b/mudu_kernel/fuzz/src/lib.rs index e69de29..8b13789 100644 --- a/mudu_kernel/fuzz/src/lib.rs +++ b/mudu_kernel/fuzz/src/lib.rs @@ -0,0 +1 @@ + diff --git a/mudu_kernel/src/command/load_from_file.rs b/mudu_kernel/src/command/load_from_file.rs index a520f4e..411d246 100644 --- a/mudu_kernel/src/command/load_from_file.rs +++ b/mudu_kernel/src/command/load_from_file.rs @@ -10,6 +10,7 @@ use mudu::common::result::RS; use mudu::error::ec::EC as ER; use mudu::m_error; use mudu_type::dat_type_id::DatTypeID; +use mudu_utils::scoped_task_trace; use mudu_utils::sync::a_mutex::AMutex; use std::io::Cursor; use std::sync::Arc; @@ -94,6 +95,7 @@ impl _LoadFromFile { } async fn load_table(&self) -> RS { + scoped_task_trace!(); debug!( table_id = self.table_id, csv_file = %self.csv_file, @@ -290,6 +292,7 @@ impl CmdExec for LoadFromFile { } async fn run(&self) -> RS<()> { + scoped_task_trace!(); let mut inner = self.inner.lock().await; let rows = inner.load_table().await?; inner.set_affected_rows(rows); diff --git a/mudu_kernel/src/io/linux/socket.rs b/mudu_kernel/src/io/linux/socket.rs index f454ddf..45f6a20 100644 --- a/mudu_kernel/src/io/linux/socket.rs +++ b/mudu_kernel/src/io/linux/socket.rs @@ -159,10 +159,6 @@ impl IoSocket { pub(crate) fn from_raw_fd(fd: RawFd) -> Self { Self { fd } } - - pub(crate) fn into_raw_fd(self) -> RawFd { - self.fd - } } impl SocketOpenRequest { diff --git a/mudu_kernel/src/meta/meta_mgr.rs b/mudu_kernel/src/meta/meta_mgr.rs index bbfd827..da2d69c 100644 --- a/mudu_kernel/src/meta/meta_mgr.rs +++ b/mudu_kernel/src/meta/meta_mgr.rs @@ -490,11 +490,11 @@ unsafe impl Send for MetaMgrImpl {} #[cfg(test)] mod tests { - use std::env::temp_dir; - use std::future::Future; + use crate::contract::schema_column::SchemaColumn; use mudu_type::dat_type_id::DatTypeID; use mudu_type::dt_info::DTInfo; - use crate::contract::schema_column::SchemaColumn; + use std::env::temp_dir; + use std::future::Future; use super::*; @@ -545,14 +545,12 @@ mod tests { let reopened = MetaMgrImpl::new(&dir).unwrap(); let schema_id = schema.id(); - let table = block_on(async move { - reopened.get_table_by_id(schema_id).await - }).unwrap(); + let table = block_on(async move { reopened.get_table_by_id(schema_id).await }).unwrap(); assert_eq!(table.name(), schema.table_name()); } #[test] - fn meta_mgr_broadcasts_ddl_to_peer_instances() { + fn meta_mgr_broadcasts_ddl_to_peer_instances() { block_on(async move { let r = _meta_mgr_broadcasts_ddl_to_peer_instances().await; assert!(r.is_ok()); @@ -574,7 +572,4 @@ mod tests { assert!(mgr1.get_table_by_id(schema.id()).await.is_err()); Ok(()) } - - - } diff --git a/mudu_kernel/src/server/async_func_task.rs b/mudu_kernel/src/server/async_func_task.rs index 1c50c2b..ae89b84 100644 --- a/mudu_kernel/src/server/async_func_task.rs +++ b/mudu_kernel/src/server/async_func_task.rs @@ -1,7 +1,5 @@ #![allow(dead_code)] -use crate::server::routing::SessionOpenTransferAction; -use mudu::common::id::OID; use mudu::common::result::RS; use mudu_utils::task_id::TaskID; use std::future::Future; @@ -21,40 +19,6 @@ pub struct AsyncFuncTask { pub(in crate::server) enum HandleResult { Response(Vec), - Transfer(SessionTransferDispatch), -} - -#[derive(Clone)] -pub(in crate::server) struct SessionTransferDispatch { - target_worker: usize, - session_ids: Vec, - action: SessionOpenTransferAction, -} - -impl SessionTransferDispatch { - pub(in crate::server) fn new( - target_worker: usize, - session_ids: Vec, - action: SessionOpenTransferAction, - ) -> Self { - Self { - target_worker, - session_ids, - action, - } - } - - pub(in crate::server) fn target_worker(&self) -> usize { - self.target_worker - } - - pub(in crate::server) fn session_ids(&self) -> &[OID] { - &self.session_ids - } - - pub(in crate::server) fn action(&self) -> SessionOpenTransferAction { - self.action - } } impl AsyncFuncTask { diff --git a/mudu_kernel/src/server/linux/connection_worker_task.rs b/mudu_kernel/src/server/linux/connection_worker_task.rs index 742aa5b..74a41a9 100644 --- a/mudu_kernel/src/server/linux/connection_worker_task.rs +++ b/mudu_kernel/src/server/linux/connection_worker_task.rs @@ -1,14 +1,9 @@ use crate::io::socket::{close, IoSocket}; -use crate::server::async_func_task::{HandleResult, SessionTransferDispatch}; +use crate::server::async_func_task::HandleResult; use crate::server::frame_dispatch::dispatch_frame_async; use crate::server::protocol_codec::{read_next_frame, write_response}; -use crate::server::routing::ConnectionTransfer; -use crate::server::transferred_connection::TransferredConnection; use crate::server::worker::WorkerRuntime; -use crate::server::worker_mailbox::WorkerMailboxMsg; -use crate::server::worker_ring_loop::WorkerRingLoop; use crate::server::worker_task::WorkerTaskFuture; -use crossbeam_queue::SegQueue; use mudu::common::result::RS; use mudu_contract::protocol::encode_merror_response; use std::net::SocketAddr; @@ -18,8 +13,6 @@ use tracing::trace; pub(in crate::server) fn spawn_connection_worker_task( worker: WorkerRuntime, - mailbox_fds: Vec, - mailboxes: Vec>>, connections: Arc>, conn_id: u64, socket: IoSocket, @@ -29,8 +22,6 @@ pub(in crate::server) fn spawn_connection_worker_task( Box::pin(async move { run_connection_worker_task( worker, - mailbox_fds, - mailboxes, connections, conn_id, socket, @@ -43,8 +34,6 @@ pub(in crate::server) fn spawn_connection_worker_task( async fn run_connection_worker_task( worker: WorkerRuntime, - mailbox_fds: Vec, - mailboxes: Vec>>, connections: Arc>, conn_id: u64, socket: IoSocket, @@ -52,23 +41,13 @@ async fn run_connection_worker_task( initial_response: Option>, ) -> RS<()> { mudu_utils::scoped_task_trace!(); - let r = _run_connection_worker_task( - worker, - mailbox_fds, - mailboxes, - conn_id, - socket, - remote_addr, - initial_response, - ) - .await; + let r = + _run_connection_worker_task(worker, conn_id, socket, remote_addr, initial_response).await; let _ = connections.remove_sync(&conn_id); r } async fn _run_connection_worker_task( worker: WorkerRuntime, - mailbox_fds: Vec, - mailboxes: Vec>>, conn_id: u64, socket: IoSocket, remote_addr: SocketAddr, @@ -127,28 +106,6 @@ async fn _run_connection_worker_task( ); write_response(socket.as_ref().unwrap(), &response).await?; } - Ok(HandleResult::Transfer(transfer)) => { - trace!( - conn_id, - request_id, - target_worker = transfer.target_worker(), - session_count = transfer.session_ids().len(), - "dispatch requested connection transfer" - ); - let connection = build_transfer( - conn_id, - remote_addr, - socket.take().unwrap(), - transfer.clone(), - ); - WorkerRingLoop::dispatch_mailbox_message( - &mailbox_fds, - &mailboxes, - connection.transfer().target_worker(), - WorkerMailboxMsg::AdoptConnection(connection), - )?; - break; - } Err(err) => { trace!( conn_id, @@ -165,22 +122,3 @@ async fn _run_connection_worker_task( trace!(conn_id, "io_uring connection worker stopped"); Ok(()) } - -fn build_transfer( - conn_id: u64, - remote_addr: SocketAddr, - socket: IoSocket, - transfer: SessionTransferDispatch, -) -> TransferredConnection { - TransferredConnection::new( - ConnectionTransfer::new( - conn_id, - transfer.target_worker(), - crate::server::connection_state::ConnectionState::Active, - remote_addr, - ), - socket.into_raw_fd(), - transfer.session_ids().to_vec(), - Some(transfer.action()), - ) -} diff --git a/mudu_kernel/src/server/linux/message_bus_runtime.rs b/mudu_kernel/src/server/linux/message_bus_runtime.rs index bb45caf..a8ce0f6 100644 --- a/mudu_kernel/src/server/linux/message_bus_runtime.rs +++ b/mudu_kernel/src/server/linux/message_bus_runtime.rs @@ -80,7 +80,10 @@ impl WorkerMessageBus { "message_bus dispatching callback task" ); let future = (callback)(envelope); - task::spawn_system("iouring-message-bus-callback", spawn_system_worker_task(future)); + task::spawn_system( + "iouring-message-bus-callback", + spawn_system_worker_task(future), + ); } Ok(()) } @@ -195,7 +198,10 @@ impl MessageBus for WorkerMessageBus { }; if let Some(envelope) = maybe_envelope { let future = (callback)(envelope); - task::spawn_system("iouring-message-bus-on-recv", spawn_system_worker_task(future)); + task::spawn_system( + "iouring-message-bus-on-recv", + spawn_system_worker_task(future), + ); } Ok(callback_id) } diff --git a/mudu_kernel/src/server/linux/perf_test.rs b/mudu_kernel/src/server/linux/perf_test.rs index a29af80..3df51d9 100644 --- a/mudu_kernel/src/server/linux/perf_test.rs +++ b/mudu_kernel/src/server/linux/perf_test.rs @@ -1,5 +1,8 @@ -use crate::server::routing::{route_worker, RoutingContext, RoutingMode}; -use crate::server::server::{WorkerTcpBackend, WorkerTcpServerConfig}; +use crate::server::routing::RoutingMode; +use crate::server::server::WorkerTcpBackend; +use crate::server::server_cfg::ServerCfg; +use crate::server::server_launch::ServerLaunch; +use crate::server::server_runtime_deps::ServerRuntimeDeps; use crate::server::worker_registry::{load_or_create_worker_registry, WorkerRegistry}; use mudu::common::result::RS; use mudu::error::ec::EC; @@ -164,6 +167,27 @@ impl AsyncPerfClient { } } +fn test_worker_port(base_port: u16, worker_index: usize) -> RS { + let offset = u16::try_from(worker_index).map_err(|_| { + m_error!( + EC::ParseErr, + format!( + "worker index too large for test port mapping: {}", + worker_index + ) + ) + })?; + base_port.checked_add(offset).ok_or_else(|| { + m_error!( + EC::ParseErr, + format!( + "test worker port overflow: base_port={}, worker_index={}", + base_port, worker_index + ) + ) + }) +} + fn loopback_shard_ip(shard: usize) -> Ipv4Addr { // Linux routes the entire 127.0.0.0/8 block to loopback. Spreading // clients across multiple source IPs expands the available 4-tuple space @@ -345,6 +369,15 @@ async fn wait_until_server_ready_or_exit_async(port: u16, server: &TestServerHan )) } +async fn wait_until_worker_port_ready_or_exit_async( + base_port: u16, + worker_index: usize, + server: &TestServerHandle, +) -> RS<()> { + let worker_port = test_worker_port(base_port, worker_index)?; + wait_until_server_ready_or_exit_async(worker_port, server).await +} + fn spawn_iouring_server( listener: TcpListener, worker_count: usize, @@ -355,23 +388,26 @@ fn spawn_iouring_server( let (stop_notifier, server_stop) = notify_wait(); let (exit_tx, exit_rx) = mpsc::channel(); let port = listener.local_addr().unwrap().port(); - let mut server_cfg = WorkerTcpServerConfig::new( + let server_cfg = ServerCfg::new( worker_count, "127.0.0.1".to_string(), port, data_dir.to_string_lossy().into_owned(), data_dir.to_string_lossy().into_owned(), RoutingMode::ConnectionId, - None, ) .unwrap() - .with_prebound_listener(listener) + .with_multi_port(worker_count > 1) .with_log_chunk_size(log_chunk_size); + let mut server_deps = ServerRuntimeDeps::from_cfg(&server_cfg).unwrap(); if let Some(worker_registry) = worker_registry { - server_cfg = server_cfg.with_worker_registry(worker_registry).unwrap(); + server_deps = server_deps + .with_worker_registry(&server_cfg, worker_registry) + .unwrap(); } + let server_launch = ServerLaunch::new(server_cfg, server_deps).with_prebound_listener(listener); let join_handle = thread::spawn(move || { - let result = WorkerTcpBackend::sync_serve_with_stop(server_cfg, server_stop); + let result = WorkerTcpBackend::sync_serve_with_stop(server_launch, server_stop); let exit_msg = match &result { Ok(()) => Ok(()), Err(err) => Err(err.to_string()), @@ -635,7 +671,7 @@ async fn iouring_backend_recovery_replays_worker_logs() -> RS<()> { } { - let mut client = AsyncPerfClient::connect(port).await?; + let mut client = AsyncPerfClient::connect(test_worker_port(port, 0)?).await?; client.session_id = client .create_session(Some( serde_json::json!({ @@ -675,7 +711,7 @@ async fn iouring_backend_recovery_replays_worker_logs() -> RS<()> { } { - let mut client = AsyncPerfClient::connect(restart_port).await?; + let mut client = AsyncPerfClient::connect(test_worker_port(restart_port, 0)?).await?; client.session_id = client .create_session(Some( serde_json::json!({ @@ -796,12 +832,7 @@ async fn iouring_backend_open_session_routes_connection_to_requested_partition() )); std::fs::create_dir_all(&data_dir).unwrap(); - let initial_partition = route_worker( - &RoutingContext::new(1, "127.0.0.1:10000".parse().unwrap(), None), - RoutingMode::ConnectionId, - worker_count, - ); - let target_partition = (initial_partition + 1) % worker_count; + let target_partition = 1usize; let registry = load_or_create_worker_registry(&data_dir, worker_count)?; let target_worker = registry.worker(target_partition).unwrap(); @@ -819,9 +850,19 @@ async fn iouring_backend_open_session_routes_connection_to_requested_partition() } return Err(err); } + if let Err(err) = + wait_until_worker_port_ready_or_exit_async(port, target_partition, &server_thread).await + { + if should_skip_iouring_perf(&err) { + eprintln!("skip io_uring route test: {}", err); + return Ok(()); + } + return Err(err); + } { - let mut client = AsyncPerfClient::connect(port).await?; + let mut client = + AsyncPerfClient::connect(test_worker_port(port, target_partition)?).await?; let session_id = client .create_session(Some( serde_json::json!({ @@ -882,12 +923,7 @@ async fn iouring_backend_open_session_rebind_keeps_same_session_id() -> RS<()> { )); std::fs::create_dir_all(&data_dir).unwrap(); - let initial_partition = route_worker( - &RoutingContext::new(1, "127.0.0.1:10001".parse().unwrap(), None), - RoutingMode::ConnectionId, - worker_count, - ); - let target_partition = (initial_partition + 1) % worker_count; + let target_partition = 1usize; let registry = load_or_create_worker_registry(&data_dir, worker_count)?; let target_worker = registry.worker(target_partition).unwrap(); @@ -905,9 +941,19 @@ async fn iouring_backend_open_session_rebind_keeps_same_session_id() -> RS<()> { } return Err(err); } + if let Err(err) = + wait_until_worker_port_ready_or_exit_async(port, target_partition, &server_thread).await + { + if should_skip_iouring_perf(&err) { + eprintln!("skip io_uring rebind test: {}", err); + return Ok(()); + } + return Err(err); + } { - let mut client = AsyncPerfClient::connect(port).await?; + let mut client = + AsyncPerfClient::connect(test_worker_port(port, target_partition)?).await?; let original_session_id = client.session_id; let rebound_session_id = client .create_session(Some( diff --git a/mudu_kernel/src/server/linux/server_iouring.rs b/mudu_kernel/src/server/linux/server_iouring.rs index 247611d..42fd4a9 100644 --- a/mudu_kernel/src/server/linux/server_iouring.rs +++ b/mudu_kernel/src/server/linux/server_iouring.rs @@ -1,5 +1,4 @@ -use crate::server::frame_dispatch::{dispatch_frame_async, try_decode_next_frame}; -use crate::server::server::WorkerTcpServerConfig; +use crate::server::server_launch::ServerLaunch; use crate::server::worker::WorkerRuntime; use crate::server::worker_loop_stats::WorkerLoopStats; use crate::server::worker_mailbox::WorkerMailboxMsg; @@ -8,7 +7,6 @@ use crossbeam_queue::SegQueue; use mudu::common::result::RS; use mudu::error::ec::EC; use mudu::m_error; -use mudu_contract::protocol::Frame; use mudu_utils::notifier::{Notifier, Waiter}; use mudu_utils::task_async::{build_current_thread_runtime, CurrentThreadTaskRuntime}; use std::os::fd::{IntoRawFd, RawFd}; @@ -30,26 +28,23 @@ struct RecoveryState { } pub(crate) fn sync_serve_iouring( - mut cfg: WorkerTcpServerConfig, + mut cfg: ServerLaunch, stop: Waiter, ready: Option, ) -> RS<()> { - if cfg.worker_count() == 0 { + if cfg.cfg().worker_count() == 0 { return Err(m_error!(EC::ParseErr, "invalid io_uring worker count")); } - let listen_addr: std::net::SocketAddr = format!("{}:{}", cfg.listen_ip(), cfg.listen_port()) - .parse() - .map_err(|e| m_error!(EC::ParseErr, "parse io_uring tcp listen address error", e))?; let prebound_listener = cfg.take_prebound_listener(); let conn_id_alloc = Arc::new(AtomicU64::new(1)); - let mailboxes: Vec<_> = (0..cfg.worker_count()) + let mailboxes: Vec<_> = (0..cfg.cfg().worker_count()) .map(|_| Arc::new(SegQueue::::new())) .collect(); - let mailbox_fds: Vec<_> = (0..cfg.worker_count()) + let mailbox_fds: Vec<_> = (0..cfg.cfg().worker_count()) .map(|_| create_mailbox_event_fd()) .collect::>>()?; let stop_flag = Arc::new(AtomicBool::new(false)); - let recovery_coordinator = Arc::new(RecoveryCoordinator::new(cfg.worker_count(), ready)); + let recovery_coordinator = Arc::new(RecoveryCoordinator::new(cfg.cfg().worker_count(), ready)); let stop_for_notifier = stop.clone(); let shutdown_mailboxes = mailboxes.clone(); @@ -78,15 +73,22 @@ pub(crate) fn sync_serve_iouring( Ok(()) })?; - let mut handles = Vec::with_capacity(cfg.worker_count()); - for worker_id in 0..cfg.worker_count() { - let listen_addr = listen_addr; + let mut handles = Vec::with_capacity(cfg.cfg().worker_count()); + for worker_id in 0..cfg.cfg().worker_count() { + let worker_port = cfg.cfg().listen_port_for_worker(worker_id)?; + let listen_addr: std::net::SocketAddr = + format!("{}:{}", cfg.cfg().listen_ip(), worker_port) + .parse() + .map_err(|e| { + m_error!(EC::ParseErr, "parse io_uring tcp listen address error", e) + })?; let conn_id_alloc = conn_id_alloc.clone(); let mailbox = mailboxes[worker_id].clone(); let all_mailboxes = mailboxes.clone(); let all_mailbox_fds = mailbox_fds.clone(); - let procedure_runtime = cfg.procedure_runtime_for_worker(worker_id); + let procedure_runtime = cfg.deps().procedure_runtime_for_worker(worker_id); let worker_identity = cfg + .deps() .worker_registry() .worker(worker_id) .cloned() @@ -96,15 +98,15 @@ pub(crate) fn sync_serve_iouring( format!("missing worker identity {}", worker_id) ) })?; - let worker_registry = cfg.worker_registry(); - let routing_mode = cfg.routing_mode(); - let data_dir = cfg.data_dir().to_string(); - let log_dir = cfg.log_dir().to_string(); - let log_chunk_size = cfg.log_chunk_size(); - let log_batching = cfg.log_batching(); - let worker_count = cfg.worker_count(); - let server_instance_id = cfg.server_instance_id(); + let worker_registry = cfg.deps().worker_registry(); + let data_dir = cfg.cfg().data_dir().to_string(); + let log_dir = cfg.cfg().log_dir().to_string(); + let log_chunk_size = cfg.cfg().log_chunk_size(); + let log_batching = cfg.deps().log_batching(); + let worker_count = cfg.cfg().worker_count(); + let server_instance_id = cfg.cfg().server_instance_id(); let listener = match &prebound_listener { + Some(_) if worker_id != 0 => None, Some(listener) => Some( listener .try_clone() @@ -115,45 +117,51 @@ pub(crate) fn sync_serve_iouring( let stop = stop_flag.clone(); let recovery_coordinator = recovery_coordinator.clone(); let mailbox_fd = mailbox_fds[worker_id]; - let async_runtime = cfg.async_runtime(); + let async_runtime = cfg.deps().async_runtime(); + let recovery_coordinator_for_failure = recovery_coordinator.clone(); let handle = mudu_sys::task_sync::spawn_thread_named(format!("worker-{worker_id}"), move || { - let runtime = CurrentThreadTaskRuntime::new().map_err(|e| { - m_error!(EC::TokioErr, "create runtime for io_uring worker error", e) - })?; - let listener_fd = match listener { - Some(listener) => listener.into_raw_fd(), - None => create_listener_fd(listen_addr)?, - }; - let worker = WorkerRuntime::new_with_log_batching_and_runtime( - worker_identity, - worker_count, - routing_mode, - log_dir.clone(), - data_dir.clone(), - log_chunk_size, - log_batching, - procedure_runtime, - worker_registry, - async_runtime, - server_instance_id, - )?; - let mut loop_state = WorkerRingLoop::new( - worker, - listener_fd, - mailbox_fd, - mailbox, - all_mailboxes, - all_mailbox_fds, - conn_id_alloc, - recovery_coordinator, - stop, - )?; - runtime.block_on(async move { loop_state.run() }) + let result = (|| { + let runtime = CurrentThreadTaskRuntime::new().map_err(|e| { + m_error!(EC::TokioErr, "create runtime for io_uring worker error", e) + })?; + let listener_fd = match listener { + Some(listener) => listener.into_raw_fd(), + None => create_listener_fd(listen_addr)?, + }; + let worker = WorkerRuntime::new_with_log_batching_and_runtime( + worker_identity, + worker_count, + log_dir.clone(), + data_dir.clone(), + log_chunk_size, + log_batching, + procedure_runtime, + worker_registry, + async_runtime, + server_instance_id, + )?; + let mut loop_state = WorkerRingLoop::new( + worker, + listener_fd, + mailbox_fd, + mailbox, + all_mailboxes, + all_mailbox_fds, + conn_id_alloc, + recovery_coordinator, + stop, + )?; + runtime.block_on(async move { loop_state.run() }) + })(); + if result.is_err() { + recovery_coordinator_for_failure.worker_failed(); + } + result })?; handles.push(handle); } - let mut worker_stats = Vec::::with_capacity(cfg.worker_count()); + let mut worker_stats = Vec::::with_capacity(cfg.cfg().worker_count()); let mut first_error: Option = None; for handle in handles { @@ -269,21 +277,6 @@ impl RecoveryCoordinator { } } -#[allow(dead_code)] -pub async fn dispatch_frame_iouring( - worker: &WorkerRuntime, - conn_id: u64, - frame: &Frame, -) -> RS { - mudu_utils::scoped_task_trace!(); - dispatch_frame_async(worker, conn_id, frame).await -} - -#[allow(dead_code)] -pub fn try_decode_next_frame_iouring(buf: &[u8]) -> RS> { - try_decode_next_frame(buf) -} - fn create_listener_fd(listen_addr: std::net::SocketAddr) -> RS { mudu_sys::net::create_tcp_listener_fd(listen_addr, 1024) } @@ -345,8 +338,6 @@ fn log_worker_stats(stats: &[WorkerLoopStats]) { #[cfg(test)] mod tests { use super::*; - use crate::server::routing::ConnectionTransfer; - use crate::server::transferred_connection::TransferredConnection; #[test] fn mailbox_eventfd_accumulates_wakeups() { @@ -361,29 +352,9 @@ mod tests { } #[test] - fn mailbox_can_store_shutdown_and_transfer_messages() { + fn mailbox_can_store_shutdown_messages() { let mailbox = SegQueue::new(); - mailbox.push(WorkerMailboxMsg::AdoptConnection( - TransferredConnection::new( - ConnectionTransfer::new( - 11, - 1, - crate::server::connection_state::ConnectionState::Accepted, - "127.0.0.1:9527".parse().unwrap(), - ), - -1, - Vec::new(), - None, - ), - )); mailbox.push(WorkerMailboxMsg::Shutdown); - match mailbox.pop() { - Some(WorkerMailboxMsg::AdoptConnection(connection)) => { - assert_eq!(connection.transfer().conn_id(), 11); - assert_eq!(connection.transfer().target_worker(), 1); - } - other => panic!("unexpected first mailbox message: {other:?}"), - } assert!(matches!(mailbox.pop(), Some(WorkerMailboxMsg::Shutdown))); assert!(mailbox.pop().is_none()); } diff --git a/mudu_kernel/src/server/linux/transferred_connection.rs b/mudu_kernel/src/server/linux/transferred_connection.rs deleted file mode 100644 index 7648983..0000000 --- a/mudu_kernel/src/server/linux/transferred_connection.rs +++ /dev/null @@ -1,43 +0,0 @@ -use crate::server::routing::{ConnectionTransfer, SessionOpenTransferAction}; -use mudu::common::id::OID; -use std::os::fd::RawFd; - -#[derive(Debug)] -pub(in crate::server) struct TransferredConnection { - transfer: ConnectionTransfer, - fd: RawFd, - session_ids: Vec, - session_open_action: Option, -} - -impl TransferredConnection { - pub(in crate::server) fn new( - transfer: ConnectionTransfer, - fd: RawFd, - session_ids: Vec, - session_open_action: Option, - ) -> Self { - Self { - transfer, - fd, - session_ids, - session_open_action, - } - } - - pub(in crate::server) fn transfer(&self) -> &ConnectionTransfer { - &self.transfer - } - - pub(in crate::server) fn fd(&self) -> RawFd { - self.fd - } - - pub(in crate::server) fn session_ids(&self) -> &[OID] { - &self.session_ids - } - - pub(in crate::server) fn session_open_action(&self) -> Option { - self.session_open_action - } -} diff --git a/mudu_kernel/src/server/linux/worker_mailbox.rs b/mudu_kernel/src/server/linux/worker_mailbox.rs index 8603a05..ea80946 100644 --- a/mudu_kernel/src/server/linux/worker_mailbox.rs +++ b/mudu_kernel/src/server/linux/worker_mailbox.rs @@ -1,9 +1,7 @@ use crate::server::message_bus_api::Envelope; -use crate::server::transferred_connection::TransferredConnection; #[derive(Debug)] pub(in crate::server) enum WorkerMailboxMsg { - AdoptConnection(TransferredConnection), BusMessage(Envelope), Shutdown, } diff --git a/mudu_kernel/src/server/linux/worker_ring_loop.rs b/mudu_kernel/src/server/linux/worker_ring_loop.rs index 0202555..279ca47 100644 --- a/mudu_kernel/src/server/linux/worker_ring_loop.rs +++ b/mudu_kernel/src/server/linux/worker_ring_loop.rs @@ -33,9 +33,7 @@ use crossbeam_queue::SegQueue; use mudu::common::result::RS; use mudu::error::ec::EC; use mudu::m_error; -use mudu_contract::protocol::{ - encode_merror_response, encode_session_create_response, SessionCreateResponse, -}; +use mudu_utils::scoped_task_trace; use mudu_utils::task_context::TaskContext; use std::collections::HashMap; use std::fs::OpenOptions; @@ -46,7 +44,6 @@ use std::sync::Arc; use std::thread; use std::time::Duration; use tracing::{debug, trace}; -use mudu_utils::scoped_task_trace; #[path = "worker_ring_loop/recovery.rs"] mod recovery; @@ -69,8 +66,6 @@ pub(in crate::server) struct WorkerRingLoop { listener_fd: RawFd, mailbox_fd: RawFd, mailbox: Arc>, - mailboxes: Vec>>, - mailbox_fds: Vec, conn_id_alloc: Arc, recovery_coordinator: Arc, worker_local_ring: Arc, @@ -135,8 +130,6 @@ impl WorkerRingLoop { listener_fd, mailbox_fd, mailbox, - mailboxes, - mailbox_fds, conn_id_alloc, recovery_coordinator, worker_local_ring, @@ -239,29 +232,7 @@ impl WorkerRingLoop { let remote_addr = server_iouring::sockaddr_to_socket_addr(op.addr())?; server_iouring::set_connection_options(conn_fd)?; let conn_id = self.conn_id_alloc.fetch_add(1, Ordering::Relaxed); - let target_worker = self.worker.route_connection(conn_id, remote_addr); - if target_worker == self.worker.worker_index() { - self.register_connection(conn_id, conn_fd, remote_addr)?; - } else { - Self::dispatch_mailbox_message( - &self.mailbox_fds, - &self.mailboxes, - target_worker, - WorkerMailboxMsg::AdoptConnection( - crate::server::transferred_connection::TransferredConnection::new( - crate::server::routing::ConnectionTransfer::new( - conn_id, - target_worker, - crate::server::connection_state::ConnectionState::Accepted, - remote_addr, - ), - conn_fd, - Vec::new(), - None, - ), - ), - )?; - } + self.register_connection(conn_id, conn_fd, remote_addr)?; } } InflightOp::MailboxRead { .. } => { @@ -300,48 +271,6 @@ impl WorkerRingLoop { fn handle_mailbox_message(&self, msg: WorkerMailboxMsg) -> RS<()> { match msg { - WorkerMailboxMsg::AdoptConnection(connection) => { - debug!( - worker_id = self.worker.worker_id(), - conn_id = connection.transfer().conn_id(), - remote_addr = %connection.transfer().remote_addr(), - session_ids = ?connection.session_ids(), - has_session_open_action = connection.session_open_action().is_some(), - "worker_ring_loop handling adopt connection mailbox message" - ); - server_iouring::set_connection_options(connection.fd())?; - self.worker.adopt_connection_sessions( - connection.transfer().conn_id(), - connection.session_ids(), - )?; - let initial_response = if let Some(action) = connection.session_open_action() { - Some( - match self.worker.open_session_with_config( - connection.transfer().conn_id(), - action.config(), - ) { - Ok(session_id) => encode_session_create_response( - action.request_id(), - &SessionCreateResponse::new(session_id), - )?, - Err(err) => encode_merror_response(action.request_id(), &err)?, - }, - ) - } else { - None - }; - self.start_connection_task( - connection.transfer().conn_id(), - connection.fd(), - connection.transfer().remote_addr(), - initial_response, - )?; - debug!( - worker_id = self.worker.worker_id(), - conn_id = connection.transfer().conn_id(), - "worker_ring_loop handled adopt connection mailbox message" - ); - } WorkerMailboxMsg::BusMessage(envelope) => { debug!( worker_id = self.worker.worker_id(), @@ -427,31 +356,6 @@ impl WorkerRingLoop { submit_user_io(&mut ctx) } - pub fn dispatch_mailbox_message( - mailbox_fds: &[RawFd], - mailboxes: &[Arc>], - target_worker: usize, - msg: WorkerMailboxMsg, - ) -> RS<()> { - let Some(mailbox) = mailboxes.get(target_worker) else { - return Err(m_error!( - EC::InternalErr, - format!("mailbox target worker {} is out of range", target_worker) - )); - }; - let Some(&fd) = mailbox_fds.get(target_worker) else { - return Err(m_error!( - EC::InternalErr, - format!( - "mailbox eventfd target worker {} is out of range", - target_worker - ) - )); - }; - mailbox.push(msg); - server_iouring::notify_mailbox_fd(fd) - } - pub(in crate::server) fn alloc_token(&mut self) -> u64 { let token = self.next_token; self.next_token += 1; @@ -471,8 +375,6 @@ impl WorkerRingLoop { Some(conn_id), spawn_connection_worker_task( self.worker.clone(), - self.mailbox_fds.clone(), - self.mailboxes.clone(), self.connection_task_fds.clone(), conn_id, socket, @@ -546,7 +448,6 @@ mod tests { accept, close as close_socket, connect, recv, send, shutdown, socket, IoSocket, }; use crate::server::callback_registry::{CallbackDomain, CallbackEventKey, CallbackTrigger}; - use crate::server::routing::RoutingMode; use crate::server::worker_registry::load_or_create_worker_registry; use mudu::common::id::gen_oid; use mudu_sys::tokio::task::{yield_now, JoinHandle}; @@ -564,7 +465,6 @@ mod tests { let worker = WorkerRuntime::new( identity, 1, - RoutingMode::ConnectionId, dir.clone(), dir.clone(), 4096, diff --git a/mudu_kernel/src/server/linux/worker_ring_loop/runtime.rs b/mudu_kernel/src/server/linux/worker_ring_loop/runtime.rs index 7daa6ee..6d7c621 100644 --- a/mudu_kernel/src/server/linux/worker_ring_loop/runtime.rs +++ b/mudu_kernel/src/server/linux/worker_ring_loop/runtime.rs @@ -91,7 +91,6 @@ impl WorkerRingLoop { return Ok(()); } self.shutting_down = true; - self.close_pending_mailbox_connections()?; self.shutdown_connection_tasks(); if self.listener_fd >= 0 { let rc = unsafe { libc::close(self.listener_fd) }; @@ -107,22 +106,6 @@ impl WorkerRingLoop { Ok(()) } - fn close_pending_mailbox_connections(&mut self) -> RS<()> { - while let Some(msg) = self.mailbox.pop() { - if let WorkerMailboxMsg::AdoptConnection(connection) = msg { - let rc = unsafe { libc::close(connection.fd()) }; - if rc != 0 { - return Err(m_error!( - EC::NetErr, - "close transferred io_uring connection during shutdown error", - std::io::Error::last_os_error() - )); - } - } - } - Ok(()) - } - pub(in crate::server) fn poll_ready_worker_tasks(&mut self) -> RS<()> { for completed in self.worker_local_ring.worker_task_registry().poll_ready() { if completed.is_system() { diff --git a/mudu_kernel/src/server/mod.rs b/mudu_kernel/src/server/mod.rs index 3b3ef33..d80add1 100644 --- a/mudu_kernel/src/server/mod.rs +++ b/mudu_kernel/src/server/mod.rs @@ -44,9 +44,12 @@ mod request_ctx; mod request_response_worker; pub mod routing; pub mod server; +pub mod server_cfg; #[cfg(target_os = "linux")] #[path = "linux/server_iouring.rs"] mod server_iouring; +pub mod server_launch; +pub mod server_runtime_deps; mod session_bound_worker_runtime; mod task; #[cfg(target_os = "linux")] @@ -54,9 +57,6 @@ mod task; pub(crate) mod task_registry; #[cfg(test)] pub(crate) mod test_meta_mgr; -#[cfg(target_os = "linux")] -#[path = "linux/transferred_connection.rs"] -mod transferred_connection; pub mod worker; pub mod worker_local; mod worker_loop_stats; @@ -76,3 +76,4 @@ mod worker_task; mod worker_tx_manager; pub mod x_contract; mod x_lock_mgr; +mod procedure_runtimes; diff --git a/mudu_kernel/src/server/procedure_runtimes.rs b/mudu_kernel/src/server/procedure_runtimes.rs new file mode 100644 index 0000000..45aad05 --- /dev/null +++ b/mudu_kernel/src/server/procedure_runtimes.rs @@ -0,0 +1,24 @@ +use crate::server::async_func_runtime::AsyncFuncInvokerPtr; + +/// Procedure invokers are runtime dependencies, not static server settings. +pub enum ProcedureRuntimes { + None, + Shared(AsyncFuncInvokerPtr), + PerWorker(Vec), +} + +impl ProcedureRuntimes { + pub fn for_worker(&self, worker_id: usize) -> Option { + match self { + Self::None => None, + Self::Shared(runtime) => Some(runtime.clone()), + Self::PerWorker(runtimes) => runtimes.get(worker_id).cloned(), + } + } +} + +impl Default for ProcedureRuntimes { + fn default() -> Self { + Self::None + } +} \ No newline at end of file diff --git a/mudu_kernel/src/server/request_ctx.rs b/mudu_kernel/src/server/request_ctx.rs index 7567bb5..dd13ba9 100644 --- a/mudu_kernel/src/server/request_ctx.rs +++ b/mudu_kernel/src/server/request_ctx.rs @@ -15,7 +15,7 @@ use std::sync::Arc; use crate::server::async_func_task::HandleResult; use crate::server::request_response_worker::WorkerRuntimeRef; use crate::server::routing::parse_session_open_config; -use crate::server::routing::{SessionOpenConfig, SessionOpenTransferAction}; +use crate::server::routing::SessionOpenConfig; use crate::server::worker_registry::WorkerRegistry; #[derive(Clone)] @@ -204,16 +204,15 @@ impl RequestCtx { ), )?)) } else { - let action = SessionOpenTransferAction::new(self.request_id, config); - let session_ids = self - .worker - .prepare_connection_transfer(self.conn_id, Some(action))?; - Ok(HandleResult::Transfer( - crate::server::async_func_task::SessionTransferDispatch::new( + Err(mudu::m_error!( + mudu::error::ec::EC::NetErr, + format!( + "session create landed on worker index {} worker id {}, expected worker index {} worker id {}; reconnect to the target worker port", + self.worker.worker_index(), + self.worker.worker_id(), config.target_worker_index(), - session_ids, - action, - ), + config.worker_id() + ) )) } } diff --git a/mudu_kernel/src/server/request_response_worker.rs b/mudu_kernel/src/server/request_response_worker.rs index 21b2e87..36dc1e1 100644 --- a/mudu_kernel/src/server/request_response_worker.rs +++ b/mudu_kernel/src/server/request_response_worker.rs @@ -6,7 +6,7 @@ use mudu::common::result::RS; use mudu_contract::protocol::{ProcedureInvokeRequest, ProcedureInvokeResponse}; use std::sync::Arc; -use crate::server::routing::{SessionOpenConfig, SessionOpenTransferAction}; +use crate::server::routing::SessionOpenConfig; #[async_trait] pub trait RequestResponseWorker: Send + Sync { @@ -18,12 +18,6 @@ pub trait RequestResponseWorker: Send + Sync { fn open_session_with_config(&self, conn_id: u64, config: SessionOpenConfig) -> RS; - fn prepare_connection_transfer( - &self, - conn_id: u64, - action: Option, - ) -> RS>; - fn close_session_for_connection(&self, conn_id: u64, session_id: OID) -> RS; async fn handle_procedure_request( diff --git a/mudu_kernel/src/server/routing.rs b/mudu_kernel/src/server/routing.rs index 5eae326..17d6e72 100644 --- a/mudu_kernel/src/server/routing.rs +++ b/mudu_kernel/src/server/routing.rs @@ -1,4 +1,3 @@ -use crate::server::connection_state::ConnectionState; use crate::server::worker_registry::WorkerRegistry; use mudu::common::id::OID; use mudu::common::result::RS; @@ -6,9 +5,6 @@ use mudu::error::ec::EC; use mudu::m_error; use serde::de::{self, Deserializer}; use serde::Deserialize; -use std::collections::hash_map::DefaultHasher; -use std::hash::{Hash, Hasher}; -use std::net::SocketAddr; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum RoutingMode { @@ -17,21 +13,6 @@ pub enum RoutingMode { RemoteHash, } -#[derive(Debug, Clone)] -pub struct RoutingContext { - conn_id: u64, - remote_addr: SocketAddr, - opt_player_id: Option, -} - -#[derive(Debug, Clone)] -pub struct ConnectionTransfer { - conn_id: u64, - target_worker: usize, - state: ConnectionState, - remote_addr: SocketAddr, -} - #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct SessionOpenConfig { session_id: OID, @@ -39,12 +20,6 @@ pub struct SessionOpenConfig { target_worker_index: usize, } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub struct SessionOpenTransferAction { - request_id: u64, - config: SessionOpenConfig, -} - #[derive(Debug, Deserialize)] struct RawSessionOpenConfig { #[serde(deserialize_with = "deserialize_oid_json")] @@ -91,60 +66,6 @@ where .transpose() } -impl RoutingContext { - pub fn new(conn_id: u64, remote_addr: SocketAddr, opt_player_id: Option) -> Self { - Self { - conn_id, - remote_addr, - opt_player_id, - } - } - - pub fn conn_id(&self) -> u64 { - self.conn_id - } - - pub fn remote_addr(&self) -> SocketAddr { - self.remote_addr - } - - pub fn opt_player_id(&self) -> Option<&str> { - self.opt_player_id.as_deref() - } -} - -impl ConnectionTransfer { - pub fn new( - conn_id: u64, - target_worker: usize, - state: ConnectionState, - remote_addr: SocketAddr, - ) -> Self { - Self { - conn_id, - target_worker, - state, - remote_addr, - } - } - - pub fn conn_id(&self) -> u64 { - self.conn_id - } - - pub fn target_worker(&self) -> usize { - self.target_worker - } - - pub fn state(&self) -> ConnectionState { - self.state - } - - pub fn remote_addr(&self) -> SocketAddr { - self.remote_addr - } -} - impl SessionOpenConfig { pub fn new(session_id: OID, worker_id: OID, target_worker_index: usize) -> Self { Self { @@ -167,32 +88,6 @@ impl SessionOpenConfig { } } -impl SessionOpenTransferAction { - pub fn new(request_id: u64, config: SessionOpenConfig) -> Self { - Self { request_id, config } - } - - pub fn request_id(&self) -> u64 { - self.request_id - } - - pub fn config(&self) -> SessionOpenConfig { - self.config - } -} - -pub fn route_worker(ctx: &RoutingContext, mode: RoutingMode, worker_count: usize) -> usize { - let key = match mode { - RoutingMode::ConnectionId => ctx.conn_id().to_string(), - RoutingMode::PlayerId => ctx - .opt_player_id() - .map(ToOwned::to_owned) - .unwrap_or_else(|| ctx.conn_id().to_string()), - RoutingMode::RemoteHash => ctx.remote_addr().to_string(), - }; - stable_hash(&key) % worker_count.max(1) -} - pub fn parse_session_open_config( config_json: Option<&str>, default_worker_index: usize, @@ -233,9 +128,3 @@ pub fn parse_session_open_config( )), } } - -fn stable_hash(value: &str) -> usize { - let mut hasher = DefaultHasher::new(); - value.hash(&mut hasher); - hasher.finish() as usize -} diff --git a/mudu_kernel/src/server/server.rs b/mudu_kernel/src/server/server.rs index d978094..6d71870 100644 --- a/mudu_kernel/src/server/server.rs +++ b/mudu_kernel/src/server/server.rs @@ -1,9 +1,7 @@ -#![allow(dead_code)] - use crate::async_rt::contract::AsyncRuntime; use crate::server::async_func_runtime::AsyncFuncInvokerPtr; -use crate::server::async_func_task::{AsyncFuncFuture, AsyncFuncTask, HandleResult}; -use crate::server::async_func_task_waker::AsyncFuncTaskWaker; +use crate::server::async_func_task::HandleResult; + use crate::server::frame_dispatch::{dispatch_frame_async, try_decode_next_frame}; use crate::server::message_bus_api::{ register_worker_message_bus, set_current_message_bus, unregister_worker_message_bus, @@ -11,217 +9,46 @@ use crate::server::message_bus_api::{ OnRecvCallback, OutgoingMessage, RecvFilter, ServerInstanceId, SubscriptionId, }; use crate::server::message_bus_state::WorkerMessageBusState; -use crate::server::routing::{ConnectionTransfer, RoutingMode, SessionOpenTransferAction}; use crate::server::session_bound_worker_runtime::{ as_worker_local_ref, new_session_bound_worker_runtime, }; use crate::server::worker::WorkerRuntime; use crate::server::worker_local::{set_current_worker_local, unset_current_worker_local}; -use crate::server::worker_registry::{ - load_or_create_worker_registry, WorkerIdentity, WorkerRegistry, -}; +use crate::server::worker_registry::{WorkerIdentity, WorkerRegistry}; use crate::wal::worker_log::WorkerLogBatching; +use crate::wal::worker_log::{decode_frames, WorkerLogBackend}; +use crate::wal::xl_batch::decode_xl_batches; use async_trait::async_trait; use crossbeam_queue::SegQueue; -use futures::future::poll_fn; -use futures::task::{waker, Context}; -use futures::Future; -use mudu::common::id::{gen_oid, OID}; + +use mudu::common::id::OID; use mudu::common::result::RS; use mudu::error::ec::EC; use mudu::m_error; -use mudu_contract::protocol::{ - encode_merror_response, encode_session_create_response, Frame, SessionCreateResponse, -}; +use mudu_contract::protocol::encode_merror_response; +use mudu_sys::sync::stop_flag::{stop_channel, StopRx, StopTx}; +use mudu_sys::tokio; +use mudu_sys::tokio::io::{AsyncReadExt, AsyncWriteExt}; use mudu_sys::tokio::net::{TcpListener as TokioTcpListener, TcpStream as TokioTcpStream}; +use mudu_sys::tokio::sync::Notify; use mudu_utils::notifier::{notify_wait, Notifier, Waiter}; +use mudu_utils::scoped_task_trace; use mudu_utils::task_async::{ build_current_thread_runtime, spawn_local_detached, spawn_local_task, CurrentThreadTaskRuntime, - PollTaskIdGuard, }; -use mudu_utils::task_context::TaskContext; -use mudu_utils::task_id::new_task_id; + use socket2::{Domain, Protocol, Socket, Type}; -use std::collections::HashMap; -use std::io::ErrorKind; -use std::net::{SocketAddr, TcpListener, TcpStream}; + +use std::net::{SocketAddr, TcpListener}; use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::mpsc; use std::sync::{atomic::AtomicBool, Arc, Mutex}; -use std::task::Poll; + +use crate::server::server_launch::{ServerLaunch, WorkerTcpBackendConfig}; use std::thread; use std::thread::JoinHandle; -use std::time::Duration; use tracing::trace; -/// Configuration shared by both execution paths of the `client` backend. -/// -/// The same configuration is consumed by both the io_uring worker-ring backend -/// and the Tokio backend so they keep the worker model and protocol surface -/// aligned. -pub struct WorkerTcpServerConfig { - server_instance_id: ServerInstanceId, - worker_count: usize, - listen_ip: String, - listen_port: u16, - prebound_listener: Option, - data_dir: String, - log_dir: String, - log_chunk_size: u64, - log_batching: WorkerLogBatching, - routing_mode: RoutingMode, - procedure_runtime: Option, - worker_procedure_runtimes: Option>, - worker_registry: Arc, - async_runtime: Option>, -} - -/// Backward-compatible name for callers that still refer to the historical -/// io_uring-only server configuration. -pub type IoUringTcpServerConfig = WorkerTcpServerConfig; - -/// Alias used by backend construction code that does not need a transport- -/// specific name. -pub type WorkerTcpBackendConfig = WorkerTcpServerConfig; - -impl WorkerTcpServerConfig { - /// Creates a backend configuration. - /// - /// The resulting value can be used by both the io_uring and Tokio TCP - /// backends with the same externally visible behavior. - pub fn new( - worker_count: usize, - listen_ip: String, - listen_port: u16, - data_dir: String, - log_dir: String, - routing_mode: RoutingMode, - procedure_runtime: Option, - ) -> RS { - let worker_registry = load_or_create_worker_registry(&log_dir, worker_count)?; - Ok(Self { - server_instance_id: gen_oid(), - worker_count, - listen_ip, - listen_port, - prebound_listener: None, - data_dir, - log_dir, - log_chunk_size: 64 * 1024 * 1024, - log_batching: WorkerLogBatching::default(), - routing_mode, - procedure_runtime, - worker_procedure_runtimes: None, - worker_registry, - async_runtime: None, - }) - } - - pub fn with_log_chunk_size(mut self, log_chunk_size: u64) -> Self { - self.log_chunk_size = log_chunk_size; - self - } - - pub fn with_log_batching(mut self, log_batching: WorkerLogBatching) -> Self { - self.log_batching = log_batching; - self - } - - pub fn with_prebound_listener(mut self, listener: TcpListener) -> Self { - self.prebound_listener = Some(listener); - self - } - - pub fn with_worker_registry(mut self, worker_registry: Arc) -> RS { - if worker_registry.workers().len() != self.worker_count { - return Err(m_error!( - EC::ParseErr, - format!( - "worker registry count {} does not match expected {}", - worker_registry.workers().len(), - self.worker_count - ) - )); - } - self.worker_registry = worker_registry; - Ok(self) - } - - /// Installs per-worker procedure runtimes. - /// - /// When this is not set, every worker uses `procedure_runtime()`. This hook - /// exists so upper layers can give each worker an isolated invoker instance - /// while keeping the transport API unchanged across io_uring and Tokio - /// implementations. - pub fn with_worker_procedure_runtimes(mut self, runtimes: Vec) -> Self { - self.worker_procedure_runtimes = Some(runtimes); - self - } - - pub fn with_async_runtime(mut self, async_runtime: Arc) -> Self { - self.async_runtime = Some(async_runtime); - self - } - - pub fn server_instance_id(&self) -> ServerInstanceId { - self.server_instance_id - } - - pub fn worker_count(&self) -> usize { - self.worker_count - } - - pub fn listen_ip(&self) -> &str { - &self.listen_ip - } - - pub fn listen_port(&self) -> u16 { - self.listen_port - } - - pub fn take_prebound_listener(&mut self) -> Option { - self.prebound_listener.take() - } - - pub fn log_dir(&self) -> &str { - &self.log_dir - } - - pub fn data_dir(&self) -> &str { - &self.data_dir - } - - pub fn log_chunk_size(&self) -> u64 { - self.log_chunk_size - } - - pub fn log_batching(&self) -> WorkerLogBatching { - self.log_batching - } - - pub fn routing_mode(&self) -> RoutingMode { - self.routing_mode - } - - pub fn worker_registry(&self) -> Arc { - self.worker_registry.clone() - } - - pub fn procedure_runtime(&self) -> Option { - self.procedure_runtime.clone() - } - - pub fn procedure_runtime_for_worker(&self, worker_id: usize) -> Option { - self.worker_procedure_runtimes - .as_ref() - .and_then(|runtimes| runtimes.get(worker_id).cloned()) - .or_else(|| self.procedure_runtime()) - } - - pub fn async_runtime(&self) -> Option> { - self.async_runtime.clone() - } -} - /// Backend entry point for the `client` transport. /// /// Actual behavior is target-specific: Linux runs the native `io_uring` @@ -234,120 +61,11 @@ pub struct TokioTcpBackend; /// io_uring-only backend entry point. pub type IoUringTcpBackend = WorkerTcpBackend; -#[derive(Debug)] -struct TransferredConnection { - transfer: ConnectionTransfer, - stream: TcpStream, - session_ids: Vec, - session_open_action: Option, -} - -struct TokioWorkerConnection { - core: ConnectionCore, - stream: Option, -} - -struct ConnectionCore { - conn_id: u64, - state: crate::server::connection_state::ConnectionState, - remote_addr: SocketAddr, - transferred: bool, - read_buf: Vec, - write_buf: Vec, -} - -impl ConnectionCore { - fn new(conn_id: u64, remote_addr: SocketAddr) -> Self { - Self { - conn_id, - state: crate::server::connection_state::ConnectionState::Active, - remote_addr, - transferred: false, - read_buf: Vec::with_capacity(4096), - write_buf: Vec::with_capacity(4096), - } - } -} - -trait BackendConnection { - fn core(&self) -> &ConnectionCore; - fn core_mut(&mut self) -> &mut ConnectionCore; - fn read_available(&mut self) -> RS; - fn write_pending(&mut self) -> RS; - fn take_transfer_stream(&mut self) -> RS; -} - -impl BackendConnection for TokioWorkerConnection { - fn core(&self) -> &ConnectionCore { - &self.core - } - - fn core_mut(&mut self) -> &mut ConnectionCore { - &mut self.core - } - - fn read_available(&mut self) -> RS { - let mut progressed = false; - let Some(stream) = self.stream.as_mut() else { - return Ok(progressed); - }; - let mut buf = [0u8; 8192]; - loop { - match stream.try_read(&mut buf) { - Ok(0) => { - self.core.state = crate::server::connection_state::ConnectionState::Closing; - break; - } - Ok(read) => { - progressed = true; - self.core.read_buf.extend_from_slice(&buf[..read]); - } - Err(err) if err.kind() == ErrorKind::WouldBlock => break, - Err(err) => return Err(m_error!(EC::NetErr, "read tokio tcp request error", err)), - } - } - Ok(progressed) - } - - fn write_pending(&mut self) -> RS { - let mut progressed = false; - let Some(stream) = self.stream.as_mut() else { - return Ok(progressed); - }; - while !self.core.write_buf.is_empty() { - match stream.try_write(&self.core.write_buf) { - Ok(0) => { - self.core.state = crate::server::connection_state::ConnectionState::Closing; - break; - } - Ok(written) => { - progressed = true; - self.core.write_buf.drain(0..written); - } - Err(err) if err.kind() == ErrorKind::WouldBlock => break, - Err(err) => { - return Err(m_error!(EC::NetErr, "write tokio tcp response error", err)) - } - } - } - Ok(progressed) - } - - fn take_transfer_stream(&mut self) -> RS { - let stream = self - .stream - .take() - .ok_or_else(|| m_error!(EC::InternalErr, "tokio connection stream missing"))?; - stream - .into_std() - .map_err(|e| m_error!(EC::NetErr, "convert tokio stream for transfer error", e)) - } -} - struct TokioWorkerMessageBus { local_worker_id: OID, registry: Arc, mailboxes: Vec>>, + mailbox_wakes: Vec>, next_msg_id: AtomicU64, state: Mutex, } @@ -357,11 +75,13 @@ impl TokioWorkerMessageBus { local_worker_id: OID, registry: Arc, mailboxes: Vec>>, + mailbox_wakes: Vec>, ) -> Arc { Arc::new(Self { local_worker_id, registry, mailboxes, + mailbox_wakes, next_msg_id: AtomicU64::new(1), state: Mutex::new(WorkerMessageBusState::new()), }) @@ -417,6 +137,7 @@ impl MessageBus for TokioWorkerMessageBus { } async fn send(&self, dst: EndpointId, message: OutgoingMessage) -> RS { + scoped_task_trace!(); let msg_id = self.next_msg_id.fetch_add(1, Ordering::Relaxed); let envelope = Envelope::new( msg_id, @@ -435,6 +156,9 @@ impl MessageBus for TokioWorkerMessageBus { )); }; mailbox.push(envelope); + if let Some(wake) = self.mailbox_wakes.get(target_worker) { + wake.notify_one(); + } Ok(msg_id) } @@ -490,7 +214,6 @@ struct WorkerBuildConfig { data_dir: String, log_chunk_size: u64, log_batching: WorkerLogBatching, - routing_mode: RoutingMode, procedure_runtime: Option, worker_identity: WorkerIdentity, worker_registry: Arc, @@ -499,7 +222,9 @@ struct WorkerBuildConfig { impl WorkerBuildConfig { fn from_server_config(cfg: &WorkerTcpBackendConfig, worker_id: usize) -> RS { - let worker_identity = cfg + let server_cfg = cfg.cfg(); + let deps = cfg.deps(); + let worker_identity = deps .worker_registry() .worker(worker_id) .cloned() @@ -510,17 +235,16 @@ impl WorkerBuildConfig { ) })?; Ok(Self { - server_instance_id: cfg.server_instance_id(), - worker_count: cfg.worker_count(), - log_dir: cfg.log_dir().to_string(), - data_dir: cfg.data_dir().to_string(), - log_chunk_size: cfg.log_chunk_size(), - log_batching: cfg.log_batching(), - routing_mode: cfg.routing_mode(), - procedure_runtime: cfg.procedure_runtime_for_worker(worker_id), + server_instance_id: server_cfg.server_instance_id(), + worker_count: server_cfg.worker_count(), + log_dir: server_cfg.log_dir().to_string(), + data_dir: server_cfg.data_dir().to_string(), + log_chunk_size: server_cfg.log_chunk_size(), + log_batching: deps.log_batching(), + procedure_runtime: deps.procedure_runtime_for_worker(worker_id), worker_identity, - worker_registry: cfg.worker_registry(), - async_runtime: cfg.async_runtime(), + worker_registry: deps.worker_registry(), + async_runtime: deps.async_runtime(), }) } @@ -528,7 +252,6 @@ impl WorkerBuildConfig { WorkerRuntime::new_with_log_batching_and_runtime( self.worker_identity, self.worker_count, - self.routing_mode, self.log_dir, self.data_dir, self.log_chunk_size, @@ -545,6 +268,8 @@ fn spawn_stop_bridge( name: &'static str, stop: Waiter, stop_flag: Arc, + service_ready: Arc, + stop_tx: StopTx, ) -> RS>> { thread::Builder::new() .name(name.to_string()) @@ -555,7 +280,9 @@ fn spawn_stop_bridge( trace!(bridge = name, "tokio stop bridge waiting for stop"); runtime.block_on(stop.wait()); trace!(bridge = name, "tokio stop bridge observed stop"); + service_ready.store(false, Ordering::Relaxed); stop_flag.store(true, Ordering::Relaxed); + stop_tx.stop(); Ok(()) }) .map_err(|e| m_error!(EC::ThreadErr, format!("spawn {name} error"), e)) @@ -567,201 +294,13 @@ fn wait_stop_bridge(name: &'static str, handle: JoinHandle>) -> RS<()> { .map_err(|_| m_error!(EC::ThreadErr, format!("join {name} error")))? } -fn apply_handle_result_to_connection( - connection: &mut C, - inboxes: &[Arc>], - result: HandleResult, -) -> RS<()> { - match result { - HandleResult::Response(payload) => { - connection.core_mut().write_buf.extend_from_slice(&payload); - } - HandleResult::Transfer(transfer) => { - let stream = connection.take_transfer_stream()?; - enqueue_transfer( - inboxes, - connection.core().conn_id, - transfer.target_worker(), - connection.core().remote_addr, - stream, - transfer.session_ids().to_vec(), - Some(transfer.action()), - )?; - let core = connection.core_mut(); - core.transferred = true; - core.state = crate::server::connection_state::ConnectionState::Closing; - core.write_buf.clear(); - } - } - Ok(()) -} - -fn apply_handle_result( - connections: &mut HashMap, - inboxes: &[Arc>], - conn_id: u64, - result: HandleResult, -) -> RS<()> { - let Some(connection) = connections.get_mut(&conn_id) else { - return Ok(()); - }; - apply_handle_result_to_connection(connection, inboxes, result) -} - -struct FallbackAsyncFuncState { - next_task_id: u64, - next_op_id: u64, - tasks: HashMap, - ready_queue: Arc>, - completion_queue: Arc>, - op_registry: HashMap, -} - -impl FallbackAsyncFuncState { - fn new() -> Self { - Self { - next_task_id: 1, - next_op_id: 1, - tasks: HashMap::new(), - ready_queue: Arc::new(SegQueue::new()), - completion_queue: Arc::new(SegQueue::new()), - op_registry: HashMap::new(), - } - } - - fn enqueue_future(&mut self, conn_id: u64, request_id: u64, future: AsyncFuncFuture) { - let task_id = self.next_task_id; - self.next_task_id += 1; - let trace_task_id = new_task_id(); - let _ = TaskContext::new_context( - trace_task_id, - format!("tokio_async_func conn={conn_id} req={request_id}"), - false, - ); - self.tasks.insert( - task_id, - AsyncFuncTask::new( - conn_id, - trace_task_id, - request_id, - future, - Arc::new(AtomicBool::new(false)), - ), - ); - self.ready_queue.push(task_id); - } - - fn drain_completions(&mut self) -> bool { - let mut progressed = false; - while let Some(op_id) = self.completion_queue.pop() { - let Some(task_id) = self.op_registry.remove(&op_id) else { - continue; - }; - let Some(task) = self.tasks.get(&task_id) else { - continue; - }; - if let Some(ctx) = TaskContext::get(task.trace_task_id()) { - ctx.watch("state", "ready"); - ctx.watch("wake_op_id", &op_id.to_string()); - } - if !task.queued().swap(true, Ordering::AcqRel) { - self.ready_queue.push(task_id); - progressed = true; - } - } - progressed - } - - fn poll_ready( - &mut self, - connections: &mut HashMap, - inboxes: &[Arc>], - ) -> RS { - let mut progressed = false; - while let Some(task_id) = self.ready_queue.pop() { - let Some(mut task) = self.tasks.remove(&task_id) else { - continue; - }; - let trace_task_id = task.trace_task_id(); - trace!( - task_id, - conn_id = task.conn_id(), - request_id = task.request_id(), - "tokio async task poll begin" - ); - progressed = true; - task.clear_queued(); - if let Some(waiting_on) = task.take_waiting_on() { - self.op_registry.remove(&waiting_on); - } - - let op_id = self.next_op_id; - self.next_op_id += 1; - let waker = waker(Arc::new(AsyncFuncTaskWaker::new( - op_id, - self.completion_queue.clone(), - task.completed().clone(), - ))); - let mut cx = Context::from_waker(&waker); - let _guard = PollTaskIdGuard::enter(trace_task_id); - if let Some(ctx) = TaskContext::get(trace_task_id) { - ctx.watch("state", "polling"); - ctx.watch("poll_task_id", &task_id.to_string()); - } - match task.future_mut().poll(&mut cx) { - Poll::Ready(Ok(result)) => { - trace!( - task_id, - conn_id = task.conn_id(), - request_id = task.request_id(), - "tokio async task poll ready ok" - ); - TaskContext::remove_context(trace_task_id); - apply_handle_result(connections, inboxes, task.conn_id(), result)?; - } - Poll::Ready(Err(err)) => { - trace!( - task_id, - conn_id = task.conn_id(), - request_id = task.request_id(), - err = %err, - "tokio async task poll ready err" - ); - TaskContext::remove_context(trace_task_id); - if let Some(connection) = connections.get_mut(&task.conn_id()) { - let response = encode_merror_response(task.request_id(), &err)?; - connection.core_mut().write_buf.extend_from_slice(&response); - } - } - Poll::Pending => { - trace!( - task_id, - conn_id = task.conn_id(), - request_id = task.request_id(), - op_id, - "tokio async task poll pending" - ); - task.set_waiting_on(op_id); - if let Some(ctx) = TaskContext::get(trace_task_id) { - ctx.watch("state", "pending"); - ctx.watch("waiting_waker_op_id", &op_id.to_string()); - } - self.op_registry.insert(op_id, task_id); - self.tasks.insert(task_id, task); - } - } - } - Ok(progressed) - } -} - impl WorkerTcpBackend { /// Starts the backend until shutdown. /// /// This method keeps the old public entry point stable. It dispatches to /// the io_uring implementation on Linux. Select `TokioTcpBackend` /// explicitly when the Tokio worker loop is desired on any target. - pub fn sync_serve(cfg: WorkerTcpServerConfig) -> RS<()> { + pub fn sync_serve(cfg: ServerLaunch) -> RS<()> { let (_stop_notifier, stop_waiter) = notify_wait(); Self::sync_serve_with_stop(cfg, stop_waiter) } @@ -771,12 +310,12 @@ impl WorkerTcpBackend { /// The io_uring backend is Linux-only. The Tokio backend is available as a /// separate implementation and bridges the async stop signal into its /// worker loop. - pub fn sync_serve_with_stop(cfg: WorkerTcpServerConfig, stop: Waiter) -> RS<()> { + pub fn sync_serve_with_stop(cfg: ServerLaunch, stop: Waiter) -> RS<()> { Self::sync_serve_with_stop_and_ready(cfg, stop, None) } pub fn sync_serve_with_stop_and_ready( - cfg: WorkerTcpServerConfig, + cfg: ServerLaunch, stop: Waiter, ready: Option, ) -> RS<()> { @@ -791,64 +330,115 @@ impl WorkerTcpBackend { } impl TokioTcpBackend { - pub fn sync_serve(cfg: WorkerTcpServerConfig) -> RS<()> { + pub fn sync_serve(cfg: ServerLaunch) -> RS<()> { let (_stop_notifier, stop_waiter) = notify_wait(); Self::sync_serve_with_stop(cfg, stop_waiter) } - pub fn sync_serve_with_stop(cfg: WorkerTcpServerConfig, stop: Waiter) -> RS<()> { + pub fn sync_serve_with_stop(cfg: ServerLaunch, stop: Waiter) -> RS<()> { Self::sync_serve_with_stop_and_ready(cfg, stop, None) } pub fn sync_serve_with_stop_and_ready( - cfg: WorkerTcpServerConfig, + cfg: ServerLaunch, stop: Waiter, ready: Option, ) -> RS<()> { let stop_flag = Arc::new(AtomicBool::new(false)); - let notifier = spawn_stop_bridge("tokio-stop-bridge", stop, stop_flag.clone())?; - let result = sync_serve_tokio(cfg, stop_flag, ready); + let service_ready = Arc::new(AtomicBool::new(false)); + let (stop_tx, stop_rx) = stop_channel(); + let notifier = spawn_stop_bridge( + "tokio-stop-bridge", + stop, + stop_flag.clone(), + service_ready.clone(), + stop_tx, + )?; + let result = sync_serve_tokio(cfg, stop_flag, stop_rx, service_ready, ready); wait_stop_bridge("tokio-stop-bridge", notifier)?; result } } +#[derive(Clone)] +struct TokioConnTaskState { + active: Arc, + drained: Arc, +} + +impl TokioConnTaskState { + fn new() -> Self { + Self { + active: Arc::new(std::sync::atomic::AtomicU64::new(0)), + drained: Arc::new(Notify::new()), + } + } + + fn on_spawn(&self) { + self.active.fetch_add(1, Ordering::Relaxed); + } + + fn on_finish(&self) { + if self.active.fetch_sub(1, Ordering::Relaxed) == 1 { + self.drained.notify_waiters(); + } + } + + async fn wait_drained(&self) { + while self.active.load(Ordering::Relaxed) > 0 { + self.drained.notified().await; + } + } +} + fn sync_serve_tokio( - mut cfg: WorkerTcpServerConfig, + mut cfg: ServerLaunch, stop: Arc, + stop_rx: StopRx, + service_ready: Arc, ready: Option, ) -> RS<()> { - if cfg.worker_count() == 0 { + if cfg.cfg().worker_count() == 0 { return Err(m_error!(EC::ParseErr, "invalid tokio worker count")); } - let listen_addr: SocketAddr = format!("{}:{}", cfg.listen_ip(), cfg.listen_port()) - .parse() - .map_err(|e| m_error!(EC::ParseErr, "parse tokio tcp listen address error", e))?; - let conn_id_alloc = Arc::new(AtomicU64::new(1)); - let inboxes: Vec<_> = (0..cfg.worker_count()) - .map(|_| Arc::new(SegQueue::::new())) - .collect(); - let bus_mailboxes: Vec<_> = (0..cfg.worker_count()) + let bus_mailboxes: Vec<_> = (0..cfg.cfg().worker_count()) .map(|_| Arc::new(SegQueue::::new())) .collect(); - let listener = match cfg.take_prebound_listener() { - Some(listener) => listener, - None => create_listener(listen_addr)?, - }; + let bus_wakes: Vec<_> = (0..cfg.cfg().worker_count()) + .map(|_| Arc::new(Notify::new())) + .collect(); + let (started_tx, started_rx) = mpsc::channel::>(); + let (rpc_ready_tx, rpc_ready_rx) = mpsc::channel::>(); - let mut handles = Vec::with_capacity(cfg.worker_count()); - for worker_id in 0..cfg.worker_count() { + let mut handles = Vec::with_capacity(cfg.cfg().worker_count()); + for worker_id in 0..cfg.cfg().worker_count() { let worker_cfg = WorkerBuildConfig::from_server_config(&cfg, worker_id)?; - let inbox = inboxes[worker_id].clone(); - let all_inboxes = inboxes.clone(); let bus_inbox = bus_mailboxes[worker_id].clone(); + let bus_wake = bus_wakes[worker_id].clone(); let all_bus_mailboxes = bus_mailboxes.clone(); + let all_bus_wakes = bus_wakes.clone(); let conn_id_alloc = conn_id_alloc.clone(); let stop = stop.clone(); - let listener = listener - .try_clone() - .map_err(|e| m_error!(EC::NetErr, "clone tokio tcp listener error", e))?; + let stop_rx = stop_rx.clone(); + let service_ready = service_ready.clone(); + let started_tx = started_tx.clone(); + let rpc_ready_tx = rpc_ready_tx.clone(); + let listener = if let Some(prebound) = cfg.take_prebound_listener() { + prebound + } else { + let worker_port = cfg.cfg().listen_port_for_worker(worker_id)?; + let listen_addr: SocketAddr = format!("{}:{}", cfg.cfg().listen_ip(), worker_port) + .parse() + .map_err(|e| { + m_error!( + EC::ParseErr, + format!("parse tokio tcp listen address error: {}", worker_port), + e + ) + })?; + create_listener(listen_addr)? + }; let handle = thread::Builder::new() .name(format!("tokio-tcp-worker-{worker_id}")) .spawn(move || { @@ -861,9 +451,11 @@ fn sync_serve_tokio( worker.worker_id(), worker.registry().clone(), all_bus_mailboxes, + all_bus_wakes, ); let worker_id = worker.worker_id(); let server_instance_id = worker.server_instance_id(); + let conn_tasks = TokioConnTaskState::new(); let runtime = CurrentThreadTaskRuntime::new() .map_err(|e| m_error!(EC::TokioErr, "build tokio worker runtime error", e))?; set_current_worker_local(as_worker_local_ref(new_session_bound_worker_runtime( @@ -882,6 +474,7 @@ fn sync_serve_tokio( let listener = TokioTcpListener::from_std(listener) .map_err(|e| m_error!(EC::NetErr, "convert tokio tcp listener error", e))?; worker.ensure_partition_rpc_handler()?; + recover_worker_log_tokio(&worker)?; let (_task_notifier, task_waiter) = notify_wait(); let join = spawn_local_task( task_waiter.into(), @@ -889,14 +482,18 @@ fn sync_serve_tokio( run_worker_loop_tokio( worker, listener, - inbox, - all_inboxes, bus_inbox, message_bus, + bus_wake, conn_id_alloc, stop, + stop_rx, + service_ready, + conn_tasks.clone(), + Some(rpc_ready_tx), ), )?; + let _ = started_tx.send(Ok(())); match join.await.map_err(|e| { m_error!(EC::TokioErr, "join tokio worker loop task error", e) })? { @@ -914,6 +511,32 @@ fn sync_serve_tokio( .map_err(|e| m_error!(EC::ThreadErr, "spawn tokio worker error", e))?; handles.push(handle); } + drop(started_tx); + drop(rpc_ready_tx); + + for _ in 0..cfg.cfg().worker_count() { + let started = started_rx.recv().map_err(|_| { + m_error!( + EC::ThreadErr, + "tokio worker startup barrier channel closed unexpectedly" + ) + })?; + started?; + } + + // RPC-ready barrier: every worker must report that its message bus, + // partition rpc handler and main loop are active before the backend is + // externally considered ready. + for _ in 0..cfg.cfg().worker_count() { + let ready = rpc_ready_rx.recv().map_err(|_| { + m_error!( + EC::ThreadErr, + "tokio worker rpc-ready barrier channel closed unexpectedly" + ) + })?; + ready?; + } + service_ready.store(true, Ordering::Relaxed); // Tokio mode has no separate recovery barrier after the listener is bound // and the worker threads are spawned, so this is the earliest point where @@ -936,199 +559,183 @@ fn sync_serve_tokio( async fn run_worker_loop_tokio( worker: WorkerRuntime, listener: TokioTcpListener, - inbox: Arc>, - inboxes: Vec>>, bus_inbox: Arc>, message_bus: Arc, + bus_wake: Arc, conn_id_alloc: Arc, stop: Arc, + mut stop_rx: StopRx, + service_ready: Arc, + conn_tasks: TokioConnTaskState, + rpc_ready_tx: Option>>, ) -> RS<()> { - let mut connections = HashMap::::new(); - let mut async_funcs = FallbackAsyncFuncState::new(); - let idle_sleep = Duration::from_millis(1); - + scoped_task_trace!(); + if let Some(tx) = rpc_ready_tx { + let _ = tx.send(Ok(())); + } while !stop.load(Ordering::Relaxed) { - let mut progressed = false; - trace!( - worker_id = worker.worker_id(), - connection_count = connections.len(), - pending_async_tasks = async_funcs.tasks.len(), - "tokio worker loop iteration begin" - ); - progressed |= drain_accepted_connections_tokio( - &listener, - &worker, - &inboxes, - &mut connections, - &conn_id_alloc, - ) - .await?; - trace!( - worker_id = worker.worker_id(), - progressed, - connection_count = connections.len(), - pending_async_tasks = async_funcs.tasks.len(), - "tokio worker loop after accept" - ); - progressed |= - drain_transferred_connections_tokio(&worker, inbox.as_ref(), &mut connections)?; - trace!( - worker_id = worker.worker_id(), - progressed, - connection_count = connections.len(), - pending_async_tasks = async_funcs.tasks.len(), - "tokio worker loop after transfer" - ); - progressed |= drain_message_bus_tokio(bus_inbox.as_ref(), message_bus.as_ref())?; - trace!( - worker_id = worker.worker_id(), - progressed, - connection_count = connections.len(), - pending_async_tasks = async_funcs.tasks.len(), - "tokio worker loop after message bus" - ); - progressed |= async_funcs.drain_completions(); - trace!( - worker_id = worker.worker_id(), - progressed, - connection_count = connections.len(), - pending_async_tasks = async_funcs.tasks.len(), - "tokio worker loop after drain completions" - ); - progressed |= async_funcs.poll_ready(&mut connections, &inboxes)?; - trace!( - worker_id = worker.worker_id(), - progressed, - connection_count = connections.len(), - pending_async_tasks = async_funcs.tasks.len(), - "tokio worker loop after poll_ready" - ); - progressed |= drive_connections(&worker, &mut async_funcs, &mut connections, &inboxes)?; - trace!( - worker_id = worker.worker_id(), - progressed, - connection_count = connections.len(), - pending_async_tasks = async_funcs.tasks.len(), - "tokio worker loop after drive_connections" - ); - - if !progressed { - mudu_sys::task_async::sleep(idle_sleep).await?; + if stop_rx.is_stopped() { + break; + } + while drain_message_bus_tokio(bus_inbox.as_ref(), message_bus.as_ref())? {} + tokio::select! { + accept_result = listener.accept() => { + let (stream, remote_addr) = accept_result + .map_err(|err| m_error!(EC::NetErr, "accept tokio tcp connection error", err))?; + let conn_id = conn_id_alloc.fetch_add(1, Ordering::Relaxed); + let worker = worker.clone(); + let stop = stop.clone(); + let service_ready = service_ready.clone(); + let conn_tasks = conn_tasks.clone(); + trace!( + worker_id = worker.worker_id(), + conn_id, + remote = %remote_addr, + "tokio accepted connection" + ); + conn_tasks.on_spawn(); + let stop_rx_conn = stop_rx.clone(); + let _ = spawn_local_detached( + &format!("tokio_conn_{conn_id}"), + async move { + let result = + handle_tokio_connection( + worker, + stream, + conn_id, + remote_addr, + stop, + stop_rx_conn, + service_ready, + ) + .await; + conn_tasks.on_finish(); + result + }, + ); + } + _ = bus_wake.notified() => {} + changed = stop_rx.changed() => { + if !changed || stop_rx.is_stopped() { + break; + } + } + else => { + break; + } } } + let _ = + tokio::time::timeout(std::time::Duration::from_secs(3), conn_tasks.wait_drained()).await; trace!( worker_id = worker.worker_id(), - remaining_connections = connections.len(), - pending_async_tasks = async_funcs.tasks.len(), "tokio worker loop observed stop" ); Ok(()) } -fn drain_message_bus_tokio( - inbox: &SegQueue, - message_bus: &TokioWorkerMessageBus, -) -> RS { - let mut progressed = false; - while let Some(envelope) = inbox.pop() { - progressed = true; - message_bus.handle_incoming(envelope)?; - } - Ok(progressed) -} - -async fn drain_accepted_connections_tokio( - listener: &TokioTcpListener, - worker: &WorkerRuntime, - inboxes: &[Arc>], - connections: &mut HashMap, - conn_id_alloc: &AtomicU64, -) -> RS { - let mut progressed = false; - loop { - match poll_accept_once(listener).await? { - Some((stream, remote_addr)) => { - progressed = true; - route_accepted_connection( - worker, - inboxes, - connections, - conn_id_alloc, - stream, - remote_addr, - register_connection_tokio, - |stream| { - stream.into_std().map_err(|e| { - m_error!(EC::NetErr, "convert accepted tokio stream to std error", e) - }) - }, - )?; - } - None => break, - } - } - Ok(progressed) -} - -fn drain_transferred_connections_tokio( - worker: &WorkerRuntime, - inbox: &SegQueue, - connections: &mut HashMap, -) -> RS { - drain_transferred_connections_common(worker, inbox, connections, |connections, connection| { - connection.stream.set_nonblocking(true).map_err(|e| { - m_error!( - EC::NetErr, - "set transferred tokio stream nonblocking error", - e - ) - })?; - let stream = TokioTcpStream::from_std(connection.stream).map_err(|e| { +fn recover_worker_log_tokio(worker: &WorkerRuntime) -> RS<()> { + let Some(log) = worker.worker_log() else { + return Ok(()); + }; + let chunk_paths = log.chunk_paths_sorted()?; + for path in chunk_paths { + let bytes = std::fs::read(&path).map_err(|e| { m_error!( - EC::NetErr, - "convert transferred std stream to tokio error", + EC::IOErr, + format!("read worker log chunk {} error", path.display()), e ) })?; - register_connection_tokio( - connections, - connection.transfer.conn_id(), - connection.transfer.remote_addr(), - stream, - ) - }) + if bytes.is_empty() { + continue; + } + let frames = decode_frames(&bytes)?; + let batches = decode_xl_batches(&frames)?; + for batch in batches { + worker.replay_log_batch(batch)?; + } + } + Ok(()) } -fn register_connection_tokio( - connections: &mut HashMap, +async fn handle_tokio_connection( + worker: WorkerRuntime, + mut stream: TokioTcpStream, conn_id: u64, remote_addr: SocketAddr, - stream: TokioTcpStream, + stop: Arc, + mut stop_rx: StopRx, + service_ready: Arc, ) -> RS<()> { + scoped_task_trace!(); stream .set_nodelay(true) .map_err(|e| m_error!(EC::NetErr, "set tokio connection nodelay error", e))?; - connections.insert( - conn_id, - TokioWorkerConnection { - core: ConnectionCore::new(conn_id, remote_addr), - stream: Some(stream), - }, - ); + let mut read_buf: Vec = Vec::with_capacity(8192); + let mut chunk = vec![0u8; 8192]; + loop { + if stop.load(Ordering::Relaxed) || stop_rx.is_stopped() { + break; + } + let read = tokio::select! { + read_result = stream.read(&mut chunk) => { + read_result.map_err(|e| m_error!(EC::NetErr, "read tokio tcp request error", e))? + } + changed = stop_rx.changed() => { + if !changed || stop_rx.is_stopped() { + break; + } + continue; + } + }; + if read == 0 { + break; + } + read_buf.extend_from_slice(&chunk[..read]); + while let Some((frame, consumed)) = try_decode_next_frame(&read_buf)? { + read_buf.drain(0..consumed); + if !service_ready.load(Ordering::Relaxed) { + let err = m_error!(EC::InternalErr, "server is not ready"); + let payload = encode_merror_response(frame.header().request_id(), &err)?; + stream + .write_all(&payload) + .await + .map_err(|e| m_error!(EC::NetErr, "write tokio tcp response error", e))?; + continue; + } + match dispatch_frame_async(&worker, conn_id, &frame).await { + Ok(HandleResult::Response(payload)) => { + stream + .write_all(&payload) + .await + .map_err(|e| m_error!(EC::NetErr, "write tokio tcp response error", e))?; + } + Err(err) => { + let payload = encode_merror_response(frame.header().request_id(), &err)?; + stream + .write_all(&payload) + .await + .map_err(|e| m_error!(EC::NetErr, "write tokio tcp response error", e))?; + } + } + } + } + worker.close_connection_sessions(conn_id)?; + trace!(worker_id = worker.worker_id(), conn_id, remote = %remote_addr, "tokio connection closed"); Ok(()) } -async fn poll_accept_once(listener: &TokioTcpListener) -> RS> { - poll_fn(|cx| match listener.poll_accept(cx) { - Poll::Ready(Ok(pair)) => Poll::Ready(Ok(Some(pair))), - Poll::Ready(Err(err)) => Poll::Ready(Err(m_error!( - EC::NetErr, - "accept tokio tcp connection error", - err - ))), - Poll::Pending => Poll::Ready(Ok(None)), - }) - .await +fn drain_message_bus_tokio( + inbox: &SegQueue, + message_bus: &TokioWorkerMessageBus, +) -> RS { + let mut progressed = false; + while let Some(envelope) = inbox.pop() { + progressed = true; + message_bus.handle_incoming(envelope)?; + } + Ok(progressed) } fn create_listener(listen_addr: SocketAddr) -> RS { @@ -1154,214 +761,6 @@ fn create_listener(listen_addr: SocketAddr) -> RS { Ok(socket.into()) } -fn enqueue_transfer( - inboxes: &[Arc>], - conn_id: u64, - target_worker: usize, - remote_addr: SocketAddr, - stream: TcpStream, - session_ids: Vec, - session_open_action: Option, -) -> RS<()> { - let target_inbox = inboxes.get(target_worker).ok_or_else(|| { - m_error!( - EC::InternalErr, - format!("route target worker {} is out of range", target_worker) - ) - })?; - target_inbox.push(TransferredConnection { - transfer: ConnectionTransfer::new( - conn_id, - target_worker, - crate::server::connection_state::ConnectionState::Accepted, - remote_addr, - ), - stream, - session_ids, - session_open_action, - }); - Ok(()) -} - -fn route_accepted_connection( - worker: &WorkerRuntime, - inboxes: &[Arc>], - connections: &mut HashMap, - conn_id_alloc: &AtomicU64, - stream: S, - remote_addr: SocketAddr, - register_local: RegisterLocal, - into_transfer: IntoTransfer, -) -> RS<()> -where - RegisterLocal: FnOnce(&mut HashMap, u64, SocketAddr, S) -> RS<()>, - IntoTransfer: FnOnce(S) -> RS, -{ - let conn_id = conn_id_alloc.fetch_add(1, Ordering::Relaxed); - let target_worker = worker.route_connection(conn_id, remote_addr); - if target_worker == worker.worker_index() { - register_local(connections, conn_id, remote_addr, stream) - } else { - enqueue_transfer( - inboxes, - conn_id, - target_worker, - remote_addr, - into_transfer(stream)?, - Vec::new(), - None, - ) - } -} - -fn drain_transferred_connections_common( - worker: &WorkerRuntime, - inbox: &SegQueue, - connections: &mut HashMap, - mut register: Register, -) -> RS -where - C: BackendConnection, - Register: FnMut(&mut HashMap, TransferredConnection) -> RS<()>, -{ - let mut progressed = false; - while let Some(connection) = inbox.pop() { - progressed = true; - worker.adopt_connection_sessions(connection.transfer.conn_id(), &connection.session_ids)?; - let conn_id = connection.transfer.conn_id(); - let action = connection.session_open_action; - register(connections, connection)?; - if let Some(action) = action { - let payload = match worker.open_session_with_config(conn_id, action.config()) { - Ok(session_id) => encode_session_create_response( - action.request_id(), - &SessionCreateResponse::new(session_id), - )?, - Err(err) => encode_merror_response(action.request_id(), &err)?, - }; - if let Some(registered) = connections.get_mut(&conn_id) { - registered.core_mut().write_buf.extend_from_slice(&payload); - } - } - } - Ok(progressed) -} - -fn drive_connections( - worker: &WorkerRuntime, - async_funcs: &mut FallbackAsyncFuncState, - connections: &mut HashMap, - inboxes: &[Arc>], -) -> RS { - let mut progressed = false; - let conn_ids: Vec = connections.keys().copied().collect(); - let mut closed = Vec::new(); - - for conn_id in conn_ids { - trace!(conn_id, "tokio drive connection begin"); - let Some(connection) = connections.get_mut(&conn_id) else { - continue; - }; - progressed |= connection.write_pending()?; - trace!( - conn_id, - state = ?connection.core().state, - write_buf_len = connection.core().write_buf.len(), - read_buf_len = connection.core().read_buf.len(), - "tokio drive connection after write_pending" - ); - let connection_progress = read_and_dispatch(worker, async_funcs, connection, inboxes)?; - progressed |= connection_progress; - trace!( - conn_id, - connection_progress, - state = ?connection.core().state, - write_buf_len = connection.core().write_buf.len(), - read_buf_len = connection.core().read_buf.len(), - "tokio drive connection after read_and_dispatch" - ); - if connection.core().state == crate::server::connection_state::ConnectionState::Closing - && connection.core().write_buf.is_empty() - { - closed.push((conn_id, connection.core().transferred)); - } - } - - for (conn_id, transferred) in closed { - if !transferred { - worker.close_connection_sessions(conn_id)?; - } - connections.remove(&conn_id); - } - Ok(progressed) -} - -fn read_and_dispatch( - worker: &WorkerRuntime, - async_funcs: &mut FallbackAsyncFuncState, - connection: &mut C, - inboxes: &[Arc>], -) -> RS { - let mut progressed = connection.read_available()?; - - while let Some((frame, consumed)) = try_decode_next_frame(&connection.core().read_buf)? { - progressed = true; - let response = dispatch_frame(worker, connection.core().conn_id, async_funcs, &frame); - connection.core_mut().read_buf.drain(0..consumed); - match response { - Ok(Some(result)) => { - apply_handle_result_to_connection(connection, inboxes, result)?; - if connection.core().transferred { - return Ok(true); - } - } - Ok(None) => {} - Err(err) => { - let payload = encode_merror_response(frame.header().request_id(), &err)?; - connection.core_mut().write_buf.extend_from_slice(&payload); - } - } - } - Ok(progressed) -} - -fn dispatch_frame( - worker: &WorkerRuntime, - conn_id: u64, - async_funcs: &mut FallbackAsyncFuncState, - frame: &Frame, -) -> RS> { - let request_id = frame.header().request_id(); - trace!(conn_id, request_id, "tokio dispatch frame begin"); - let worker = worker.clone(); - let frame = frame.clone(); - let mut future = Box::pin(async move { - mudu_utils::scoped_task_trace!(); - dispatch_frame_async(&worker, conn_id, &frame).await - }); - let waker = waker(Arc::new(AsyncFuncTaskWaker::new( - 0, - Arc::new(SegQueue::new()), - Arc::new(AtomicBool::new(false)), - ))); - let mut cx = Context::from_waker(&waker); - match future.as_mut().poll(&mut cx) { - Poll::Ready(Ok(result)) => { - trace!(conn_id, request_id, "tokio dispatch frame ready ok"); - Ok(Some(result)) - } - Poll::Ready(Err(err)) => { - trace!(conn_id, request_id, err = %err, "tokio dispatch frame ready err"); - Err(err) - } - Poll::Pending => { - trace!(conn_id, request_id, "tokio dispatch frame pending"); - async_funcs.enqueue_future(conn_id, request_id, future); - Ok(None) - } - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/mudu_kernel/src/server/server_cfg.rs b/mudu_kernel/src/server/server_cfg.rs new file mode 100644 index 0000000..ed8c83e --- /dev/null +++ b/mudu_kernel/src/server/server_cfg.rs @@ -0,0 +1,117 @@ +use crate::server::message_bus_api::ServerInstanceId; +use crate::server::routing::RoutingMode; +use mudu::common::id::gen_oid; +use mudu::common::result::RS; +use mudu::error::ec::EC; +use mudu::m_error; + +/// Configuration shared by both execution paths of the `client` backend. +/// +/// The same configuration is consumed by both the io_uring worker-ring backend +/// and the Tokio backend so they keep the worker model and protocol surface +/// aligned. +pub struct ServerCfg { + server_instance_id: ServerInstanceId, + worker_count: usize, + listen_ip: String, + listen_port: u16, + multi_port: bool, + data_dir: String, + log_dir: String, + log_chunk_size: u64, + routing_mode: RoutingMode, +} + +impl ServerCfg { + /// Creates a backend configuration. + /// + /// The resulting value can be used by both the io_uring and Tokio TCP + /// backends with the same externally visible behavior. + pub fn new( + worker_count: usize, + listen_ip: String, + listen_port: u16, + data_dir: String, + log_dir: String, + routing_mode: RoutingMode, + ) -> RS { + Ok(Self { + server_instance_id: gen_oid(), + worker_count, + listen_ip, + listen_port, + multi_port: false, + data_dir, + log_dir, + log_chunk_size: 64 * 1024 * 1024, + routing_mode, + }) + } + + pub fn with_log_chunk_size(mut self, log_chunk_size: u64) -> Self { + self.log_chunk_size = log_chunk_size; + self + } + + pub fn with_multi_port(mut self, multi_port: bool) -> Self { + self.multi_port = multi_port; + self + } + + pub fn server_instance_id(&self) -> ServerInstanceId { + self.server_instance_id + } + + pub fn worker_count(&self) -> usize { + self.worker_count + } + + pub fn listen_ip(&self) -> &str { + &self.listen_ip + } + + pub fn listen_port(&self) -> u16 { + self.listen_port + } + + pub fn multi_port(&self) -> bool { + self.multi_port + } + + pub fn listen_port_for_worker(&self, worker_index: usize) -> RS { + if !self.multi_port { + return Ok(self.listen_port); + } + let worker_offset = u16::try_from(worker_index).map_err(|_| { + m_error!( + EC::ParseErr, + format!("worker index too large for port mapping: {}", worker_index) + ) + })?; + self.listen_port.checked_add(worker_offset).ok_or_else(|| { + m_error!( + EC::ParseErr, + format!( + "worker listen port overflow: base_port={}, worker_index={}", + self.listen_port, worker_index + ) + ) + }) + } + + pub fn log_dir(&self) -> &str { + &self.log_dir + } + + pub fn data_dir(&self) -> &str { + &self.data_dir + } + + pub fn log_chunk_size(&self) -> u64 { + self.log_chunk_size + } + + pub fn routing_mode(&self) -> RoutingMode { + self.routing_mode + } +} diff --git a/mudu_kernel/src/server/server_launch.rs b/mudu_kernel/src/server/server_launch.rs new file mode 100644 index 0000000..a039954 --- /dev/null +++ b/mudu_kernel/src/server/server_launch.rs @@ -0,0 +1,48 @@ +use std::net::TcpListener; + +use mudu::common::result::RS; + +use crate::server::server_cfg::ServerCfg; +use crate::server::server_runtime_deps::ServerRuntimeDeps; + +/// A single server start request, including one-shot resources such as listeners. +pub struct ServerLaunch { + cfg: ServerCfg, + deps: ServerRuntimeDeps, + prebound_listener: Option, +} + +impl ServerLaunch { + pub fn new(cfg: ServerCfg, deps: ServerRuntimeDeps) -> Self { + Self { + cfg, + deps, + prebound_listener: None, + } + } + + pub fn from_cfg(cfg: ServerCfg) -> RS { + let deps = ServerRuntimeDeps::from_cfg(&cfg)?; + Ok(Self::new(cfg, deps)) + } + + pub fn with_prebound_listener(mut self, listener: TcpListener) -> Self { + self.prebound_listener = Some(listener); + self + } + + pub fn cfg(&self) -> &ServerCfg { + &self.cfg + } + + pub fn deps(&self) -> &ServerRuntimeDeps { + &self.deps + } + + pub fn take_prebound_listener(&mut self) -> Option { + self.prebound_listener.take() + } +} + +/// Alias used by backend construction code that does not need a transport-specific name. +pub type WorkerTcpBackendConfig = ServerLaunch; diff --git a/mudu_kernel/src/server/server_runtime_deps.rs b/mudu_kernel/src/server/server_runtime_deps.rs new file mode 100644 index 0000000..de14aa7 --- /dev/null +++ b/mudu_kernel/src/server/server_runtime_deps.rs @@ -0,0 +1,88 @@ +use std::sync::Arc; + +use mudu::common::result::RS; +use mudu::error::ec::EC; +use mudu::m_error; + +use crate::async_rt::contract::AsyncRuntime; +use crate::server::async_func_runtime::AsyncFuncInvokerPtr; +use crate::server::procedure_runtimes::ProcedureRuntimes; +use crate::server::server_cfg::ServerCfg; +use crate::server::worker_registry::{load_or_create_worker_registry, WorkerRegistry}; +use crate::wal::worker_log::WorkerLogBatching; + +/// Dependencies assembled for one server process after pure configuration is known. +pub struct ServerRuntimeDeps { + log_batching: WorkerLogBatching, + procedure_runtimes: ProcedureRuntimes, + worker_registry: Arc, + async_runtime: Option>, +} + +impl ServerRuntimeDeps { + pub fn from_cfg(cfg: &ServerCfg) -> RS { + let worker_registry = load_or_create_worker_registry(cfg.log_dir(), cfg.worker_count())?; + Ok(Self { + log_batching: WorkerLogBatching::default(), + procedure_runtimes: ProcedureRuntimes::default(), + worker_registry, + async_runtime: None, + }) + } + + pub fn with_log_batching(mut self, log_batching: WorkerLogBatching) -> Self { + self.log_batching = log_batching; + self + } + + pub fn with_shared_procedure_runtime(mut self, runtime: AsyncFuncInvokerPtr) -> Self { + self.procedure_runtimes = ProcedureRuntimes::Shared(runtime); + self + } + + /// Installs isolated procedure invokers for each worker thread. + pub fn with_worker_procedure_runtimes(mut self, runtimes: Vec) -> Self { + self.procedure_runtimes = ProcedureRuntimes::PerWorker(runtimes); + self + } + + pub fn with_worker_registry( + mut self, + cfg: &ServerCfg, + worker_registry: Arc, + ) -> RS { + if worker_registry.workers().len() != cfg.worker_count() { + return Err(m_error!( + EC::ParseErr, + format!( + "worker registry count {} does not match expected {}", + worker_registry.workers().len(), + cfg.worker_count() + ) + )); + } + self.worker_registry = worker_registry; + Ok(self) + } + + pub fn with_async_runtime(mut self, async_runtime: Option>) -> Self { + self.async_runtime = async_runtime; + self + } + + pub fn log_batching(&self) -> WorkerLogBatching { + self.log_batching + } + + pub fn procedure_runtime_for_worker(&self, worker_id: usize) -> Option { + self.procedure_runtimes.for_worker(worker_id) + } + + pub fn worker_registry(&self) -> Arc { + self.worker_registry.clone() + } + + pub fn async_runtime(&self) -> Option> { + self.async_runtime.clone() + } +} diff --git a/mudu_kernel/src/server/session_bound_worker_runtime.rs b/mudu_kernel/src/server/session_bound_worker_runtime.rs index fa76464..652d4a5 100644 --- a/mudu_kernel/src/server/session_bound_worker_runtime.rs +++ b/mudu_kernel/src/server/session_bound_worker_runtime.rs @@ -1,7 +1,7 @@ use crate::contract::meta_mgr::MetaMgr; use crate::server::message_bus_api::{message_bus_for_worker, MessageBusRef}; use crate::server::request_response_worker::{RequestResponseWorker, WorkerRuntimeRef}; -use crate::server::routing::{SessionOpenConfig, SessionOpenTransferAction}; +use crate::server::routing::SessionOpenConfig; use crate::server::worker::WorkerRuntime; use crate::server::worker_local::{WorkerExecute, WorkerLocal, WorkerLocalRef}; use crate::server::worker_registry::WorkerRegistry; @@ -141,14 +141,6 @@ impl RequestResponseWorker for SessionBoundWorkerRuntime { self.worker.open_session_with_config(conn_id, config) } - fn prepare_connection_transfer( - &self, - conn_id: u64, - action: Option, - ) -> RS> { - self.worker.prepare_connection_transfer(conn_id, action) - } - fn close_session_for_connection(&self, conn_id: u64, session_id: OID) -> RS { self.worker.close_session(conn_id, session_id) } diff --git a/mudu_kernel/src/server/worker.rs b/mudu_kernel/src/server/worker.rs index 82ff296..3693d74 100644 --- a/mudu_kernel/src/server/worker.rs +++ b/mudu_kernel/src/server/worker.rs @@ -3,9 +3,7 @@ use crate::contract::meta_mgr::MetaMgr; use crate::mudu_conn::mudu_conn_core::MuduConnCore; use crate::server::async_func_runtime::AsyncFuncInvokerPtr; use crate::server::message_bus_api::ServerInstanceId; -use crate::server::routing::{ - route_worker, RoutingContext, RoutingMode, SessionOpenConfig, SessionOpenTransferAction, -}; +use crate::server::routing::SessionOpenConfig; use crate::server::session_bound_worker_runtime::{ as_worker_local_ref, new_session_bound_worker_runtime, }; @@ -31,7 +29,6 @@ use mudu_contract::database::sql_stmt::SQLStmt; use mudu_contract::protocol::{ProcedureInvokeRequest, ProcedureInvokeResponse}; use mudu_utils::task_trace; use std::collections::BTreeMap; -use std::net::SocketAddr; use std::sync::atomic::AtomicUsize; use std::sync::Arc; @@ -53,7 +50,6 @@ pub struct WorkerRuntime { worker_id: OID, partition_ids: Vec, worker_count: usize, - routing_mode: RoutingMode, contract: Arc, log_layout: WorkerLogLayout, procedure_runtime: Option, @@ -69,7 +65,6 @@ impl WorkerRuntime { pub fn new( identity: WorkerIdentity, worker_count: usize, - routing_mode: RoutingMode, log_dir: String, data_dir: String, log_chunk_size: u64, @@ -80,7 +75,6 @@ impl WorkerRuntime { Self::new_with_log_batching( identity, worker_count, - routing_mode, log_dir, data_dir, log_chunk_size, @@ -94,7 +88,6 @@ impl WorkerRuntime { pub fn new_with_log_batching( identity: WorkerIdentity, worker_count: usize, - routing_mode: RoutingMode, log_dir: String, data_dir: String, log_chunk_size: u64, @@ -106,7 +99,6 @@ impl WorkerRuntime { Self::new_with_log_batching_and_runtime( identity, worker_count, - routing_mode, log_dir, data_dir, log_chunk_size, @@ -121,7 +113,6 @@ impl WorkerRuntime { pub fn new_with_log_batching_and_runtime( identity: WorkerIdentity, worker_count: usize, - routing_mode: RoutingMode, log_dir: String, data_dir: String, log_chunk_size: u64, @@ -170,7 +161,6 @@ impl WorkerRuntime { worker_id, partition_ids: identity.partition_ids, worker_count, - routing_mode, contract: contract.clone(), log_layout, procedure_runtime, @@ -183,11 +173,6 @@ impl WorkerRuntime { self.server_instance_id } - pub fn route_connection(&self, conn_id: u64, remote_addr: SocketAddr) -> usize { - let ctx = RoutingContext::new(conn_id, remote_addr, None); - route_worker(&ctx, self.routing_mode, self.worker_count) - } - pub async fn delete_async(&self, key: &[u8]) -> RS<()> { self.contract.worker_delete_async(key).await } @@ -307,7 +292,7 @@ impl WorkerRuntime { self.range_in_session(session_id, start_key, end_key).await } - #[allow(dead_code)] + #[cfg(test)] fn execute_tx(&self, session_id: OID, instruction: WorkerExecute) -> RS<()> { match instruction { WorkerExecute::BeginTx => self @@ -805,48 +790,6 @@ impl WorkerRuntime { Ok(config.session_id()) } } - - pub fn prepare_connection_transfer( - &self, - conn_id: u64, - action: Option, - ) -> RS> { - if self.connection_has_active_tx(conn_id)? { - return Err(m_error!( - EC::TxErr, - format!( - "connection {} cannot be transferred while a session transaction is active", - conn_id - ) - )); - } - if let Some(action) = action { - let config = action.config(); - if config.session_id() != 0 { - self.ensure_session_owned_by_connection(conn_id, config.session_id())?; - } - } - self.session_manager.detach_connection_sessions(conn_id) - } - - pub fn adopt_connection_sessions(&self, conn_id: u64, session_ids: &[OID]) -> RS<()> { - self.session_manager - .adopt_connection_sessions(conn_id, session_ids) - } - - fn connection_has_active_tx(&self, conn_id: u64) -> RS { - self.session_manager.connection_has_active_tx(conn_id) - } -} - -#[allow(dead_code)] -fn worker_log_oid(worker_id: usize) -> OID { - worker_id as u128 + 1 -} - -#[allow(dead_code)] -fn is_key_in_range(key: &[u8], start_key: &[u8], end_key: &[u8]) -> bool { - key >= start_key && (end_key.is_empty() || key < end_key) } #[cfg(test)] @@ -913,7 +856,6 @@ mod tests { WorkerRuntime::new( identity, worker_count, - RoutingMode::ConnectionId, log_dir.to_string(), data_dir.to_string(), 4096, @@ -1112,7 +1054,6 @@ mod tests { let _worker = WorkerRuntime::new( identity, 1, - RoutingMode::ConnectionId, log_dir.clone(), log_dir.clone(), 4096, @@ -1231,71 +1172,6 @@ mod tests { assert_eq!(worker.get(b"a").unwrap(), Some(b"1".to_vec())); } - #[test] - fn worker_can_transfer_connection_sessions_between_partitions() { - let (log_dir, registry) = test_registry(2); - let source = test_worker(0, 2, &log_dir, &log_dir, registry.clone(), None); - let target = test_worker(1, 2, &log_dir, &log_dir, registry.clone(), None); - - let conn_id = 41; - let session_a = source.create_session(conn_id).unwrap(); - let session_b = source.create_session(conn_id).unwrap(); - let target_identity = registry.worker(1).unwrap(); - let action = SessionOpenTransferAction::new( - 7, - SessionOpenConfig::new(session_a, target_identity.worker_id, 1), - ); - - let transferred = source - .prepare_connection_transfer(conn_id, Some(action)) - .unwrap(); - assert_eq!(transferred.len(), 2); - assert!( - futures::executor::block_on(source.get_for_connection(conn_id, session_a, b"k")) - .is_err() - ); - - target - .adopt_connection_sessions(conn_id, &transferred) - .unwrap(); - assert_eq!( - target - .open_session_with_config(conn_id, action.config()) - .unwrap(), - session_a - ); - target - .put_for_connection(conn_id, session_b, b"k".to_vec(), b"v".to_vec()) - .unwrap(); - assert_eq!( - futures::executor::block_on(target.get_for_connection(conn_id, session_b, b"k")) - .unwrap(), - Some(b"v".to_vec()) - ); - } - - #[test] - fn worker_rejects_transfer_with_active_transaction() { - let (log_dir, registry) = test_registry(2); - let worker = test_worker(0, 2, &log_dir, &log_dir, registry.clone(), None); - let conn_id = 51; - let session_id = worker.create_session(conn_id).unwrap(); - worker - .execute_tx(session_id, WorkerExecute::BeginTx) - .unwrap(); - - let err = worker - .prepare_connection_transfer( - conn_id, - Some(SessionOpenTransferAction::new( - 1, - SessionOpenConfig::new(session_id, registry.worker(1).unwrap().worker_id, 1), - )), - ) - .unwrap_err(); - assert!(err.to_string().contains("cannot be transferred")); - } - #[tokio::test(flavor = "current_thread")] async fn worker_snapshot_isolation_hides_later_commits_from_existing_tx() { let (log_dir, registry) = test_registry(1); diff --git a/mudu_kernel/src/server/worker_session_manager.rs b/mudu_kernel/src/server/worker_session_manager.rs index 3d2a0da..b7860d5 100644 --- a/mudu_kernel/src/server/worker_session_manager.rs +++ b/mudu_kernel/src/server/worker_session_manager.rs @@ -166,53 +166,6 @@ impl WorkerSessionManager { } } - pub(crate) fn adopt_connection_sessions(&self, conn_id: u64, session_ids: &[OID]) -> RS<()> { - if session_ids.is_empty() { - return Ok(()); - } - let conn_sessions = self.connection_sessions(conn_id); - for &session_id in session_ids { - self.session_owner - .insert_sync(session_id, conn_id) - .map_err(|_| { - m_error!( - EC::ExistingSuchElement, - format!("session {} already exists on target worker", session_id) - ) - })?; - if self - .session_contexts - .insert_sync( - session_id, - Arc::new(SessionContext::new(self.meta_mgr.clone())), - ) - .is_err() - { - let _ = self.session_owner.remove_sync(&session_id); - return Err(m_error!( - EC::ExistingSuchElement, - format!( - "session {} context already exists on target worker", - session_id - ) - )); - } - let _ = conn_sessions.insert_sync(session_id, ()); - self.active_sessions.fetch_add(1, Ordering::Relaxed); - } - Ok(()) - } - - pub(crate) fn connection_has_active_tx(&self, conn_id: u64) -> RS { - let session_ids = self.connection_session_ids(conn_id); - for session_id in session_ids { - if self.has_session_tx(session_id)? { - return Ok(true); - } - } - Ok(false) - } - pub(crate) fn has_session_tx(&self, session_id: OID) -> RS { Ok(self .session_context(session_id)? @@ -250,24 +203,6 @@ impl WorkerSessionManager { f(session.tx_manager_cloned()) } - pub(crate) fn detach_connection_sessions(&self, conn_id: u64) -> RS> { - let Some((_conn_id, conn_sessions)) = self.connection_sessions.remove_sync(&conn_id) else { - return Ok(Vec::new()); - }; - let mut session_ids = Vec::new(); - conn_sessions.iter_sync(|session_id, _| { - session_ids.push(*session_id); - true - }); - for &session_id in &session_ids { - if self.session_owner.remove_sync(&session_id).is_some() { - self.active_sessions.fetch_sub(1, Ordering::Relaxed); - } - let _ = self.session_contexts.remove_sync(&session_id); - } - Ok(session_ids) - } - fn connection_sessions(&self, conn_id: u64) -> Arc> { if let Some(existing) = self.connection_sessions.get_sync(&conn_id) { return existing.get().clone(); @@ -287,18 +222,6 @@ impl WorkerSessionManager { } } } - - fn connection_session_ids(&self, conn_id: u64) -> Vec { - let Some(conn_sessions) = self.connection_sessions.get_sync(&conn_id) else { - return Vec::new(); - }; - let mut session_ids = Vec::new(); - conn_sessions.get().iter_sync(|session_id, _| { - session_ids.push(*session_id); - true - }); - session_ids - } } impl SessionContext { diff --git a/mudu_kernel/src/server/worker_storage.rs b/mudu_kernel/src/server/worker_storage.rs index bc216cd..6bc7a5b 100644 --- a/mudu_kernel/src/server/worker_storage.rs +++ b/mudu_kernel/src/server/worker_storage.rs @@ -1142,7 +1142,9 @@ mod tests { async fn _worker_storage_broadcasts_create_and_drop_to_peer_workers() -> RS<()> { let (mgr, _storage1, storage2, oid) = test_shared_storage().await?; let mut tx = begin_tx(1, vec![]); - storage2.put(oid, i32_bytes(7), i32_bytes(70), &mut tx).await?; + storage2 + .put(oid, i32_bytes(7), i32_bytes(70), &mut tx) + .await?; storage2.commit_tx(&mut tx).await?; assert!(mgr.get_table_by_id(oid).await.is_ok()); @@ -1150,7 +1152,10 @@ mod tests { assert!(mgr.get_table_by_id(oid).await.is_err()); let mut tx = begin_tx(2, vec![]); - let err = storage2.put(oid, i32_bytes(8), i32_bytes(80), &mut tx).await.unwrap_err(); + let err = storage2 + .put(oid, i32_bytes(8), i32_bytes(80), &mut tx) + .await + .unwrap_err(); assert!(format!("{err}").contains("no such table")); Ok(()) } @@ -1183,7 +1188,9 @@ mod tests { storage.bootstrap_existing_tables_async().await?; let mut tx = begin_tx(1, vec![]); - storage.put(oid, i32_bytes(1), i32_bytes(10), &mut tx).await?; + storage + .put(oid, i32_bytes(1), i32_bytes(10), &mut tx) + .await?; storage.commit_tx(&mut tx).await?; let mut read_tx = begin_tx(2, vec![]); assert_eq!( @@ -1205,7 +1212,9 @@ mod tests { let (storage, oid) = test_storage().await?; let mut tx = begin_tx(10, vec![]); - storage.put(oid, i32_bytes(1), i32_bytes(11), &mut tx).await?; + storage + .put(oid, i32_bytes(1), i32_bytes(11), &mut tx) + .await?; assert_eq!( storage.get(oid, &i32_bytes(1), &mut tx).await?, @@ -1230,12 +1239,16 @@ mod tests { async fn _worker_storage_snapshot_hides_later_commit() -> RS<()> { let (storage, oid) = test_storage().await?; let mut tx1 = begin_tx(1, vec![]); - storage.put(oid, i32_bytes(1), i32_bytes(10), &mut tx1).await?; + storage + .put(oid, i32_bytes(1), i32_bytes(10), &mut tx1) + .await?; storage.commit_tx(&mut tx1).await?; let mut old_tx = begin_tx(2, vec![]); let mut new_tx = begin_tx(3, vec![2]); - storage.put(oid, i32_bytes(1), i32_bytes(20), &mut new_tx).await?; + storage + .put(oid, i32_bytes(1), i32_bytes(20), &mut new_tx) + .await?; storage.commit_tx(&mut new_tx).await?; assert_eq!( @@ -1256,23 +1269,27 @@ mod tests { async fn _worker_storage_range_is_stable_with_snapshot() -> RS<()> { let (storage, oid) = test_storage().await?; let mut seed = begin_tx(1, vec![]); - storage.put(oid, i32_bytes(1), i32_bytes(10), &mut seed).await?; + storage + .put(oid, i32_bytes(1), i32_bytes(10), &mut seed) + .await?; storage.commit_tx(&mut seed).await?; let mut old_tx = begin_tx(2, vec![]); let mut new_tx = begin_tx(3, vec![2]); - storage.put(oid, i32_bytes(2), i32_bytes(20), &mut new_tx).await?; + storage + .put(oid, i32_bytes(2), i32_bytes(20), &mut new_tx) + .await?; storage.commit_tx(&mut new_tx).await?; let rows = storage .range( - oid, - ( - Included(i32_bytes(1).as_slice()), - Included(i32_bytes(9).as_slice()), - ), - &mut old_tx, - ) + oid, + ( + Included(i32_bytes(1).as_slice()), + Included(i32_bytes(9).as_slice()), + ), + &mut old_tx, + ) .await?; assert_eq!(rows, vec![(i32_bytes(1), i32_bytes(10))]); Ok(()) @@ -1289,13 +1306,19 @@ mod tests { async fn _worker_storage_first_committer_wins() -> RS<()> { let (storage, oid) = test_storage().await?; let mut seed = begin_tx(1, vec![]); - storage.put(oid, i32_bytes(1), i32_bytes(10), &mut seed).await?; + storage + .put(oid, i32_bytes(1), i32_bytes(10), &mut seed) + .await?; storage.commit_tx(&mut seed).await?; let mut tx1 = begin_tx(2, vec![]); let mut tx2 = begin_tx(3, vec![2]); - storage.put(oid, i32_bytes(1), i32_bytes(11), &mut tx1).await?; - storage.put(oid, i32_bytes(1), i32_bytes(12), &mut tx2).await?; + storage + .put(oid, i32_bytes(1), i32_bytes(11), &mut tx1) + .await?; + storage + .put(oid, i32_bytes(1), i32_bytes(12), &mut tx2) + .await?; storage.commit_tx(&mut tx1).await?; let err = storage.commit_tx(&mut tx2).await.unwrap_err(); @@ -1314,7 +1337,9 @@ mod tests { async fn _worker_storage_delete_respects_snapshot() -> RS<()> { let (storage, oid) = test_storage().await?; let mut seed = begin_tx(1, vec![]); - storage.put(oid, i32_bytes(1), i32_bytes(10), &mut seed).await?; + storage + .put(oid, i32_bytes(1), i32_bytes(10), &mut seed) + .await?; storage.commit_tx(&mut seed).await?; let mut old_tx = begin_tx(2, vec![]); @@ -1330,10 +1355,7 @@ mod tests { Some(i32_bytes(10)) ); let mut fresh_tx = begin_tx(4, vec![]); - assert_eq!( - storage.get(oid, &i32_bytes(1), &mut fresh_tx).await?, - None - ); + assert_eq!(storage.get(oid, &i32_bytes(1), &mut fresh_tx).await?, None); Ok(()) } @@ -1347,8 +1369,7 @@ mod tests { async fn _worker_storage_kv_snapshot_hides_later_commit() -> RS<()> { let (storage, _oid) = test_storage().await?; - storage - .worker_put_local(b"a".to_vec(), b"0".to_vec(), 1)?; + storage.worker_put_local(b"a".to_vec(), b"0".to_vec(), 1)?; let snapshot = WorkerSnapshot::new(2, vec![]); let prepared = storage.prepare_worker_kv_autocommit( @@ -1363,10 +1384,7 @@ mod tests { storage.kv_get(b"a", Some(&snapshot)).await?, Some(b"0".to_vec()) ); - assert_eq!( - storage.kv_get(b"a", None).await?, - Some(b"1".to_vec()) - ); + assert_eq!(storage.kv_get(b"a", None).await?, Some(b"1".to_vec())); Ok(()) } @@ -1380,11 +1398,9 @@ mod tests { async fn _worker_storage_kv_range_is_stable_with_snapshot() -> RS<()> { let (storage, _oid) = test_storage().await?; - storage - .worker_put_local(b"a".to_vec(), b"1".to_vec(), 1)?; + storage.worker_put_local(b"a".to_vec(), b"1".to_vec(), 1)?; let snapshot = WorkerSnapshot::new(2, vec![]); - storage - .worker_put_local(b"b".to_vec(), b"2".to_vec(), 3)?; + storage.worker_put_local(b"b".to_vec(), b"2".to_vec(), 3)?; let rows = storage.kv_range(b"a", b"z", Some(&snapshot)).await?; assert_eq!( @@ -1410,32 +1426,24 @@ mod tests { let snapshot1 = WorkerSnapshot::new(1, vec![]); let snapshot2 = WorkerSnapshot::new(2, vec![1]); - let prepared1 = storage - .prepare_worker_kv_commit( - &snapshot1, - snapshot1.xid(), - BTreeMap::from([(b"a".to_vec(), Some(b"1".to_vec()))]), - XLBatch::new(vec![]), - )?; - let prepared2 = storage - .prepare_worker_kv_commit( - &snapshot2, - snapshot2.xid(), - BTreeMap::from([(b"b".to_vec(), Some(b"2".to_vec()))]), - XLBatch::new(vec![]), - )?; + let prepared1 = storage.prepare_worker_kv_commit( + &snapshot1, + snapshot1.xid(), + BTreeMap::from([(b"a".to_vec(), Some(b"1".to_vec()))]), + XLBatch::new(vec![]), + )?; + let prepared2 = storage.prepare_worker_kv_commit( + &snapshot2, + snapshot2.xid(), + BTreeMap::from([(b"b".to_vec(), Some(b"2".to_vec()))]), + XLBatch::new(vec![]), + )?; storage.apply_prepared_commit(prepared1)?; storage.apply_prepared_commit(prepared2)?; - assert_eq!( - storage.kv_get(b"a", None).await?, - Some(b"1".to_vec()) - ); - assert_eq!( - storage.kv_get(b"b", None).await?, - Some(b"2".to_vec()) - ); + assert_eq!(storage.kv_get(b"a", None).await?, Some(b"1".to_vec())); + assert_eq!(storage.kv_get(b"b", None).await?, Some(b"2".to_vec())); Ok(()) } @@ -1473,10 +1481,7 @@ mod tests { storage.replay_batch(batch)?; - assert_eq!( - storage.kv_get(b"k", None).await?, - Some(b"v".to_vec()) - ); + assert_eq!(storage.kv_get(b"k", None).await?, Some(b"v".to_vec())); let mut tx = begin_tx(10, vec![]); assert_eq!( storage.get(oid, &i32_bytes(7), &mut tx).await?, @@ -1495,8 +1500,7 @@ mod tests { async fn _worker_storage_replay_batch_applies_kv_delete() -> RS<()> { let (storage, _oid) = test_storage().await?; - storage - .worker_put_local(b"k".to_vec(), b"v".to_vec(), 1)?; + storage.worker_put_local(b"k".to_vec(), b"v".to_vec(), 1)?; let batch = XLBatch::new(vec![crate::wal::xl_entry::XLEntry { xid: 2, diff --git a/mudu_kernel/src/server/x_contract.rs b/mudu_kernel/src/server/x_contract.rs index bd23560..f47ab15 100644 --- a/mudu_kernel/src/server/x_contract.rs +++ b/mudu_kernel/src/server/x_contract.rs @@ -9,7 +9,7 @@ use mudu_contract::tuple::build_tuple::build_tuple; use mudu_contract::tuple::tuple_binary::TupleBinary as TupleRaw; use mudu_contract::tuple::update_tuple::update_tuple; use mudu_type::dt_function::send_binary; -use mudu_utils::task_trace; +use mudu_utils::{scoped_task_trace, task_trace}; use std::ops::Bound; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::{Arc, Mutex}; @@ -1337,6 +1337,7 @@ impl XContract for WorkerXContract { values: &VecDatum, opt_insert: &OptInsert, ) -> RS<()> { + scoped_task_trace!(); let desc = self.meta_mgr.get_table_by_id(table_id).await?; self._insert(desc, tx_mgr, table_id, keys, values, opt_insert) .await @@ -1504,9 +1505,7 @@ fn build_tuple_for( let completed = completed .into_iter() .collect::>>() - .ok_or_else(|| { - m_error!(EC::TupleErr) - })?; + .ok_or_else(|| m_error!(EC::TupleErr))?; build_tuple(&completed, tuple_desc) } @@ -1810,13 +1809,7 @@ mod tests { let values = value_row(10); let opt_insert = OptInsert::default(); contract - .insert( - tx_mgr.clone(), - table_id, - &keys, - &values, - &opt_insert, - ) + .insert(tx_mgr.clone(), table_id, &keys, &values, &opt_insert) .await?; contract.commit_tx(tx_mgr).await?; @@ -1893,13 +1886,7 @@ mod tests { let select = VecSelTerm::new(vec![1]); let opt_read = OptRead::default(); let relation = contract - .read_key( - xid, - table_id, - &pred_key, - &select, - &opt_read, - ) + .read_key(xid, table_id, &pred_key, &select, &opt_read) .await?; assert_eq!(relation, Some(vec![datum(30)])); Ok(()) @@ -1979,12 +1966,12 @@ mod tests { let opt_insert = OptInsert::default(); contract .insert( - insert_tx.clone(), - table_id, - &insert_key, - &insert_value, - &opt_insert, - ) + insert_tx.clone(), + table_id, + &insert_key, + &insert_value, + &opt_insert, + ) .await?; contract.commit_tx(insert_tx).await?; @@ -1994,13 +1981,13 @@ mod tests { let update_value = value_row(20); let updated = contract .update( - update_tx.clone(), - table_id, - &update_key, - &pred_non_key, - &update_value, - &OptUpdate {}, - ) + update_tx.clone(), + table_id, + &update_key, + &pred_non_key, + &update_value, + &OptUpdate {}, + ) .await?; assert_eq!(updated, 1); contract.commit_tx(update_tx).await?; @@ -2010,13 +1997,7 @@ mod tests { let select = VecSelTerm::new(vec![1]); let opt_read = OptRead::default(); let relation = contract - .read_key( - read_tx, - table_id, - &read_key, - &select, - &opt_read, - ) + .read_key(read_tx, table_id, &read_key, &select, &opt_read) .await?; assert_eq!(relation, Some(vec![datum(20)])); Ok(()) diff --git a/mudu_kernel/src/sql/planner.rs b/mudu_kernel/src/sql/planner.rs index 5fbf27d..ae2fe92 100644 --- a/mudu_kernel/src/sql/planner.rs +++ b/mudu_kernel/src/sql/planner.rs @@ -338,13 +338,7 @@ mod tests { fn get(&self, _key: &[u8]) -> Option>> { None } - fn put_relation( - &self, - _relation_id: PhysicalRelationId, - _key: Vec, - _value: Vec, - ) { - } + fn put_relation(&self, _relation_id: PhysicalRelationId, _key: Vec, _value: Vec) {} fn delete_relation(&self, _relation_id: PhysicalRelationId, _key: Vec) {} fn get_relation( &self, diff --git a/mudu_runtime/src/backend/http_api/kernel_http_api.rs b/mudu_runtime/src/backend/http_api/kernel_http_api.rs index 9957cdd..ffbe981 100644 --- a/mudu_runtime/src/backend/http_api/kernel_http_api.rs +++ b/mudu_runtime/src/backend/http_api/kernel_http_api.rs @@ -28,6 +28,8 @@ use std::sync::Arc; pub struct KernelHttpApi { app_mgr: Arc, tcp_addr: String, + tcp_multi_port: bool, + tcp_base_listen_port: u16, worker_registry: Arc, meta_mgr: Arc, partition_router: PartitionRouter, @@ -45,6 +47,8 @@ impl KernelHttpApi { Ok(Self::with_client_factory( app_mgr, format!("{}:{}", cfg.listen_ip, cfg.tcp_listen_port), + cfg.tcp_multi_port, + cfg.tcp_listen_port, worker_registry, meta_mgr, Arc::new(KernelInvokeClientFactory), @@ -54,6 +58,8 @@ impl KernelHttpApi { pub fn with_client_factory( app_mgr: Arc, tcp_addr: String, + tcp_multi_port: bool, + tcp_base_listen_port: u16, worker_registry: Arc, meta_mgr: Arc, client_factory: Arc, @@ -63,6 +69,8 @@ impl KernelHttpApi { Self { app_mgr, tcp_addr, + tcp_multi_port, + tcp_base_listen_port, worker_registry, partition_router: PartitionRouter::new(meta_mgr.clone()), meta_mgr, @@ -154,12 +162,20 @@ impl HttpApi for KernelHttpApi { async fn server_topology(&self) -> RS { Ok(ServerTopology { worker_count: self.worker_registry.workers().len(), + tcp_multi_port: self.tcp_multi_port, + tcp_base_listen_port: self.tcp_base_listen_port, workers: self .worker_registry .workers() .iter() .map(|worker| WorkerTopology { worker_index: worker.worker_index, + tcp_listen_port: if self.tcp_multi_port { + self.tcp_base_listen_port + .saturating_add(worker.worker_index as u16) + } else { + self.tcp_base_listen_port + }, worker_id: worker.worker_id, partitions: worker.partition_ids.clone(), }) diff --git a/mudu_runtime/src/backend/http_api/mod.rs b/mudu_runtime/src/backend/http_api/mod.rs index 8a9027f..5d4cd13 100644 --- a/mudu_runtime/src/backend/http_api/mod.rs +++ b/mudu_runtime/src/backend/http_api/mod.rs @@ -36,13 +36,14 @@ use mudu_contract::procedure::proc_desc::ProcDesc; use mudu_contract::procedure::procedure_param::ProcedureParam; use mudu_contract::tuple::datum_desc::DatumDesc; use mudu_utils::notifier::Waiter; +use mudu_utils::scoped_task_trace; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use serde_json::{Map, Value}; use std::collections::HashMap; +use std::io::{Cursor, Read}; use std::net::TcpListener; use std::sync::Arc; use tracing::error; -use mudu_utils::scoped_task_trace; fn serialize_oid_as_unioid(oid: &OID, serializer: S) -> Result where @@ -79,6 +80,7 @@ where #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] pub struct WorkerTopology { pub worker_index: usize, + pub tcp_listen_port: u16, #[serde( serialize_with = "serialize_oid_as_unioid", deserialize_with = "deserialize_oid_from_unioid" @@ -94,6 +96,8 @@ pub struct WorkerTopology { #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] pub struct ServerTopology { pub worker_count: usize, + pub tcp_multi_port: bool, + pub tcp_base_listen_port: u16, pub workers: Vec, } @@ -386,11 +390,14 @@ async fn app_proc_detail( async fn install(body: web::Bytes, context: web::Data) -> impl Responder { let body_str = String::from_utf8_lossy(&body).to_string(); match decode_install_request(&body_str) { - Ok(binary) => match context.api.install_mpk(binary).await { - Ok(()) => http_ok(JsonValue::Null), - Err(e) => http_err(format!("fail to install package {:?}", body_str), &e), - }, - Err(e) => http_err(format!("fail to install package {:?}", body_str), &e), + Ok(binary) => { + let package_name = mpk_package_name(&binary).unwrap_or_else(|| "".to_string()); + match context.api.install_mpk(binary).await { + Ok(()) => http_ok(JsonValue::Null), + Err(e) => http_err(format!("fail to install package {}", package_name), &e), + } + } + Err(e) => http_err("fail to install package ", &e), } } @@ -433,6 +440,22 @@ fn decode_install_request(body_str: &str) -> RS> { .map_err(|e| m_error!(EC::DecodeErr, "decode error", e)) } +fn mpk_package_name(binary: &[u8]) -> Option { + let cursor = Cursor::new(binary); + let mut archive = zip::ZipArchive::new(cursor).ok()?; + let mut package_cfg = String::new(); + archive + .by_name("package.cfg.json") + .ok()? + .read_to_string(&mut package_cfg) + .ok()?; + serde_json::from_str::(&package_cfg) + .ok()? + .get("name")? + .as_str() + .map(str::to_string) +} + fn to_param(argv: &Map, desc: &[DatumDesc]) -> RS { let mut vec = vec![]; for datum_desc in desc.iter() { @@ -736,6 +759,8 @@ mod test { let api = KernelHttpApi::with_client_factory( Arc::new(MockAppMgr), "127.0.0.1:9527".to_string(), + false, + 9527, registry, MetaMgrFactory::create( std::env::temp_dir() @@ -813,6 +838,8 @@ mod test { let api = KernelHttpApi::with_client_factory( Arc::new(MockAppMgr), "127.0.0.1:9527".to_string(), + false, + 9527, registry, meta_mgr, Arc::new(MockClientFactory { @@ -869,6 +896,8 @@ mod test { let api = KernelHttpApi::with_client_factory( Arc::new(MockAppMgr), "127.0.0.1:9527".to_string(), + false, + 9527, registry.clone(), meta_mgr, Arc::new(MockClientFactory { @@ -890,6 +919,8 @@ mod test { let topology = api.server_topology().await.unwrap(); assert_eq!(topology.worker_count, registry.workers().len()); + assert!(!topology.tcp_multi_port); + assert_eq!(topology.tcp_base_listen_port, 9527); assert_eq!(topology.workers.len(), registry.workers().len()); } @@ -912,6 +943,8 @@ mod test { let api = KernelHttpApi::with_client_factory( Arc::new(MockAppMgr), "127.0.0.1:9527".to_string(), + false, + 9527, registry, meta_mgr, Arc::new(MockClientFactory { @@ -957,6 +990,8 @@ mod test { let api = KernelHttpApi::with_client_factory( Arc::new(MockAppMgr), "127.0.0.1:9527".to_string(), + false, + 9527, registry, meta_mgr, Arc::new(MockClientFactory { diff --git a/mudu_runtime/src/backend/linux/server_ur/server.rs b/mudu_runtime/src/backend/linux/server_ur/server.rs index 58b67ad..68811ca 100644 --- a/mudu_runtime/src/backend/linux/server_ur/server.rs +++ b/mudu_runtime/src/backend/linux/server_ur/server.rs @@ -10,13 +10,20 @@ use mudu_kernel::mudu_conn::mudu_conn_async::{ }; use mudu_kernel::server::routing::RoutingMode; use mudu_kernel::server::server::WorkerTcpBackend as KernelWorkerTcpBackend; -use mudu_kernel::server::server::WorkerTcpServerConfig; +use mudu_kernel::server::server_cfg::ServerCfg; +use mudu_kernel::server::server_launch::ServerLaunch; +use mudu_kernel::server::server_runtime_deps::ServerRuntimeDeps; use mudu_sys::task_async; use mudu_utils::notifier::{Notifier, Waiter, notify_wait}; -use std::sync::Arc; +use std::sync::{Arc, Mutex, OnceLock}; pub struct IoUringBackend; +fn default_remote_scope_lock() -> &'static Mutex<()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())) +} + impl IoUringBackend { pub fn sync_serve(cfg: MuduDBCfg) -> RS<()> { let (_stop_notifier, stop_waiter) = notify_wait(); @@ -28,11 +35,20 @@ impl IoUringBackend { } pub fn sync_serve_with_stop_and_ready( - cfg: MuduDBCfg, + mut cfg: MuduDBCfg, stop: Waiter, ready: Option, ) -> RS<()> { + let _default_remote_guard = default_remote_scope_lock().lock().map_err(|_| { + mudu::m_error!( + mudu::error::ec::EC::MutexError, + "default remote scope lock poisoned" + ) + })?; let worker_count = cfg.effective_worker_threads(); + if worker_count > 1 { + cfg.tcp_multi_port = true; + } let async_runtime = RuntimeOpt::build_async_runtime(cfg.server_mode); let app_mgr = Arc::new(MuduAppMgr::new_with_async_runtime( cfg.clone(), @@ -43,23 +59,22 @@ impl IoUringBackend { crate::backend::mududb_cfg::RoutingMode::PlayerId => RoutingMode::PlayerId, crate::backend::mududb_cfg::RoutingMode::RemoteHash => RoutingMode::RemoteHash, }; - let base_server_cfg = WorkerTcpServerConfig::new( + let base_server_cfg = ServerCfg::new( worker_count, cfg.listen_ip.clone(), cfg.tcp_listen_port, cfg.db_path.clone(), cfg.db_path.clone(), routing_mode, - None, )? - .with_log_chunk_size(cfg.io_uring_log_chunk_size); - let base_server_cfg = match async_runtime { - Some(async_runtime) => base_server_cfg.with_async_runtime(async_runtime), - None => base_server_cfg, - }; + .with_log_chunk_size(cfg.io_uring_log_chunk_size) + .with_multi_port(cfg.tcp_multi_port); + let mut server_deps = ServerRuntimeDeps::from_cfg(&base_server_cfg)? + .with_async_runtime(async_runtime.clone()); let default_remote_addr = format!("{}:{}", cfg.listen_ip, cfg.tcp_listen_port); - let default_remote_worker_id = base_server_cfg.worker_registry().default_global_worker_id(); - set_default_remote_async_runtime(base_server_cfg.async_runtime()); + let worker_registry = server_deps.worker_registry(); + let default_remote_worker_id = worker_registry.default_global_worker_id(); + set_default_remote_async_runtime(server_deps.async_runtime()); set_default_remote_addr(Some(default_remote_addr.clone())); set_default_remote_worker_id(default_remote_worker_id); let procedure_cfg = cfg.clone(); @@ -71,15 +86,11 @@ impl IoUringBackend { } Ok::<_, mudu::error::err::MError>(runtimes) })??; - let server_cfg = base_server_cfg.with_worker_procedure_runtimes(procedure_runtimes); - spawn_management_thread( - cfg.clone(), - app_mgr.clone(), - server_cfg.worker_registry(), - stop.clone(), - )?; + server_deps = server_deps.with_worker_procedure_runtimes(procedure_runtimes); + let server_launch = ServerLaunch::new(base_server_cfg, server_deps); + spawn_management_thread(cfg.clone(), app_mgr.clone(), worker_registry, stop.clone())?; let result = - KernelWorkerTcpBackend::sync_serve_with_stop_and_ready(server_cfg, stop, ready); + KernelWorkerTcpBackend::sync_serve_with_stop_and_ready(server_launch, stop, ready); clear_default_remote_if_current(&default_remote_addr, default_remote_worker_id); result } diff --git a/mudu_runtime/src/backend/linux/server_ur/test_mpk.rs b/mudu_runtime/src/backend/linux/server_ur/test_mpk.rs index 31820d2..3940ca6 100644 --- a/mudu_runtime/src/backend/linux/server_ur/test_mpk.rs +++ b/mudu_runtime/src/backend/linux/server_ur/test_mpk.rs @@ -11,7 +11,10 @@ use mudu_contract::procedure::procedure_param::ProcedureParam; use mudu_contract::tuple::tuple_datum::TupleDatum; use mudu_kernel::server::async_func_runtime::AsyncFuncInvokerPtr; use mudu_kernel::server::routing::RoutingMode as KernelRoutingMode; -use mudu_kernel::server::server::{TokioTcpBackend, WorkerTcpBackend, WorkerTcpServerConfig}; +use mudu_kernel::server::server::{TokioTcpBackend, WorkerTcpBackend}; +use mudu_kernel::server::server_cfg::ServerCfg; +use mudu_kernel::server::server_launch::ServerLaunch; +use mudu_kernel::server::server_runtime_deps::ServerRuntimeDeps; use mudu_utils::log::log_setup; use mudu_utils::notifier::notify_wait; use std::env::temp_dir; @@ -220,21 +223,22 @@ fn run_kv_mpk_can_be_used_by_kernel_backend(server_mode: ServerMode) -> RS<()> { let procedure_runtimes = create_procedure_runtimes(&app_mgr, &cfg)?; let (stop_notifier, server_stop) = notify_wait(); - let server_cfg = WorkerTcpServerConfig::new( + let server_cfg = ServerCfg::new( cfg.effective_worker_threads(), cfg.listen_ip.clone(), cfg.tcp_listen_port, cfg.db_path.clone(), cfg.db_path.clone(), KernelRoutingMode::ConnectionId, - None, )? - .with_log_chunk_size(cfg.io_uring_log_chunk_size) - .with_worker_procedure_runtimes(procedure_runtimes); + .with_log_chunk_size(cfg.io_uring_log_chunk_size); + let server_deps = ServerRuntimeDeps::from_cfg(&server_cfg)? + .with_worker_procedure_runtimes(procedure_runtimes); + let server_launch = ServerLaunch::new(server_cfg, server_deps); let server_thread = thread::spawn(move || match server_mode { - ServerMode::IOUring => WorkerTcpBackend::sync_serve_with_stop(server_cfg, server_stop), - ServerMode::Tokio => TokioTcpBackend::sync_serve_with_stop(server_cfg, server_stop), + ServerMode::IOUring => WorkerTcpBackend::sync_serve_with_stop(server_launch, server_stop), + ServerMode::Tokio => TokioTcpBackend::sync_serve_with_stop(server_launch, server_stop), ServerMode::Legacy => unreachable!("legacy mode is not a kernel backend"), }); diff --git a/mudu_runtime/src/backend/mududb_cfg.rs b/mudu_runtime/src/backend/mududb_cfg.rs index 5796144..57f0ff8 100644 --- a/mudu_runtime/src/backend/mududb_cfg.rs +++ b/mudu_runtime/src/backend/mududb_cfg.rs @@ -45,6 +45,8 @@ pub struct MuduDBCfg { #[serde(default = "default_tcp_listen_port")] pub tcp_listen_port: u16, #[serde(default)] + pub tcp_multi_port: bool, + #[serde(default)] pub io_uring_worker_threads: usize, #[serde(default = "default_ring_entries")] pub io_uring_ring_entries: u32, @@ -81,6 +83,7 @@ impl Display for MuduDBCfg { write!(f, " -> Enable Async: {}\n", self.enable_async)?; write!(f, " -> Server mode: {:?}\n", self.server_mode)?; write!(f, " -> TCP Listening port: {}\n", self.tcp_listen_port)?; + write!(f, " -> TCP Multi-port: {}\n", self.tcp_multi_port)?; write!( f, " -> io_uring workers: {}\n", @@ -135,6 +138,7 @@ impl Default for MuduDBCfg { enable_async: true, server_mode: ServerMode::Legacy, tcp_listen_port: default_tcp_listen_port(), + tcp_multi_port: false, io_uring_worker_threads: 0, io_uring_ring_entries: default_ring_entries(), io_uring_accept_multishot: true, diff --git a/mudu_runtime/src/backend/tokio_backend.rs b/mudu_runtime/src/backend/tokio_backend.rs index 3725e23..a245fe3 100644 --- a/mudu_runtime/src/backend/tokio_backend.rs +++ b/mudu_runtime/src/backend/tokio_backend.rs @@ -10,13 +10,20 @@ use mudu_kernel::mudu_conn::mudu_conn_async::{ }; use mudu_kernel::server::routing::RoutingMode; use mudu_kernel::server::server::TokioTcpBackend as KernelTokioTcpBackend; -use mudu_kernel::server::server::WorkerTcpServerConfig; +use mudu_kernel::server::server_cfg::ServerCfg; +use mudu_kernel::server::server_launch::ServerLaunch; +use mudu_kernel::server::server_runtime_deps::ServerRuntimeDeps; use mudu_sys::task_async; use mudu_utils::notifier::{Notifier, Waiter, notify_wait}; -use std::sync::Arc; +use std::sync::{Arc, Mutex, OnceLock}; pub struct TokioBackend; +fn default_remote_scope_lock() -> &'static Mutex<()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())) +} + impl TokioBackend { pub fn sync_serve(cfg: MuduDBCfg) -> RS<()> { let (_stop_notifier, stop_waiter) = notify_wait(); @@ -28,11 +35,20 @@ impl TokioBackend { } pub fn sync_serve_with_stop_and_ready( - cfg: MuduDBCfg, + mut cfg: MuduDBCfg, stop: Waiter, ready: Option, ) -> RS<()> { + let _default_remote_guard = default_remote_scope_lock().lock().map_err(|_| { + mudu::m_error!( + mudu::error::ec::EC::MutexError, + "default remote scope lock poisoned" + ) + })?; let worker_count = cfg.effective_worker_threads(); + if worker_count > 1 { + cfg.tcp_multi_port = true; + } let async_runtime = RuntimeOpt::build_async_runtime(cfg.server_mode); let app_mgr = Arc::new(MuduAppMgr::new_with_async_runtime( cfg.clone(), @@ -43,23 +59,22 @@ impl TokioBackend { crate::backend::mududb_cfg::RoutingMode::PlayerId => RoutingMode::PlayerId, crate::backend::mududb_cfg::RoutingMode::RemoteHash => RoutingMode::RemoteHash, }; - let base_server_cfg = WorkerTcpServerConfig::new( + let base_server_cfg = ServerCfg::new( worker_count, cfg.listen_ip.clone(), cfg.tcp_listen_port, cfg.db_path.clone(), cfg.db_path.clone(), routing_mode, - None, )? - .with_log_chunk_size(cfg.io_uring_log_chunk_size); - let base_server_cfg = match async_runtime { - Some(async_runtime) => base_server_cfg.with_async_runtime(async_runtime), - None => base_server_cfg, - }; + .with_log_chunk_size(cfg.io_uring_log_chunk_size) + .with_multi_port(cfg.tcp_multi_port); + let mut server_deps = ServerRuntimeDeps::from_cfg(&base_server_cfg)? + .with_async_runtime(async_runtime.clone()); let default_remote_addr = format!("{}:{}", cfg.listen_ip, cfg.tcp_listen_port); - let default_remote_worker_id = base_server_cfg.worker_registry().default_global_worker_id(); - set_default_remote_async_runtime(base_server_cfg.async_runtime()); + let worker_registry = server_deps.worker_registry(); + let default_remote_worker_id = worker_registry.default_global_worker_id(); + set_default_remote_async_runtime(server_deps.async_runtime()); set_default_remote_addr(Some(default_remote_addr.clone())); set_default_remote_worker_id(default_remote_worker_id); let procedure_cfg = cfg.clone(); @@ -71,14 +86,11 @@ impl TokioBackend { } Ok::<_, mudu::error::err::MError>(runtimes) })??; - let server_cfg = base_server_cfg.with_worker_procedure_runtimes(procedure_runtimes); - spawn_management_thread( - cfg.clone(), - app_mgr.clone(), - server_cfg.worker_registry(), - stop.clone(), - )?; - let result = KernelTokioTcpBackend::sync_serve_with_stop_and_ready(server_cfg, stop, ready); + server_deps = server_deps.with_worker_procedure_runtimes(procedure_runtimes); + let server_launch = ServerLaunch::new(base_server_cfg, server_deps); + spawn_management_thread(cfg.clone(), app_mgr.clone(), worker_registry, stop.clone())?; + let result = + KernelTokioTcpBackend::sync_serve_with_stop_and_ready(server_launch, stop, ready); clear_default_remote_if_current(&default_remote_addr, default_remote_worker_id); result } diff --git a/mudu_sys/src/sync.rs b/mudu_sys/src/sync.rs index 1cd014c..12c0231 100644 --- a/mudu_sys/src/sync.rs +++ b/mudu_sys/src/sync.rs @@ -12,5 +12,7 @@ pub mod f_mutex; #[cfg(not(target_arch = "wasm32"))] pub mod notify_wait; #[cfg(not(target_arch = "wasm32"))] +pub mod stop_flag; +#[cfg(not(target_arch = "wasm32"))] pub use crate::sync_async::*; pub use crate::sync_sync::*; diff --git a/mudu_sys/src/sync/stop_flag.rs b/mudu_sys/src/sync/stop_flag.rs new file mode 100644 index 0000000..93f715f --- /dev/null +++ b/mudu_sys/src/sync/stop_flag.rs @@ -0,0 +1,32 @@ +use tokio::sync::watch; + +#[derive(Clone)] +pub struct StopTx { + inner: watch::Sender, +} + +#[derive(Clone)] +pub struct StopRx { + inner: watch::Receiver, +} + +pub fn stop_channel() -> (StopTx, StopRx) { + let (tx, rx) = watch::channel(false); + (StopTx { inner: tx }, StopRx { inner: rx }) +} + +impl StopTx { + pub fn stop(&self) { + let _ = self.inner.send(true); + } +} + +impl StopRx { + pub fn is_stopped(&self) -> bool { + *self.inner.borrow() + } + + pub async fn changed(&mut self) -> bool { + self.inner.changed().await.is_ok() + } +} diff --git a/mudu_sys/src/task_async.rs b/mudu_sys/src/task_async.rs index 8b98664..2504359 100644 --- a/mudu_sys/src/task_async.rs +++ b/mudu_sys/src/task_async.rs @@ -78,7 +78,6 @@ where F: Future + 'static, F::Output: 'static, { - let id = { let id = task_id::new_task_id(); let _ = TaskContext::new_context(id, name.to_string(), false); @@ -261,12 +260,12 @@ where F: Future + 'static, F::Output: 'static, { - let runtime = build_current_thread_runtime()?; let ls = LocalSet::new(); let task = ls.run_until(async move { let join = spawn_local_detached("block-on", fut)?; - join.await.map_err(|e| m_error!(EC::TokioErr, "task runtime error", e)) + join.await + .map_err(|e| m_error!(EC::TokioErr, "task runtime error", e)) }); let r = runtime.block_on(async move { let r = task.await; @@ -274,8 +273,8 @@ where }); let opt = r.map_err(|e| m_error!(EC::TokioErr, "tokio tokio error", e))?; match opt { - None => { Err(m_error!(EC::TokioErr, "return none")) }, - Some(output) => { Ok(output) } + None => Err(m_error!(EC::TokioErr, "return none")), + Some(output) => Ok(output), } } diff --git a/mudu_utils/src/debug.rs b/mudu_utils/src/debug.rs index 5194d34..4d98158 100644 --- a/mudu_utils/src/debug.rs +++ b/mudu_utils/src/debug.rs @@ -2,6 +2,8 @@ // https://github.com/hyperium/hyper/blob/master/examples/echo.rs use std::net::SocketAddr; +#[cfg(feature = "debug_trace")] +use std::net::TcpListener as StdTcpListener; #[cfg(feature = "debug_trace")] use bytes::Bytes; @@ -29,6 +31,8 @@ use crate::dump_task_trace; use crate::notifier::NotifyWait; #[cfg(feature = "debug_trace")] use crate::task_async::CurrentThreadTaskRuntime; +#[cfg(feature = "debug_trace")] +use crate::notifier::Notifier; use mudu::common::result::RS; #[cfg(feature = "debug_trace")] use mudu::error::ec::EC; @@ -94,6 +98,15 @@ async fn handle_request(req: Request) -> Result>, #[cfg(feature = "debug_trace")] pub async fn async_debug_serve_until(addr: SocketAddr, stop: NotifyWait) -> Result<(), MError> { + async_debug_serve_until_with_ready(addr, stop, None).await +} + +#[cfg(feature = "debug_trace")] +pub async fn async_debug_serve_until_with_ready( + addr: SocketAddr, + stop: NotifyWait, + ready: Option, +) -> Result<(), MError> { crate::scoped_task_trace!(); let port = addr.port(); let r = SERVER.insert_sync(port); @@ -109,6 +122,18 @@ pub async fn async_debug_serve_until(addr: SocketAddr, stop: NotifyWait) -> Resu return Err(m_error!(EC::IOErr, "bind to address error", e)); } }; + if let Some(ready) = ready { + ready.notify_all(); + } + async_debug_serve_with_tokio_listener(listener, port, stop).await +} + +#[cfg(feature = "debug_trace")] +async fn async_debug_serve_with_tokio_listener( + listener: TcpListener, + port: u16, + stop: NotifyWait, +) -> Result<(), MError> { let mut tasks = JoinSet::new(); loop { let accepted = mudu_sys::tokio::select! { @@ -165,7 +190,66 @@ pub async fn async_debug_serve(_addr: SocketAddr) -> Result<(), MError> { #[cfg(feature = "debug_trace")] pub fn debug_serve(canceler: NotifyWait, port: u16) { - let async_debug_serve = async_debug_serve_until(([0, 0, 0, 0], port).into(), canceler.clone()); + let async_debug_serve = + async_debug_serve_until(([0, 0, 0, 0], port).into(), canceler.clone()); + let runtime = CurrentThreadTaskRuntime::new().unwrap(); + let join = runtime + .local() + .spawn(canceler, "debug_server", async_debug_serve) + .unwrap(); + runtime.block_on(async { + let _ = join.await; + }); +} + +#[cfg(feature = "debug_trace")] +pub fn debug_serve_until_with_ready(canceler: NotifyWait, port: u16, ready: Notifier) { + let async_debug_serve = async_debug_serve_until_with_ready( + ([0, 0, 0, 0], port).into(), + canceler.clone(), + Some(ready), + ); + let runtime = CurrentThreadTaskRuntime::new().unwrap(); + let join = runtime + .local() + .spawn(canceler, "debug_server", async_debug_serve) + .unwrap(); + runtime.block_on(async { + let _ = join.await; + }); +} + +#[cfg(feature = "debug_trace")] +pub fn debug_serve_with_listener( + canceler: NotifyWait, + listener: StdTcpListener, + ready: Notifier, +) { + let port = listener.local_addr().map(|addr| addr.port()).unwrap_or(0); + let canceler_for_future = canceler.clone(); + let async_debug_serve = async move { + if let Err(e) = listener.set_nonblocking(true) { + let _ = SERVER.remove_sync(&port); + return Err(m_error!( + EC::IOErr, + "set debug server listener nonblocking error", + e + )); + } + let listener = match TcpListener::from_std(listener) { + Ok(listener) => listener, + Err(e) => { + let _ = SERVER.remove_sync(&port); + return Err(m_error!( + EC::IOErr, + "create tokio listener from std listener error", + e + )); + } + }; + ready.notify_all(); + async_debug_serve_with_tokio_listener(listener, port, canceler_for_future).await + }; let runtime = CurrentThreadTaskRuntime::new().unwrap(); let join = runtime .local() diff --git a/mudu_utils/src/test_debug_server.rs b/mudu_utils/src/test_debug_server.rs index 141df0f..1e16f1f 100644 --- a/mudu_utils/src/test_debug_server.rs +++ b/mudu_utils/src/test_debug_server.rs @@ -1,8 +1,13 @@ #[cfg(test)] mod test { + #[cfg(feature = "debug_trace")] + use crate::debug::debug_serve_with_listener; + #[cfg(not(feature = "debug_trace"))] use crate::debug::debug_serve; use crate::log::log_setup; use crate::notifier::notify_wait; + #[cfg(feature = "debug_trace")] + use crate::task_async::build_current_thread_runtime; use crate::task_sync::spawn_thread_named; #[cfg(feature = "debug_trace")] use std::io::{Read, Write}; @@ -20,36 +25,45 @@ mod test { } }; let addr: SocketAddr = listener.local_addr().unwrap(); - drop(listener); - let port = addr.port(); let (notifier, waiter) = notify_wait(); let server_stop = waiter.into(); + #[cfg(feature = "debug_trace")] + let (ready_notifier, ready_waiter) = notify_wait(); + #[cfg(feature = "debug_trace")] let server = spawn_thread_named("test_server", move || { - debug_serve(server_stop, port); + debug_serve_with_listener(server_stop, listener, ready_notifier); }) .unwrap(); + #[cfg(not(feature = "debug_trace"))] + let server = { + drop(listener); + spawn_thread_named("test_server", move || { + debug_serve(server_stop, addr.port()); + }) + .unwrap() + }; #[cfg(feature = "debug_trace")] { - let mut response = None; - for _ in 0..20 { - std::thread::sleep(Duration::from_millis(50)); - let attempt = (|| -> std::io::Result { - let mut stream = std::net::TcpStream::connect(addr)?; - stream.write_all( - b"GET /task HTTP/1.1\r\nHost: 127.0.0.1\r\nConnection: close\r\n\r\n", - )?; - let mut buf = String::new(); - stream.read_to_string(&mut buf)?; - Ok(buf) - })(); - if let Ok(buf) = attempt { - response = Some(buf); - break; - } - } - let response = response.expect("debug server did not accept requests"); + let runtime = build_current_thread_runtime().unwrap(); + runtime.block_on(async { + ready_waiter.wait().await; + }); + } + + #[cfg(feature = "debug_trace")] + { + let response = (|| -> std::io::Result { + let mut stream = std::net::TcpStream::connect(addr)?; + stream.write_all( + b"GET /task HTTP/1.1\r\nHost: 127.0.0.1\r\nConnection: close\r\n\r\n", + )?; + let mut buf = String::new(); + stream.read_to_string(&mut buf)?; + Ok(buf) + })() + .expect("debug server did not accept requests"); assert!(response.starts_with("HTTP/1.1 200")); } @@ -60,7 +74,10 @@ mod test { } std::thread::sleep(Duration::from_millis(50)); } - assert!(server.is_finished(), "debug_serve thread did not stop after notify"); + assert!( + server.is_finished(), + "debug_serve thread did not stop after notify" + ); server.join().unwrap(); } } diff --git a/testing/src/lib.rs b/testing/src/lib.rs index 41859fa..df85745 100644 --- a/testing/src/lib.rs +++ b/testing/src/lib.rs @@ -22,6 +22,36 @@ pub fn reserve_port() -> RS> { } } +pub fn reserve_port_block(count: usize) -> RS> { + if count == 0 { + return Ok(None); + } + for _ in 0..128 { + let Some(base_port) = reserve_port()? else { + return Ok(None); + }; + let mut listeners = Vec::with_capacity(count); + let mut ok = true; + for offset in 0..count { + let Some(port) = base_port.checked_add(offset as u16) else { + ok = false; + break; + }; + match TcpListener::bind(("127.0.0.1", port)) { + Ok(listener) => listeners.push(listener), + Err(_) => { + ok = false; + break; + } + } + } + if ok { + return Ok(Some(base_port)); + } + } + Ok(None) +} + pub fn wait_until_port_ready(port: u16, service_name: &str) -> RS<()> { let deadline = mudu_sys::time::instant_now() + Duration::from_secs(10); while mudu_sys::time::instant_now() < deadline { diff --git a/testing/tests/linux/test_tpcc_concurrent_procedure.rs b/testing/tests/linux/test_tpcc_concurrent_procedure.rs index cc56e53..3c3be72 100644 --- a/testing/tests/linux/test_tpcc_concurrent_procedure.rs +++ b/testing/tests/linux/test_tpcc_concurrent_procedure.rs @@ -18,7 +18,7 @@ use std::path::{Path, PathBuf}; use std::sync::{LazyLock, Mutex, Once}; use std::thread::{self, JoinHandle}; use std::time::Instant; -use testing::{reserve_port, wait_until_port_ready}; +use testing::{reserve_port, reserve_port_block, wait_until_port_ready}; use tokio::sync::mpsc; use tokio::time::{Duration, timeout}; use tracing::{debug, info}; @@ -860,7 +860,11 @@ impl TestContext { let Some(pg_port) = reserve_port()? else { return Ok(None); }; - let Some(tcp_port) = reserve_port()? else { + let tcp_port_count = match server_mode { + ServerMode::IOUring | ServerMode::Tokio => 2, + ServerMode::Legacy => 1, + }; + let Some(tcp_port) = reserve_port_block(tcp_port_count)? else { return Ok(None); }; diff --git a/testing/tests/linux/wallet_mpk.rs b/testing/tests/linux/wallet_mpk.rs index c192337..73f90f8 100644 --- a/testing/tests/linux/wallet_mpk.rs +++ b/testing/tests/linux/wallet_mpk.rs @@ -18,7 +18,7 @@ use std::fs; use std::path::{Path, PathBuf}; use std::sync::{LazyLock, Mutex}; use std::thread::{self, JoinHandle}; -use testing::{reserve_port, wait_until_port_ready}; +use testing::{reserve_port, reserve_port_block, wait_until_port_ready}; use tracing::info; static WALLET_MPK_TEST_LOCK: LazyLock> = LazyLock::new(|| Mutex::new(())); @@ -154,12 +154,6 @@ fn run_wallet_mpk_via_mudu_cli_library_for_mode(server_mode: ServerMode) -> RS<( .block_on(install_app_package(&ctx.http_addr(), mpk_binary)) .map_err(to_mudu_error)?; - let mut client = runtime - .block_on(AsyncClientImpl::connect(&format!( - "127.0.0.1:{}", - ctx.client_port() - ))) - .map_err(|e| to_mudu_error(e.to_string()))?; let topology = runtime .block_on(fetch_server_topology(&ctx.http_addr())) .map_err(to_mudu_error)?; @@ -169,6 +163,16 @@ fn run_wallet_mpk_via_mudu_cli_library_for_mode(server_mode: ServerMode) -> RS<( .find(|worker| worker.worker_index == 0) .map(|worker| worker.worker_id) .ok_or_else(|| to_mudu_error("server topology does not contain worker 0".to_string()))?; + let default_worker_addr = topology + .worker_addr_by_id("127.0.0.1", default_worker_id) + .ok_or_else(|| { + to_mudu_error(format!( + "server topology does not contain worker id {default_worker_id} address" + )) + })?; + let mut client = runtime + .block_on(AsyncClientImpl::connect(&default_worker_addr)) + .map_err(|e| to_mudu_error(e.to_string()))?; let session_id = runtime .block_on( client.create_session(mudu_contract::protocol::SessionCreateRequest::new(Some( @@ -401,7 +405,11 @@ impl TestContext { let Some(pg_port) = reserve_port()? else { return Ok(None); }; - let Some(tcp_port) = reserve_port()? else { + let tcp_port_count = match server_mode { + ServerMode::IOUring | ServerMode::Tokio => 2, + ServerMode::Legacy => 1, + }; + let Some(tcp_port) = reserve_port_block(tcp_port_count)? else { return Ok(None); }; let base_dir = @@ -659,7 +667,7 @@ fn wait_until_backend_ready(waiter: Waiter, service_name: &str) -> RS<()> { // Wallet end-to-end tests exercise the service immediately after startup, // so they must wait for logical readiness instead of only for a bound // socket. - let result = mudu_sys::task_async::block_on_tokio_current_thread(async move{ + let result = mudu_sys::task_async::block_on_tokio_current_thread(async move { tokio::time::timeout(std::time::Duration::from_secs(10), waiter.wait()).await }) .map_err(|e| { diff --git a/testing/tests/test_copy_roundtrip.rs b/testing/tests/test_copy_roundtrip.rs index 2ea3c3f..4f96d7b 100644 --- a/testing/tests/test_copy_roundtrip.rs +++ b/testing/tests/test_copy_roundtrip.rs @@ -4,13 +4,15 @@ use mudu_runtime::backend::backend::Backend; use mudu_runtime::backend::mududb_cfg::ServerMode; use mudu_runtime::backend::mududb_cfg::{MuduDBCfg, RoutingMode}; use mudu_runtime::service::runtime_opt::ComponentTarget; +use mudu_sys::sync::NotifyWait; +use mudu_utils::debug::debug_serve; use mudu_utils::log::log_setup; use mudu_utils::notifier::{Notifier, Waiter, notify_wait}; use serde_json::{Value, json}; use std::fs; use std::net::{TcpListener, TcpStream}; use std::path::PathBuf; -use std::sync::mpsc; +use std::sync::{Mutex, OnceLock}; use std::thread::{self, JoinHandle}; use std::time::Duration; use tracing::{debug, info}; @@ -35,17 +37,30 @@ fn copy_from_to_roundtrip_tokio() -> RS<()> { } fn run_copy_from_to_roundtrip(server_mode: ServerMode) -> RS<()> { + let _test_guard = test_runtime_domain_lock().lock().map_err(|_| { + mudu::m_error!( + mudu::error::ec::EC::MutexError, + "test runtime domain lock poisoned" + ) + })?; let Some(ctx) = TestContext::new(server_mode)? else { eprintln!("skip copy roundtrip test: local TCP/HTTP bind is not permitted"); return Ok(()); }; + let notifier = NotifyWait::new(); + { + let _n = notifier.clone(); + let _ = thread::spawn(move || { + debug_serve(_n, 1800); + }); + }; let server = ctx.start_server()?; let suffix = mudu_sys::random::uuid_v4(); let copy_from_path = ctx.base_dir.join(format!("copy_from_{suffix}.csv")); let copy_to_path = ctx.base_dir.join(format!("copy_to_{suffix}.csv")); - let copy_from_file = format!("'{}'", copy_from_path.display()); - let copy_to_file = format!("'{}'", copy_to_path.display()); + let copy_from_file = sql_path_literal(©_from_path); + let copy_to_file = sql_path_literal(©_to_path); let input_csv = "id,name\n1,Alice\n2,Bob\n"; fs::write(©_from_path, input_csv).map_err(|e| { mudu::m_error!( @@ -66,7 +81,8 @@ fn run_copy_from_to_roundtrip(server_mode: ServerMode) -> RS<()> { ), copy_from_file, copy_to_file ); - let outputs = run_shell_script_outputs(&ctx, "demo", &script)?; + let app = format!("demo_{}", mudu_sys::random::uuid_v4()); + let outputs = run_shell_script_outputs(&ctx, &app, &script)?; let output_text = outputs .iter() @@ -86,8 +102,9 @@ fn run_copy_from_to_roundtrip(server_mode: ServerMode) -> RS<()> { e ) })?; - assert!( - exported.lines().next() == Some("id,name"), + assert_eq!( + exported.lines().next(), + Some("id,name"), "COPY TO should export csv header, exported: {}", exported ); @@ -110,6 +127,64 @@ fn supports_server_mode(server_mode: ServerMode) -> bool { } } +fn test_runtime_domain_lock() -> &'static Mutex<()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())) +} + +async fn handle_client_request(input: String, app: String, addr: String) -> RS> { + let mut client = JsonClient::connect(&addr).await?; + let mut current_app = app; + let mut buffer = String::new(); + let mut outputs: Vec = Vec::new(); + + for line in input.lines() { + let trimmed = line.trim(); + + if buffer.trim().is_empty() && trimmed.starts_with('\\') { + if handle_shell_meta(trimmed, &mut current_app) { + break; + } + continue; + } + + if trimmed.is_empty() && buffer.is_empty() { + continue; + } + + buffer.push_str(line); + buffer.push('\n'); + + if !statement_complete(&buffer) { + continue; + } + + let statement = finalize_statement(&buffer); + buffer.clear(); + if statement.is_empty() { + continue; + } + + let request = if looks_like_query(&statement) { + json!({ "app_name": current_app, "sql": statement }) + } else { + json!({ "app_name": current_app, "sql": statement, "kind": "execute" }) + }; + debug!(sql = %statement, is_query = looks_like_query(&statement), "sending sql"); + let output = tokio::time::timeout(Duration::from_secs(20), client.command(request)) + .await + .map_err(|_| { + mudu::m_error!( + mudu::error::ec::EC::TokioErr, + format!("copy roundtrip command timed out: {}", statement) + ) + })??; + outputs.push(output); + debug!("received sql response"); + } + + Ok(outputs) +} fn run_shell_script_outputs(ctx: &TestContext, app: &str, input: &str) -> RS> { let addr = format!("127.0.0.1:{}", ctx.client_port()); let app = app.to_string(); @@ -128,57 +203,8 @@ fn run_shell_script_outputs(ctx: &TestContext, app: &str, input: &str) -> RS = Vec::new(); - - for line in input.lines() { - let trimmed = line.trim(); - - if buffer.trim().is_empty() && trimmed.starts_with('\\') { - if handle_shell_meta(trimmed, &mut current_app) { - break; - } - continue; - } - - if trimmed.is_empty() && buffer.is_empty() { - continue; - } - - buffer.push_str(line); - buffer.push('\n'); - - if !statement_complete(&buffer) { - continue; - } - - let statement = finalize_statement(&buffer); - buffer.clear(); - if statement.is_empty() { - continue; - } - - let request = if looks_like_query(&statement) { - json!({ "app_name": current_app, "sql": statement }) - } else { - json!({ "app_name": current_app, "sql": statement, "kind": "execute" }) - }; - debug!(sql = %statement, is_query = looks_like_query(&statement), "sending sql"); - let output = tokio::time::timeout(Duration::from_secs(10), client.command(request)) - .await - .map_err(|_| { - mudu::m_error!( - mudu::error::ec::EC::TokioErr, - "interactive mcli command timed out" - ) - })??; - outputs.push(output); - debug!("received sql response"); - } - - Ok(outputs) + let r = handle_client_request(input, app, addr).await; + r }) }); @@ -210,7 +236,16 @@ fn statement_complete(buf: &str) -> bool { } fn finalize_statement(buf: &str) -> String { - buf.trim().to_string() + let stmt = buf.trim(); + let stmt = stmt.strip_suffix(';').unwrap_or(stmt); + stmt.trim().to_string() +} + +fn sql_path_literal(path: &std::path::Path) -> String { + // Use forward slashes so COPY path parsing is stable across platforms. + let normalized = path.to_string_lossy().replace('\\', "/"); + let escaped = normalized.replace('\'', "''"); + format!("'{escaped}'") } fn looks_like_query(sql: &str) -> bool { @@ -228,6 +263,8 @@ fn looks_like_query(sql: &str) -> bool { struct RunningServer { stop: Notifier, + http_port: u16, + tcp_port: u16, handle: Option>>, } @@ -236,6 +273,12 @@ impl Drop for RunningServer { debug!("test copy_roundtrip dropping running server"); self.stop.notify_all(); if let Some(handle) = self.handle.take() { + let deadline = mudu_sys::time::instant_now() + Duration::from_secs(15); + while !handle.is_finished() && mudu_sys::time::instant_now() < deadline { + let _ = TcpStream::connect(("127.0.0.1", self.http_port)); + let _ = TcpStream::connect(("127.0.0.1", self.tcp_port)); + mudu_sys::task_sync::sleep_blocking(Duration::from_millis(25)); + } let join_result = handle.join().expect("join server thread"); if let Err(err) = join_result { panic!("server stopped with error: {err}"); @@ -263,7 +306,11 @@ impl TestContext { let Some(pg_port) = reserve_port()? else { return Ok(None); }; - let Some(tcp_port) = reserve_port()? else { + let tcp_port_count = match server_mode { + ServerMode::IOUring | ServerMode::Tokio => 2, + ServerMode::Legacy => 1, + }; + let Some(tcp_port) = reserve_port_block(tcp_port_count)? else { return Ok(None); }; @@ -298,20 +345,19 @@ impl TestContext { ); let (stop, waiter) = notify_wait(); let (ready, ready_waiter) = notify_wait(); - let (exit_tx, exit_rx) = mpsc::channel(); let handle = thread::spawn(move || { - let result = Backend::sync_serve_with_stop_and_ready(cfg, waiter, Some(ready)); - let _ = exit_tx.send(result.clone()); - result + Backend::sync_serve_with_stop_and_ready(cfg, waiter, Some(ready)) }); wait_until_port_ready(self.http_port, "HTTP", BACKEND_STARTUP_TIMEOUT)?; if matches!(self.server_mode, ServerMode::IOUring | ServerMode::Tokio) { wait_until_port_ready(self.tcp_port, "TCP", BACKEND_STARTUP_TIMEOUT)?; } - wait_until_backend_ready(ready_waiter, &exit_rx, "backend", BACKEND_STARTUP_TIMEOUT)?; + wait_until_backend_ready(ready_waiter, "backend", BACKEND_STARTUP_TIMEOUT)?; debug!("backend server ready"); Ok(RunningServer { stop, + http_port: self.http_port, + tcp_port: self.tcp_port, handle: Some(handle), }) } @@ -366,6 +412,36 @@ fn reserve_port() -> RS> { } } +fn reserve_port_block(count: usize) -> RS> { + if count == 0 { + return Ok(None); + } + for _ in 0..128 { + let Some(base_port) = reserve_port()? else { + return Ok(None); + }; + let mut listeners = Vec::with_capacity(count); + let mut ok = true; + for offset in 0..count { + let Some(port) = base_port.checked_add(offset as u16) else { + ok = false; + break; + }; + match TcpListener::bind(("127.0.0.1", port)) { + Ok(listener) => listeners.push(listener), + Err(_) => { + ok = false; + break; + } + } + } + if ok { + return Ok(Some(base_port)); + } + } + Ok(None) +} + fn wait_until_port_ready(port: u16, service_name: &str, timeout: Duration) -> RS<()> { let deadline = mudu_sys::time::instant_now() + timeout; while mudu_sys::time::instant_now() < deadline { @@ -383,62 +459,27 @@ fn wait_until_port_ready(port: u16, service_name: &str, timeout: Duration) -> RS )) } -fn wait_until_backend_ready( - waiter: Waiter, - exit_rx: &mpsc::Receiver>, - service_name: &str, - timeout: Duration, -) -> RS<()> { - // The ready barrier keeps tests from racing startup with the later point - // where the backend can actually serve requests. Poll the barrier in small - // slices so an early server-thread failure surfaces immediately instead of - // degenerating into a generic timeout. - let deadline = mudu_sys::time::instant_now() + timeout; - while mudu_sys::time::instant_now() < deadline { - match exit_rx.try_recv() { - Ok(Ok(())) => { - return Err(mudu::m_error!( - mudu::error::ec::EC::ThreadErr, - format!("{service_name} stopped before publishing ready barrier") - )); - } - Ok(Err(err)) => { - return Err(mudu::m_error!( - mudu::error::ec::EC::ThreadErr, - format!("{service_name} exited before publishing ready barrier"), - err - )); - } - Err(mpsc::TryRecvError::Disconnected) => { - return Err(mudu::m_error!( - mudu::error::ec::EC::ThreadErr, - format!("{service_name} startup monitor disconnected") - )); - } - Err(mpsc::TryRecvError::Empty) => {} - } - - let result = mudu_sys::task_async::block_on_tokio_current_thread({ - let waiter = waiter.clone(); - async move { tokio::time::timeout(Duration::from_millis(100), waiter.wait()).await } - }) - .map_err(|e| { - mudu::m_error!( - mudu::error::ec::EC::TokioErr, - format!("wait for {} ready barrier runtime error", service_name), - e +fn wait_until_backend_ready(waiter: Waiter, service_name: &str, timeout: Duration) -> RS<()> { + // Listener readiness is not enough for io_uring mode because worker + // recovery continues after the port starts accepting connections. + let result = mudu_sys::task_async::block_on_tokio_current_thread(async move { + tokio::time::timeout(timeout, waiter.wait()).await + }) + .map_err(|e| { + mudu::m_error!( + mudu::error::ec::EC::TokioErr, + format!("wait for {} ready barrier runtime error", service_name), + e + ) + })?; + result.map_err(|_| { + mudu::m_error!( + mudu::error::ec::EC::TokioErr, + format!( + "{} ready barrier timed out after {:?}", + service_name, timeout ) - })?; - if result.is_ok() { - return Ok(()); - } - } - - Err(mudu::m_error!( - mudu::error::ec::EC::TokioErr, - format!( - "{} ready barrier timed out after {:?}", - service_name, timeout ) - )) + })?; + Ok(()) } diff --git a/testing/tests/test_interactive.rs b/testing/tests/test_interactive.rs index 4aae267..38c1d30 100644 --- a/testing/tests/test_interactive.rs +++ b/testing/tests/test_interactive.rs @@ -11,6 +11,7 @@ use serde_json::{Value, json}; use std::fs; use std::net::{TcpListener, TcpStream}; use std::path::PathBuf; +use std::sync::{Mutex, OnceLock}; use std::thread::{self, JoinHandle}; use std::time::Duration; use tracing::info; @@ -50,13 +51,20 @@ fn interactive_mcli_shell_tokio_tui() -> RS<()> { } fn run_interactive_mcli_shell_test(server_mode: ServerMode) -> RS<()> { + let _test_guard = test_runtime_domain_lock().lock().map_err(|_| { + mudu::m_error!( + mudu::error::ec::EC::MutexError, + "test runtime domain lock poisoned" + ) + })?; let Some(ctx) = TestContext::new(server_mode)? else { eprintln!("skip interactive mcli test: local TCP/HTTP bind is not permitted"); return Ok(()); }; let server = ctx.start_server()?; + let app = format!("demo_{}", mudu_sys::random::uuid_v4()); - let shell_output = run_interactive_mcli_shell(&ctx, "demo", crud_script())?; + let shell_output = run_interactive_mcli_shell(&ctx, &app, crud_script())?; assert!(shell_output.contains("'Eve'")); assert!(shell_output.contains("'Eva'")); drop(server); @@ -64,13 +72,20 @@ fn run_interactive_mcli_shell_test(server_mode: ServerMode) -> RS<()> { } fn run_interactive_mcli_tui_test(server_mode: ServerMode) -> RS<()> { + let _test_guard = test_runtime_domain_lock().lock().map_err(|_| { + mudu::m_error!( + mudu::error::ec::EC::MutexError, + "test runtime domain lock poisoned" + ) + })?; let Some(ctx) = TestContext::new(server_mode)? else { eprintln!("skip interactive mcli tui test: local TCP/HTTP bind is not permitted"); return Ok(()); }; let server = ctx.start_server()?; + let app = format!("demo_{}", mudu_sys::random::uuid_v4()); - let outputs = run_shell_script_outputs(&ctx, "demo", tui_script())?; + let outputs = run_shell_script_outputs(&ctx, &app, tui_script())?; let table = outputs .iter() .find_map(extract_query_table) @@ -104,6 +119,11 @@ fn supports_server_mode(server_mode: ServerMode) -> bool { } } +fn test_runtime_domain_lock() -> &'static Mutex<()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())) +} + fn crud_script() -> &'static str { concat!( "DROP TABLE IF EXISTS t_crud;\n", @@ -198,7 +218,14 @@ fn run_shell_script_outputs(ctx: &TestContext, app: &str, input: &str) -> RS bool { struct RunningServer { stop: Notifier, + http_port: u16, + tcp_port: u16, handle: Option>>, } @@ -261,6 +290,16 @@ impl Drop for RunningServer { fn drop(&mut self) { self.stop.notify_all(); if let Some(handle) = self.handle.take() { + let deadline = mudu_sys::time::instant_now() + Duration::from_secs(15); + while !handle.is_finished() && mudu_sys::time::instant_now() < deadline { + let _ = TcpStream::connect(("127.0.0.1", self.http_port)); + let _ = TcpStream::connect(("127.0.0.1", self.tcp_port)); + mudu_sys::task_sync::sleep_blocking(Duration::from_millis(25)); + } + assert!( + handle.is_finished(), + "join server thread timed out after 15s in test_interactive" + ); let join_result = handle.join().expect("join server thread"); if let Err(err) = join_result { panic!("server stopped with error: {err}"); @@ -287,7 +326,11 @@ impl TestContext { let Some(pg_port) = reserve_port()? else { return Ok(None); }; - let Some(tcp_port) = reserve_port()? else { + let tcp_port_count = match server_mode { + ServerMode::IOUring | ServerMode::Tokio => 2, + ServerMode::Legacy => 1, + }; + let Some(tcp_port) = reserve_port_block(tcp_port_count)? else { return Ok(None); }; @@ -327,6 +370,8 @@ impl TestContext { wait_until_backend_ready(ready_waiter, "backend")?; Ok(RunningServer { stop, + http_port: self.http_port, + tcp_port: self.tcp_port, handle: Some(handle), }) } @@ -381,6 +426,36 @@ fn reserve_port() -> RS> { } } +fn reserve_port_block(count: usize) -> RS> { + if count == 0 { + return Ok(None); + } + for _ in 0..128 { + let Some(base_port) = reserve_port()? else { + return Ok(None); + }; + let mut listeners = Vec::with_capacity(count); + let mut ok = true; + for offset in 0..count { + let Some(port) = base_port.checked_add(offset as u16) else { + ok = false; + break; + }; + match TcpListener::bind(("127.0.0.1", port)) { + Ok(listener) => listeners.push(listener), + Err(_) => { + ok = false; + break; + } + } + } + if ok { + return Ok(Some(base_port)); + } + } + Ok(None) +} + fn wait_until_port_ready(port: u16, service_name: &str) -> RS<()> { let deadline = mudu_sys::time::instant_now() + Duration::from_secs(10); while mudu_sys::time::instant_now() < deadline { @@ -401,7 +476,7 @@ fn wait_until_port_ready(port: u16, service_name: &str) -> RS<()> { fn wait_until_backend_ready(waiter: Waiter, service_name: &str) -> RS<()> { // Listener readiness is not enough for io_uring mode because worker // recovery continues after the port starts accepting connections. - let result = mudu_sys::task_async::block_on_tokio_current_thread(async move{ + let result = mudu_sys::task_async::block_on_tokio_current_thread(async move { tokio::time::timeout(Duration::from_secs(10), waiter.wait()).await }) .map_err(|e| { diff --git a/testing/tests/test_restart.rs b/testing/tests/test_restart.rs index 823de88..4bfc813 100644 --- a/testing/tests/test_restart.rs +++ b/testing/tests/test_restart.rs @@ -4,12 +4,15 @@ use mudu_runtime::backend::backend::Backend; use mudu_runtime::backend::mududb_cfg::ServerMode; use mudu_runtime::backend::mududb_cfg::{MuduDBCfg, RoutingMode}; use mudu_runtime::service::runtime_opt::ComponentTarget; +use mudu_sys::sync::NotifyWait; +use mudu_utils::debug::debug_serve; use mudu_utils::log::log_setup; use mudu_utils::notifier::{Notifier, Waiter, notify_wait}; use serde_json::{Value, json}; use std::fs; use std::net::{TcpListener, TcpStream}; use std::path::PathBuf; +use std::sync::{Mutex, OnceLock}; use std::thread::{self, JoinHandle}; use std::time::Duration; use tracing::info; @@ -32,12 +35,26 @@ fn test_mudud_restart_persistence_tokio() -> RS<()> { } fn run_restart_persistence(server_mode: ServerMode) -> RS<()> { + let _test_guard = test_runtime_domain_lock().lock().map_err(|_| { + mudu::m_error!( + mudu::error::ec::EC::MutexError, + "test runtime domain lock poisoned" + ) + })?; + let notifier = NotifyWait::new(); + { + let _n = notifier.clone(); + let _ = thread::spawn(move || { + debug_serve(_n, 1800); + }); + }; let Some(ctx) = TestContext::new(server_mode)? else { eprintln!("skip test: local TCP/HTTP bind is not permitted"); return Ok(()); }; println!("Step 1: Start mudud ({server_mode:?} mode)"); + let app = format!("demo_{}", mudu_sys::random::uuid_v4()); { let server = ctx.start_server()?; @@ -45,9 +62,16 @@ fn run_restart_persistence(server_mode: ServerMode) -> RS<()> { let script = concat!( "CREATE TABLE t_restart(id INT PRIMARY KEY, name TEXT);\n", "INSERT INTO t_restart(id, name) VALUES (100, 'Mudu');\n", + "SELECT name FROM t_restart WHERE id = 100;\n", "\\q\n" ); - let _ = run_shell_script_outputs(&ctx, "demo", script)?; + let outputs = run_shell_script_outputs(&ctx, &app, script)?; + let inserted_visible = outputs.iter().any(|val| val.to_string().contains("Mudu")); + assert!( + inserted_visible, + "Inserted row should be visible before stop. Outputs: {:?}", + outputs + ); println!("Step 3: Stop mudud"); drop(server); @@ -62,7 +86,7 @@ fn run_restart_persistence(server_mode: ServerMode) -> RS<()> { println!("Step 5: mcli reconnect and verify data"); let script = "SELECT name FROM t_restart WHERE id = 100;\n\\q\n"; - let outputs = run_shell_script_outputs(&ctx, "demo", script)?; + let outputs = run_shell_script_outputs(&ctx, &app, script)?; let found_mudu = outputs.iter().any(|val| val.to_string().contains("Mudu")); assert!( @@ -85,6 +109,11 @@ fn supports_server_mode(server_mode: ServerMode) -> bool { } } +fn test_runtime_domain_lock() -> &'static Mutex<()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())) +} + // Helpers adapted from test_interactive.rs fn run_shell_script_outputs(ctx: &TestContext, app: &str, input: &str) -> RS> { @@ -212,6 +241,8 @@ fn looks_like_query(sql: &str) -> bool { struct RunningServer { stop: Notifier, + http_port: u16, + tcp_port: u16, handle: Option>>, } @@ -219,6 +250,12 @@ impl Drop for RunningServer { fn drop(&mut self) { self.stop.notify_all(); if let Some(handle) = self.handle.take() { + let deadline = mudu_sys::time::instant_now() + Duration::from_secs(15); + while !handle.is_finished() && mudu_sys::time::instant_now() < deadline { + let _ = TcpStream::connect(("127.0.0.1", self.http_port)); + let _ = TcpStream::connect(("127.0.0.1", self.tcp_port)); + mudu_sys::task_sync::sleep_blocking(Duration::from_millis(25)); + } let join_result = handle.join().expect("join server thread"); if let Err(err) = join_result { panic!("server stopped with error: {err}"); @@ -245,7 +282,11 @@ impl TestContext { let Some(pg_port) = reserve_port()? else { return Ok(None); }; - let Some(tcp_port) = reserve_port()? else { + let tcp_port_count = match server_mode { + ServerMode::IOUring | ServerMode::Tokio => 2, + ServerMode::Legacy => 1, + }; + let Some(tcp_port) = reserve_port_block(tcp_port_count)? else { return Ok(None); }; @@ -289,6 +330,8 @@ impl TestContext { println!(" [server] Server ready."); Ok(RunningServer { stop, + http_port: self.http_port, + tcp_port: self.tcp_port, handle: Some(handle), }) } @@ -343,6 +386,36 @@ fn reserve_port() -> RS> { } } +fn reserve_port_block(count: usize) -> RS> { + if count == 0 { + return Ok(None); + } + for _ in 0..128 { + let Some(base_port) = reserve_port()? else { + return Ok(None); + }; + let mut listeners = Vec::with_capacity(count); + let mut ok = true; + for offset in 0..count { + let Some(port) = base_port.checked_add(offset as u16) else { + ok = false; + break; + }; + match TcpListener::bind(("127.0.0.1", port)) { + Ok(listener) => listeners.push(listener), + Err(_) => { + ok = false; + break; + } + } + } + if ok { + return Ok(Some(base_port)); + } + } + Ok(None) +} + fn wait_until_port_ready(port: u16, service_name: &str) -> RS<()> { let deadline = mudu_sys::time::instant_now() + Duration::from_secs(10); while mudu_sys::time::instant_now() < deadline { @@ -364,7 +437,7 @@ fn wait_until_backend_ready(waiter: Waiter, service_name: &str) -> RS<()> { // A listening socket only proves that bind/listen completed. In io_uring // mode the workers may still be replaying WAL, so tests must also wait for // the backend's logical readiness barrier before issuing requests. - let result = mudu_sys::task_async::block_on_tokio_current_thread(async move{ + let result = mudu_sys::task_async::block_on_tokio_current_thread(async move { tokio::time::timeout(Duration::from_secs(10), waiter.wait()).await }) .map_err(|e| {