diff --git a/.gitignore b/.gitignore index 978060a6..309e8e34 100644 --- a/.gitignore +++ b/.gitignore @@ -42,3 +42,4 @@ toxi.log *.sqlite3 perf.data perf.data.old +/shard_test/ diff --git a/Cargo.lock b/Cargo.lock index 58a00a54..05e81387 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -386,6 +386,7 @@ dependencies = [ "iana-time-zone", "js-sys", "num-traits", + "serde", "wasm-bindgen", "windows-link", ] diff --git a/pgdog/Cargo.toml b/pgdog/Cargo.toml index 48104fb5..a12c60cb 100644 --- a/pgdog/Cargo.toml +++ b/pgdog/Cargo.toml @@ -49,7 +49,7 @@ uuid = { version = "1", features = ["v4"] } url = "2" ratatui = { version = "0.30.0-alpha.1", optional = true } rmp-serde = "1" -chrono = "0.4" +chrono = { version = "0.4", features = ["serde"] } hyper = { version = "1", features = ["full"] } http-body-util = "0.1" hyper-util = { version = "0.1", features = ["full"] } diff --git a/pgdog/src/backend/pool/cluster.rs b/pgdog/src/backend/pool/cluster.rs index c07a7e5f..deb37527 100644 --- a/pgdog/src/backend/pool/cluster.rs +++ b/pgdog/src/backend/pool/cluster.rs @@ -360,6 +360,9 @@ mod test { data_type: DataType::Bigint, centroids_path: None, centroid_probes: 1, + sharding_method: None, + shard_range_map: None, + shard_list_map: None, }], vec!["sharded_omni".into()], false, diff --git a/pgdog/src/config/mod.rs b/pgdog/src/config/mod.rs index 09058a46..7baa23df 100644 --- a/pgdog/src/config/mod.rs +++ b/pgdog/src/config/mod.rs @@ -3,6 +3,7 @@ pub mod convert; pub mod error; pub mod overrides; +mod shards; pub mod url; use error::Error; @@ -21,6 +22,7 @@ use serde::{Deserialize, Serialize}; use tracing::info; use tracing::warn; +pub(crate) use crate::config::shards::{ShardListMap, ShardRangeMap, ShardingMethod}; use crate::net::messages::Vector; use crate::util::{human_duration_optional, random_string}; @@ -826,6 +828,12 @@ pub struct ShardedTable { /// How many centroids to probe. #[serde(default)] pub centroid_probes: usize, + #[serde(default)] + pub sharding_method: Option, + + pub shard_range_map: Option, + + pub shard_list_map: Option, } impl ShardedTable { @@ -865,6 +873,10 @@ pub enum DataType { Bigint, Uuid, Vector, + // TODO: implement more types? + // String, + // DateTimeUTC + // Float } #[derive(Serialize, Deserialize, PartialEq, Debug, Clone, Default)] @@ -955,9 +967,10 @@ pub struct MultiTenant { #[cfg(test)] pub mod test { - use crate::backend::databases::init; use super::*; + use crate::backend::databases::init; + use crate::config::shards::ShardRange; pub fn load_test() { let mut config = ConfigAndUsers::default(); @@ -1052,4 +1065,315 @@ column = "tenant_id" assert_eq!(config.tcp.retries().unwrap(), 5); assert_eq!(config.multi_tenant.unwrap().column, "tenant_id"); } + + #[test] + fn test_load_sharded_table_with_range_map() { + let toml_str = r#" + database = "pgdog_sharded" + name = "range_sharded" + column = "user_id" + data_type = "bigint" + sharding_method = "range" + + [shard_range_map] + "0" = { start = 0, end = 1000 } + "1" = { start = 1000, end = 2000 } + "2" = { start = 2000, end = 3000 } + "#; + + let table: ShardedTable = toml::from_str(toml_str).unwrap(); + + // Verify basic fields + assert_eq!(table.database, "pgdog_sharded"); + assert_eq!(table.name, Some("range_sharded".to_string())); + assert_eq!(table.column, "user_id"); + assert_eq!(table.data_type, crate::config::DataType::Bigint); + assert_eq!(table.sharding_method, Some(ShardingMethod::Range)); + + // Verify shard_range_map + let range_map = table.shard_range_map.unwrap(); + assert_eq!(range_map.0.len(), 3); + + // Check first range + let range_0 = range_map.0.get(&0).unwrap(); + assert_eq!(range_0.start, Some(0)); + assert_eq!(range_0.end, Some(1000)); + assert_eq!(range_0.no_min, false); + assert_eq!(range_0.no_max, false); + + // Check second range + let range_1 = range_map.0.get(&1).unwrap(); + assert_eq!(range_1.start, Some(1000)); + assert_eq!(range_1.end, Some(2000)); + + // Check third range + let range_2 = range_map.0.get(&2).unwrap(); + assert_eq!(range_2.start, Some(2000)); + assert_eq!(range_2.end, Some(3000)); + + // Verify that shard_list_map is None + assert!(table.shard_list_map.is_none()); + } + + #[test] + fn test_load_sharded_table_with_list_map() { + let toml_str = r#" + database = "pgdog_sharded" + name = "list_sharded" + column = "category_id" + data_type = "bigint" + sharding_method = "list" + + [shard_list_map] + "0" = { values = [1, 3, 5, 7, 9] } + "1" = { values = [2, 4, 6, 8, 10] } + "2" = { values = [11, 12, 13, 14, 15] } + "#; + + let table: ShardedTable = toml::from_str(toml_str).unwrap(); + + // Verify basic fields + assert_eq!(table.database, "pgdog_sharded"); + assert_eq!(table.name, Some("list_sharded".to_string())); + assert_eq!(table.column, "category_id"); + assert_eq!(table.data_type, crate::config::DataType::Bigint); + assert_eq!(table.sharding_method, Some(ShardingMethod::List)); + + // Verify shard_list_map + let list_map = table.shard_list_map.unwrap(); + assert_eq!(list_map.0.len(), 3); + + // Check first list + let list_0 = list_map.0.get(&0).unwrap(); + assert_eq!(list_0.values, vec![1, 3, 5, 7, 9]); + + // Check second list + let list_1 = list_map.0.get(&1).unwrap(); + assert_eq!(list_1.values, vec![2, 4, 6, 8, 10]); + + // Check third list + let list_2 = list_map.0.get(&2).unwrap(); + assert_eq!(list_2.values, vec![11, 12, 13, 14, 15]); + + // Verify that shard_range_map is None + assert!(table.shard_range_map.is_none()); + } + + #[test] + fn test_load_sharded_table_with_special_range_flags() { + let toml_str = r#" + database = "pgdog_sharded" + name = "special_range_sharded" + column = "timestamp_id" + data_type = "bigint" + sharding_method = "range" + + [shard_range_map] + "0" = { start = 0, end = 1000 } + "1" = { start = 1000, no_max = true } + "2" = { no_min = true, end = 0 } + "#; + + let table: ShardedTable = toml::from_str(toml_str).unwrap(); + + // Verify shard_range_map with special flags + let range_map = table.shard_range_map.unwrap(); + assert_eq!(range_map.0.len(), 3); + + // Standard range + let range_0 = range_map.0.get(&0).unwrap(); + assert_eq!(range_0.start, Some(0)); + assert_eq!(range_0.end, Some(1000)); + assert_eq!(range_0.no_min, false); + assert_eq!(range_0.no_max, false); + + // Range with no maximum (unbounded upper) + let range_1 = range_map.0.get(&1).unwrap(); + assert_eq!(range_1.start, Some(1000)); + assert_eq!(range_1.end, None); + assert_eq!(range_1.no_min, false); + assert_eq!(range_1.no_max, true); + + // Range with no minimum (unbounded lower) + let range_2 = range_map.0.get(&2).unwrap(); + assert_eq!(range_2.start, None); + assert_eq!(range_2.end, Some(0)); + assert_eq!(range_2.no_min, true); + assert_eq!(range_2.no_max, false); + } + + #[test] + fn test_load_sharded_table_with_empty_list_values() { + let toml_str = r#" + database = "pgdog_sharded" + name = "empty_list_sharded" + column = "tag_id" + data_type = "bigint" + sharding_method = "list" + + [shard_list_map] + "0" = { values = [1, 2, 3] } + "1" = { values = [] } + "#; + + let table: ShardedTable = toml::from_str(toml_str).unwrap(); + + // Verify shard_list_map with an empty list + let list_map = table.shard_list_map.unwrap(); + assert_eq!(list_map.0.len(), 2); + + // Check first list + let list_0 = list_map.0.get(&0).unwrap(); + assert_eq!(list_0.values, vec![1, 2, 3]); + + // Check empty list + let list_1 = list_map.0.get(&1).unwrap(); + assert!(list_1.values.is_empty()); + } + + #[test] + fn test_load_sharded_table_with_invalid_shard_map_keys() { + let toml_str = r#" + database = "pgdog_sharded" + name = "invalid_keys" + column = "user_id" + data_type = "bigint" + sharding_method = "range" + + [shard_range_map] + "invalid" = { start = 0, end = 1000 } + "0" = { start = 1000, end = 2000 } + "#; + + let result = toml::from_str::(toml_str); + assert!(result.is_err()); + + // Verify the error message contains information about parsing failure + let error = result.unwrap_err().to_string(); + assert!(error.contains("invalid") || error.contains("parse")); + } + + #[test] + fn test_load_sharded_table_with_both_maps() { + let toml_str = r#" + database = "pgdog_sharded" + name = "dual_sharded" + column = "id" + data_type = "bigint" + sharding_method = "range" + + [shard_range_map] + "0" = { start = 0, end = 1000 } + "1" = { start = 1000, end = 2000 } + + [shard_list_map] + "0" = { values = [1, 3, 5] } + "1" = { values = [2, 4, 6] } + "#; + + let table: ShardedTable = toml::from_str(toml_str).unwrap(); + + // Both maps should be populated, but the actual sharding method used + // should be determined by the sharding_method field + assert_eq!(table.sharding_method, Some(ShardingMethod::Range)); + + // Verify both maps exist + assert!(table.shard_range_map.is_some()); + assert!(table.shard_list_map.is_some()); + + // Check range map + let range_map = table.shard_range_map.unwrap(); + assert_eq!(range_map.0.len(), 2); + + // Check list map + let list_map = table.shard_list_map.unwrap(); + assert_eq!(list_map.0.len(), 2); + } + + #[test] + fn test_load_sharded_table_without_sharding_method() { + let toml_str = r#" + database = "pgdog_sharded" + name = "implicit_hash" + column = "id" + data_type = "bigint" + + [shard_range_map] + "0" = { start = 0, end = 1000 } + "1" = { start = 1000, end = 2000 } + "#; + + let table: ShardedTable = toml::from_str(toml_str).unwrap(); + + // If sharding_method is not specified, it should default to Hash + assert_eq!(table.sharding_method, None); + + // But the range map should still be populated + assert!(table.shard_range_map.is_some()); + let range_map = table.shard_range_map.unwrap(); + assert_eq!(range_map.0.len(), 2); + } + + #[test] + fn test_programmatically_create_and_serialize() { + // Create a ShardedTable with range map programmatically + let mut range_map = HashMap::new(); + range_map.insert( + 0, + ShardRange { + start: Some(0), + end: Some(1000), + no_min: false, + no_max: false, + }, + ); + range_map.insert( + 1, + ShardRange { + start: Some(1000), + end: None, + no_min: false, + no_max: true, + }, + ); + + let shard_range_map = ShardRangeMap(range_map); + + let table = ShardedTable { + database: "pgdog_sharded".to_string(), + name: Some("range_table".to_string()), + column: "id".to_string(), + data_type: crate::config::DataType::Bigint, + sharding_method: Some(ShardingMethod::Range), + shard_range_map: Some(shard_range_map), + shard_list_map: None, + primary: false, + centroids: Vec::new(), + centroids_path: None, + centroid_probes: 0, + }; + + // Serialize to TOML + let toml_str = toml::to_string(&table).unwrap(); + + // Deserialize back to validate + let parsed_table: ShardedTable = toml::from_str(&toml_str).unwrap(); + + // Verify the deserialized structure matches the original + assert_eq!(parsed_table.database, "pgdog_sharded"); + assert_eq!(parsed_table.name, Some("range_table".to_string())); + assert_eq!(parsed_table.sharding_method, Some(ShardingMethod::Range)); + + let parsed_range_map = parsed_table.shard_range_map.unwrap(); + assert_eq!(parsed_range_map.0.len(), 2); + + let range_0 = parsed_range_map.0.get(&0).unwrap(); + assert_eq!(range_0.start, Some(0)); + assert_eq!(range_0.end, Some(1000)); + + let range_1 = parsed_range_map.0.get(&1).unwrap(); + assert_eq!(range_1.start, Some(1000)); + assert_eq!(range_1.end, None); + assert_eq!(range_1.no_max, true); + } } diff --git a/pgdog/src/config/shards.rs b/pgdog/src/config/shards.rs new file mode 100644 index 00000000..a87c269c --- /dev/null +++ b/pgdog/src/config/shards.rs @@ -0,0 +1,196 @@ +use std::collections::HashMap; + +use serde::{Deserialize, Deserializer, Serialize, Serializer}; + +use crate::frontend::router::parser::Shard; + +// ============================================================================= +// Serialization Helper Module +// ============================================================================= + +/// Helper module for (de)serializing maps with usize keys as strings +mod usize_map_keys_as_strings { + use super::*; + + pub fn serialize(map: &HashMap, serializer: S) -> Result + where + S: Serializer, + V: Serialize, + { + let string_map: HashMap = map.iter().map(|(k, v)| (k.to_string(), v)).collect(); + string_map.serialize(serializer) + } + + pub fn deserialize<'de, D, V>(deserializer: D) -> Result, D::Error> + where + D: Deserializer<'de>, + V: Deserialize<'de>, + { + let string_map = HashMap::::deserialize(deserializer)?; + string_map + .into_iter() + .map(|(s, v)| { + s.parse::() + .map(|k| (k, v)) + .map_err(serde::de::Error::custom) + }) + .collect() + } +} + +// ============================================================================= +// Core Sharding Types +// ============================================================================= + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)] +#[serde(rename_all = "snake_case")] +pub enum ShardingMethod { + #[default] + Hash, + Range, + List, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct ShardRange { + pub start: Option, + pub end: Option, + #[serde(default)] + pub no_max: bool, + #[serde(default)] + pub no_min: bool, +} + +impl ShardRange { + /// Check if a value falls within this range + pub fn contains(&self, value: i64) -> bool { + // Check lower bound + if !self.no_min { + if let Some(start) = self.start { + if value < start { + return false; + } + } + } + + // Check upper bound + if !self.no_max { + if let Some(end) = self.end { + if value >= end { + // Using >= for exclusive upper bound + return false; + } + } + } + + true + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct ShardList { + pub values: Vec, +} + +impl ShardList { + /// Check if a value is contained in this list + pub fn contains(&self, value: i64) -> bool { + self.values.contains(&value) + } +} + +// ============================================================================= +// Shard Map Types +// ============================================================================= + +/// A map of shard IDs to their range definitions +#[derive(Debug, Clone, PartialEq)] +pub struct ShardRangeMap(pub HashMap); + +impl ShardRangeMap { + pub fn new() -> Self { + Self::default() + } + + /// Find the shard key for a given value based on range containment + pub fn find_shard_key(&self, value: i64) -> Option { + for (shard_id, range) in &self.0 { + if range.contains(value) { + return Some(Shard::Direct(*shard_id)); + } + } + None + } +} + +impl Default for ShardRangeMap { + fn default() -> Self { + Self(HashMap::new()) + } +} + +impl Serialize for ShardRangeMap { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + usize_map_keys_as_strings::serialize(&self.0, serializer) + } +} + +impl<'de> Deserialize<'de> for ShardRangeMap { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + Ok(ShardRangeMap(usize_map_keys_as_strings::deserialize( + deserializer, + )?)) + } +} + +/// A map of shard IDs to their list definitions +#[derive(Debug, Clone, PartialEq)] +pub struct ShardListMap(pub HashMap); + +impl ShardListMap { + pub fn new() -> Self { + Self::default() + } + + /// Find the shard key for a given value based on list containment + pub fn find_shard_key(&self, value: i64) -> Option { + for (shard_id, list) in &self.0 { + if list.contains(value) { + return Some(Shard::Direct(*shard_id)); + } + } + None + } +} + +impl Default for ShardListMap { + fn default() -> Self { + Self(HashMap::new()) + } +} + +impl Serialize for ShardListMap { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + usize_map_keys_as_strings::serialize(&self.0, serializer) + } +} + +impl<'de> Deserialize<'de> for ShardListMap { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + Ok(ShardListMap(usize_map_keys_as_strings::deserialize( + deserializer, + )?)) + } +} diff --git a/pgdog/src/frontend/router/sharding/context.rs b/pgdog/src/frontend/router/sharding/context.rs index 9687081b..c3f1bd3e 100644 --- a/pgdog/src/frontend/router/sharding/context.rs +++ b/pgdog/src/frontend/router/sharding/context.rs @@ -1,6 +1,5 @@ -use crate::frontend::router::parser::Shard; - use super::{Error, Operator, Value}; +use crate::frontend::router::parser::Shard; #[derive(Debug)] pub struct Context<'a> { @@ -16,7 +15,6 @@ impl<'a> Context<'a> { return Ok(Shard::Direct(hash as usize % shards)); } } - Operator::Centroids { shards, probes, @@ -26,6 +24,16 @@ impl<'a> Context<'a> { return Ok(centroids.shard(&vector, *shards, *probes)); } } + Operator::Ranges(srm) => { + if let Some(i) = self.value.int()? { + return Ok(srm.find_shard_key(i).unwrap()); + } + } + Operator::Lists(slm) => { + if let Some(i) = self.value.int()? { + return Ok(slm.find_shard_key(i).unwrap()); + } + } } Ok(Shard::All) diff --git a/pgdog/src/frontend/router/sharding/context_builder.rs b/pgdog/src/frontend/router/sharding/context_builder.rs index 4dbb4458..e552b2d7 100644 --- a/pgdog/src/frontend/router/sharding/context_builder.rs +++ b/pgdog/src/frontend/router/sharding/context_builder.rs @@ -1,4 +1,4 @@ -use crate::config::{DataType, ShardedTable}; +use crate::config::{DataType, ShardListMap, ShardRangeMap, ShardedTable, ShardingMethod}; use super::{Centroids, Context, Data, Error, Operator, Value}; @@ -8,6 +8,9 @@ pub struct ContextBuilder<'a> { operator: Option>, centroids: Option>, probes: usize, + sharding_method: Option, + shard_range_map: Option, + shard_list_map: Option, } impl<'a> ContextBuilder<'a> { @@ -22,6 +25,11 @@ impl<'a> ContextBuilder<'a> { probes: table.centroid_probes, operator: None, value: None, + // added for list and range sharding + // todo: add lifetimes to these to avoid cloning + sharding_method: table.sharding_method.clone(), + shard_range_map: table.shard_range_map.clone(), + shard_list_map: table.shard_list_map.clone(), } } @@ -37,6 +45,9 @@ impl<'a> ContextBuilder<'a> { probes: 0, centroids: None, operator: None, + sharding_method: None, + shard_range_map: None, + shard_list_map: None, }) } else if uuid.valid() { Ok(Self { @@ -45,6 +56,9 @@ impl<'a> ContextBuilder<'a> { probes: 0, centroids: None, operator: None, + sharding_method: None, + shard_range_map: None, + shard_list_map: None, }) } else { Err(Error::IncompleteContext) @@ -57,9 +71,25 @@ impl<'a> ContextBuilder<'a> { shards, probes: self.probes, centroids, - }); - } else { - self.operator = Some(Operator::Shards(shards)) + }) + } else if let Some(method) = self.sharding_method.take() { + match method { + ShardingMethod::Hash => { + self.operator = Some(Operator::Shards(shards)); + return self; + } + ShardingMethod::Range => { + if self.shard_range_map.is_some() { + self.operator = + Some(Operator::Ranges(self.shard_range_map.clone().unwrap())) + } + } + ShardingMethod::List => { + if self.shard_list_map.is_some() { + self.operator = Some(Operator::Lists(self.shard_list_map.clone().unwrap())) + } + } + } } self } diff --git a/pgdog/src/frontend/router/sharding/mod.rs b/pgdog/src/frontend/router/sharding/mod.rs index 9b5abb35..94807a32 100644 --- a/pgdog/src/frontend/router/sharding/mod.rs +++ b/pgdog/src/frontend/router/sharding/mod.rs @@ -1,3 +1,4 @@ +use bytes::Bytes; use uuid::Uuid; use crate::{ @@ -41,6 +42,15 @@ pub fn uuid(uuid: Uuid) -> u64 { } } +pub fn bytes(bytes: Bytes) -> u64 { + unsafe { + ffi::hash_combine64( + 0, + ffi::hash_bytes_extended(bytes.as_ptr(), bytes.len() as i64), + ) + } +} + /// Shard a string value, parsing out a BIGINT, UUID, or vector. /// /// TODO: This is really not great, we should pass in the type oid diff --git a/pgdog/src/frontend/router/sharding/operator.rs b/pgdog/src/frontend/router/sharding/operator.rs index 380742d8..fd8c08bc 100644 --- a/pgdog/src/frontend/router/sharding/operator.rs +++ b/pgdog/src/frontend/router/sharding/operator.rs @@ -1,4 +1,5 @@ use super::Centroids; +use crate::config::{ShardListMap, ShardRangeMap}; #[derive(Debug)] pub enum Operator<'a> { @@ -8,4 +9,6 @@ pub enum Operator<'a> { probes: usize, centroids: Centroids<'a>, }, + Lists(ShardListMap), + Ranges(ShardRangeMap), } diff --git a/pgdog/src/frontend/router/sharding/value.rs b/pgdog/src/frontend/router/sharding/value.rs index 101ecb9a..ee9f461d 100644 --- a/pgdog/src/frontend/router/sharding/value.rs +++ b/pgdog/src/frontend/router/sharding/value.rs @@ -1,5 +1,4 @@ use std::str::{from_utf8, FromStr}; - use uuid::Uuid; use super::{bigint, uuid, Error}; @@ -72,6 +71,22 @@ impl<'a> Value<'a> { } } + pub fn int(&self) -> Result, Error> { + match self.data_type { + DataType::Bigint => match self.data { + Data::Text(text) => Ok(Some(text.parse::()?)), + Data::Binary(data) => Ok(Some(match data.len() { + 2 => i16::from_be_bytes(data.try_into()?) as i64, + 4 => i32::from_be_bytes(data.try_into()?) as i64, + 8 => i64::from_be_bytes(data.try_into()?) as i64, + _ => return Err(Error::IntegerSize), + })), + Data::Integer(int) => Ok(Some(int)), + }, + _ => Ok(None), + } + } + pub fn valid(&self) -> bool { match self.data_type { DataType::Bigint => match self.data {