diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 00000000..dc8f030f --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,20 @@ +# Bash commands + +- `cargo check` to test that the code compiles. It shouldn't contain warnings. This is quicker than `cargo build`. +- `cargo fmt` to reformat code according to Rust standards. +- `cargo nextest run ` to run a specific test +- `cargo nextest run --test-threads=1` to run all tests. Make sure to use `--test-threads=1` because some tests conflict with each other. + +# Code style + +Use standard Rust code style. Use `cargo fmt` to reformat code automatically after every edit. + +# Workflow + +- Prefer to run individual tests with `cargo nextest run `. This is much faster. +- A local PostgreSQL server is required for some tests to pass. Set it up and create a database called "pgdog". Create a user called "pgdog" with password "pgdog". +- Ignore files in all folders except `./pgdog`. + +# About the project + +PgDog is a connection pooler for Postgres that can shard databases. It implements the Postgres network protocol and uses pg_query to parse SQL queries. It aims to be 100% compatible with Postgres, without clients knowing they are talking to a proxy. diff --git a/integration/logical/pgdog.toml b/integration/logical/pgdog.toml index 572a27e0..60586b9d 100644 --- a/integration/logical/pgdog.toml +++ b/integration/logical/pgdog.toml @@ -17,13 +17,13 @@ database_name = "pgdog" min_pool_size = 0 shard = 0 -[[databases]] -name = "destination" -host = "127.0.0.1" -port = 5434 -database_name = "pgdog" -min_pool_size = 0 -shard = 1 +# [[databases]] +# name = "destination" +# host = "127.0.0.1" +# port = 5434 +# database_name = "pgdog" +# min_pool_size = 0 +# shard = 1 [[sharded_tables]] database = "destination" diff --git a/pgdog/src/backend/schema/sync/error.rs b/pgdog/src/backend/schema/sync/error.rs index d4f5d15f..9b2d4cae 100644 --- a/pgdog/src/backend/schema/sync/error.rs +++ b/pgdog/src/backend/schema/sync/error.rs @@ -31,4 +31,7 @@ pub enum Error { #[error("cluster has no databases")] NoDatabases, + + #[error("missing entity in dump")] + MissingEntity, } diff --git a/pgdog/src/backend/schema/sync/pg_dump.rs b/pgdog/src/backend/schema/sync/pg_dump.rs index 19ce33f0..721d052b 100644 --- a/pgdog/src/backend/schema/sync/pg_dump.rs +++ b/pgdog/src/backend/schema/sync/pg_dump.rs @@ -16,7 +16,7 @@ use crate::{ Cluster, }, config::config, - frontend::router::parser::Table, + frontend::router::parser::{sequence::Sequence, Column, Table}, }; use tokio::process::Command; @@ -158,8 +158,10 @@ pub struct PgDumpOutput { pub enum SyncState { PreData, PostData, + Cutover, } +#[derive(Debug)] pub enum Statement<'a> { Index { table: Table<'a>, @@ -175,15 +177,28 @@ pub enum Statement<'a> { Other { sql: &'a str, }, + + SequenceOwner { + column: Column<'a>, + sequence: Sequence<'a>, + sql: &'a str, + }, + + SequenceSetMax { + sequence: Sequence<'a>, + sql: String, + }, } impl<'a> Deref for Statement<'a> { - type Target = &'a str; + type Target = str; fn deref(&self) -> &Self::Target { match self { - Self::Index { sql, .. } => sql, - Self::Table { sql, .. } => sql, - Self::Other { sql } => sql, + Self::Index { sql, .. } => *sql, + Self::Table { sql, .. } => *sql, + Self::SequenceOwner { sql, .. } => *sql, + Self::Other { sql } => *sql, + Self::SequenceSetMax { sql, .. } => sql.as_str(), } } } @@ -276,9 +291,30 @@ impl PgDumpOutput { } } - NodeEnum::AlterSeqStmt(_stmt) => { - if state == SyncState::PreData { - result.push(original.into()); + NodeEnum::AlterSeqStmt(stmt) => { + if matches!(state, SyncState::PreData | SyncState::Cutover) { + let sequence = stmt + .sequence + .as_ref() + .map(Table::from) + .ok_or(Error::MissingEntity)?; + let sequence = Sequence::from(sequence); + let column = stmt.options.first().ok_or(Error::MissingEntity)?; + let column = + Column::try_from(column).map_err(|_| Error::MissingEntity)?; + + if state == SyncState::PreData { + result.push(Statement::SequenceOwner { + column, + sequence, + sql: original, + }); + } else { + let sql = sequence + .setval_from_column(&column) + .map_err(|_| Error::MissingEntity)?; + result.push(Statement::SequenceSetMax { sequence, sql }) + } } } @@ -381,7 +417,9 @@ mod test { .await .unwrap(); - let output = output.statements(SyncState::PreData).unwrap(); + let output_pre = output.statements(SyncState::PreData).unwrap(); + let output_post = output.statements(SyncState::PostData).unwrap(); + let output_cutover = output.statements(SyncState::Cutover).unwrap(); let mut dest = test_server().await; dest.execute("DROP SCHEMA IF EXISTS test_pg_dump_execute_dest CASCADE") @@ -395,7 +433,7 @@ mod test { .await .unwrap(); - for stmt in output { + for stmt in output_pre { // Hack around us using the same database as destination. // I know, not very elegant. let stmt = stmt.replace("pgdog.", "test_pg_dump_execute_dest."); @@ -408,7 +446,30 @@ mod test { .unwrap(); assert_eq!(id[0], i + 1); // Sequence has made it over. - // Unique index has not made it over tho. + // Unique index didn't make it over. + } + + dest.execute("DELETE FROM test_pg_dump_execute_dest.test_pg_dump_execute") + .await + .unwrap(); + + for stmt in output_post { + let stmt = stmt.replace("pgdog.", "test_pg_dump_execute_dest."); + dest.execute(stmt).await.unwrap(); + } + + let q = "INSERT INTO test_pg_dump_execute_dest.test_pg_dump_execute VALUES (DEFAULT, 'test@test', NOW()) RETURNING id"; + assert!(dest.execute(q).await.is_ok()); + let err = dest.execute(q).await.err().unwrap(); + assert!(err.to_string().contains( + r#"duplicate key value violates unique constraint "test_pg_dump_execute_email_key""# + )); // Unique index made it over. + + assert_eq!(output_cutover.len(), 1); + for stmt in output_cutover { + let stmt = stmt.replace("pgdog.", "test_pg_dump_execute_dest."); + assert!(stmt.starts_with("SELECT setval('")); + dest.execute(stmt).await.unwrap(); } dest.execute("DROP SCHEMA test_pg_dump_execute_dest CASCADE") diff --git a/pgdog/src/backend/schema/sync/progress.rs b/pgdog/src/backend/schema/sync/progress.rs index 0b9db38b..32e943c5 100644 --- a/pgdog/src/backend/schema/sync/progress.rs +++ b/pgdog/src/backend/schema/sync/progress.rs @@ -19,6 +19,10 @@ pub enum Item { schema: String, name: String, }, + // SequenceOwner { + // sequence: String, + // owner: String, + // }, Other { sql: String, }, @@ -89,6 +93,12 @@ impl From<&Statement<'_>> for Item { Statement::Other { sql } => Item::Other { sql: sql.to_string(), }, + Statement::SequenceOwner { sql, .. } => Item::Other { + sql: sql.to_string(), + }, + Statement::SequenceSetMax { sql, .. } => Item::Other { + sql: sql.to_string(), + }, } } } diff --git a/pgdog/src/frontend/router/parser/column.rs b/pgdog/src/frontend/router/parser/column.rs index 4b7f54d8..d37aac72 100644 --- a/pgdog/src/frontend/router/parser/column.rs +++ b/pgdog/src/frontend/router/parser/column.rs @@ -4,12 +4,49 @@ use pg_query::{ protobuf::{self, String as PgQueryString}, Node, NodeEnum, }; +use std::fmt::{Display, Formatter, Result as FmtResult}; + +use super::Table; +use crate::util::escape_identifier; /// Column name extracted from a query. -#[derive(Debug, Clone, Copy, PartialEq)] +#[derive(Debug, Clone, Copy, PartialEq, Default)] pub struct Column<'a> { /// Column name. pub name: &'a str, + /// Table name. + pub table: Option<&'a str>, + /// Schema name. + pub schema: Option<&'a str>, +} + +/// Owned version of Column that owns its string data. +#[derive(Debug, Clone, PartialEq, Default)] +pub struct OwnedColumn { + /// Column name. + pub name: String, + /// Table name. + pub table: Option, + /// Schema name. + pub schema: Option, +} + +impl<'a> Column<'a> { + pub fn table(&self) -> Option> { + if let Some(table) = self.table { + Some(Table { + name: table, + schema: self.schema.clone(), + }) + } else { + None + } + } + + /// Convert this borrowed Column to an owned OwnedColumn + pub fn to_owned(&self) -> OwnedColumn { + OwnedColumn::from(*self) + } } impl<'a> Column<'a> { @@ -17,6 +54,7 @@ impl<'a> Column<'a> { match &string.node { Some(NodeEnum::String(protobuf::String { sval })) => Ok(Self { name: sval.as_str(), + ..Default::default() }), _ => Err(()), @@ -24,6 +62,60 @@ impl<'a> Column<'a> { } } +impl<'a> Display for Column<'a> { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { + match (self.schema, self.table) { + (Some(schema), Some(table)) => { + write!( + f, + "\"{}\".\"{}\".\"{}\"", + escape_identifier(schema), + escape_identifier(table), + escape_identifier(self.name) + ) + } + (None, Some(table)) => { + write!( + f, + "\"{}\".\"{}\"", + escape_identifier(table), + escape_identifier(self.name) + ) + } + _ => { + write!(f, "\"{}\"", escape_identifier(self.name)) + } + } + } +} + +impl Display for OwnedColumn { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { + let borrowed = Column::from(self); + borrowed.fmt(f) + } +} + +impl<'a> From> for OwnedColumn { + fn from(column: Column<'a>) -> Self { + Self { + name: column.name.to_owned(), + table: column.table.map(|s| s.to_owned()), + schema: column.schema.map(|s| s.to_owned()), + } + } +} + +impl<'a> From<&'a OwnedColumn> for Column<'a> { + fn from(owned: &'a OwnedColumn) -> Self { + Self { + name: &owned.name, + table: owned.table.as_deref(), + schema: owned.schema.as_deref(), + } + } +} + impl<'a> TryFrom<&'a Node> for Column<'a> { type Error = (); @@ -36,27 +128,78 @@ impl<'a> TryFrom<&'a Option> for Column<'a> { type Error = (); fn try_from(value: &'a Option) -> Result { + fn from_node(node: &Node) -> Option<&str> { + if let Some(NodeEnum::String(PgQueryString { sval })) = &node.node { + Some(sval.as_str()) + } else { + None + } + } + + fn from_slice<'a>(nodes: &'a [Node]) -> Result, ()> { + match nodes.len() { + 3 => { + let schema = nodes.iter().nth(0).map(from_node).flatten(); + let table = nodes.iter().nth(1).map(from_node).flatten(); + let name = nodes.iter().nth(2).map(from_node).flatten().ok_or(())?; + + return Ok(Column { + schema, + table, + name, + }); + } + + 2 => { + let table = nodes.iter().nth(0).map(from_node).flatten(); + let name = nodes.iter().nth(1).map(from_node).flatten().ok_or(())?; + + return Ok(Column { + schema: None, + table, + name, + }); + } + + 1 => { + let name = nodes.iter().nth(0).map(from_node).flatten().ok_or(())?; + + return Ok(Column { + name, + ..Default::default() + }); + } + + _ => return Err(()), + } + } + match value { Some(NodeEnum::ResTarget(res_target)) => { return Ok(Self { name: res_target.name.as_str(), + ..Default::default() }); } - Some(NodeEnum::ColumnRef(column_ref)) => { - if let Some(node) = column_ref.fields.last() { - if let Some(NodeEnum::String(PgQueryString { sval })) = &node.node { - return Ok(Self { - name: sval.as_str(), - }); + Some(NodeEnum::List(list)) => from_slice(&list.items), + + Some(NodeEnum::ColumnRef(column_ref)) => from_slice(&column_ref.fields), + + Some(NodeEnum::DefElem(list)) => { + if list.defname == "owned_by" { + if let Some(ref node) = list.arg { + Ok(Column::try_from(&node.node)?) + } else { + Err(()) } + } else { + Err(()) } } _ => return Err(()), } - - Err(()) } } @@ -92,11 +235,41 @@ mod test { .unwrap(); assert_eq!( columns, - vec![Column { name: "id" }, Column { name: "email" }] + vec![ + Column { + name: "id", + ..Default::default() + }, + Column { + name: "email", + ..Default::default() + } + ] ); } _ => panic!("not a select"), } } + + #[test] + fn test_column_sequence() { + let query = + parse("ALTER SEQUENCE public.user_profiles_id_seq OWNED BY public.user_profiles.id") + .unwrap(); + let alter = query.protobuf.stmts.first().unwrap().stmt.as_ref().unwrap(); + match alter.node { + Some(NodeEnum::AlterSeqStmt(ref stmt)) => { + if let Some(node) = stmt.options.first() { + let column = Column::try_from(node).unwrap(); + assert_eq!(column.name, "id"); + assert_eq!(column.schema, Some("public")); + assert_eq!(column.table, Some("user_profiles")); + } else { + panic!("no owned by clause"); + } + } + _ => panic!("not an alter sequence"), + } + } } diff --git a/pgdog/src/frontend/router/parser/error.rs b/pgdog/src/frontend/router/parser/error.rs index a1e26e73..77aebcdc 100644 --- a/pgdog/src/frontend/router/parser/error.rs +++ b/pgdog/src/frontend/router/parser/error.rs @@ -53,4 +53,7 @@ pub enum Error { #[error("missing parameter: ${0}")] MissingParameter(usize), + + #[error("column has no associated table")] + ColumnNoTable, } diff --git a/pgdog/src/frontend/router/parser/insert.rs b/pgdog/src/frontend/router/parser/insert.rs index 115d9880..a1bc0a83 100644 --- a/pgdog/src/frontend/router/parser/insert.rs +++ b/pgdog/src/frontend/router/parser/insert.rs @@ -151,7 +151,16 @@ mod test { ); assert_eq!( insert.columns(), - vec![Column { name: "id" }, Column { name: "email" }] + vec![ + Column { + name: "id", + ..Default::default() + }, + Column { + name: "email", + ..Default::default() + } + ] ); } diff --git a/pgdog/src/frontend/router/parser/mod.rs b/pgdog/src/frontend/router/parser/mod.rs index cd6cbc3c..67df7e6e 100644 --- a/pgdog/src/frontend/router/parser/mod.rs +++ b/pgdog/src/frontend/router/parser/mod.rs @@ -21,6 +21,7 @@ pub mod prepare; pub mod query; pub mod rewrite; pub mod route; +pub mod sequence; pub mod table; pub mod tuple; pub mod value; @@ -29,7 +30,7 @@ pub mod where_clause; pub use aggregate::{Aggregate, AggregateFunction, AggregateTarget}; pub use binary::BinaryStream; pub use cache::Cache; -pub use column::Column; +pub use column::{Column, OwnedColumn}; pub use command::Command; pub use context::QueryParserContext; pub use copy::{CopyFormat, CopyParser}; @@ -45,7 +46,8 @@ pub use order_by::OrderBy; pub use prepare::Prepare; pub use query::QueryParser; pub use route::{Route, Shard}; -pub use table::Table; +pub use sequence::{OwnedSequence, Sequence}; +pub use table::{OwnedTable, Table}; pub use tuple::Tuple; pub use value::Value; pub use where_clause::WhereClause; diff --git a/pgdog/src/frontend/router/parser/query/mod.rs b/pgdog/src/frontend/router/parser/query/mod.rs index 54fc4dea..182ea4a8 100644 --- a/pgdog/src/frontend/router/parser/query/mod.rs +++ b/pgdog/src/frontend/router/parser/query/mod.rs @@ -49,7 +49,7 @@ use tracing::{debug, trace}; /// #[derive(Debug)] pub struct QueryParser { - // The statement is executed inside a tranasction. + // The statement is executed inside a transaction. in_transaction: bool, // No matter what query is executed, we'll send it to the primary. write_override: bool, diff --git a/pgdog/src/frontend/router/parser/sequence.rs b/pgdog/src/frontend/router/parser/sequence.rs new file mode 100644 index 00000000..dcac8ebc --- /dev/null +++ b/pgdog/src/frontend/router/parser/sequence.rs @@ -0,0 +1,165 @@ +use std::fmt::Display; + +use super::{error::Error, Column, OwnedTable, Table}; +use crate::util::escape_identifier; + +/// Sequence name in a query. +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct Sequence<'a> { + /// Table representing the sequence name and schema. + pub table: Table<'a>, +} + +/// Owned version of Sequence that owns its string data. +#[derive(Debug, Clone, PartialEq)] +pub struct OwnedSequence { + /// Table representing the sequence name and schema. + pub table: OwnedTable, +} + +impl Display for Sequence<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.table.fmt(f) + } +} + +impl Default for Sequence<'_> { + fn default() -> Self { + Self { + table: Table::default(), + } + } +} + +impl Display for OwnedSequence { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let borrowed = Sequence::from(self); + borrowed.fmt(f) + } +} + +impl Default for OwnedSequence { + fn default() -> Self { + Self { + table: OwnedTable::default(), + } + } +} + +impl<'a> From> for OwnedSequence { + fn from(sequence: Sequence<'a>) -> Self { + Self { + table: OwnedTable::from(sequence.table), + } + } +} + +impl<'a> From<&'a OwnedSequence> for Sequence<'a> { + fn from(owned: &'a OwnedSequence) -> Self { + Self { + table: Table::from(&owned.table), + } + } +} + +impl From for OwnedSequence { + fn from(table: OwnedTable) -> Self { + Self { table } + } +} + +impl<'a> Sequence<'a> { + /// Convert this borrowed Sequence to an owned OwnedSequence + pub fn to_owned(&self) -> OwnedSequence { + OwnedSequence::from(*self) + } + + /// Generate a setval statement to set the sequence to the max value of the given column + pub fn setval_from_column(&self, column: &Column<'a>) -> Result { + let sequence_name = self.table.to_string(); + + let table = column.table().ok_or(Error::ColumnNoTable)?; + let table_name = table.to_string(); + + let column_name = format!("\"{}\"", escape_identifier(column.name)); + + Ok(format!( + "SELECT setval('{}', COALESCE((SELECT MAX({}) FROM {}), 1), true);", + sequence_name, column_name, table_name + )) + } +} + +impl<'a> From> for Sequence<'a> { + fn from(table: Table<'a>) -> Self { + Self { table } + } +} + +#[cfg(test)] +mod test { + use pg_query::{parse, NodeEnum}; + + use super::{Column, Sequence, Table}; + + #[test] + fn test_sequence_setval_from_alter_statement() { + let query = + parse("ALTER SEQUENCE public.user_profiles_id_seq OWNED BY public.user_profiles.id") + .unwrap(); + let alter = query.protobuf.stmts.first().unwrap().stmt.as_ref().unwrap(); + + match alter.node { + Some(NodeEnum::AlterSeqStmt(ref stmt)) => { + // Extract sequence name from the relation + let sequence_table = Table::from(stmt.sequence.as_ref().unwrap()); + let sequence = Sequence::from(sequence_table); + + // Extract column from the owned_by option + if let Some(node) = stmt.options.first() { + let column = Column::try_from(node).unwrap(); + + // Test the setval generation + let setval_sql = sequence.setval_from_column(&column).unwrap(); + + assert_eq!( + setval_sql, + "SELECT setval('\"public\".\"user_profiles_id_seq\"', COALESCE((SELECT MAX(\"id\") FROM \"public\".\"user_profiles\"), 1), true);" + ); + + // Verify the individual components + assert_eq!(sequence.table.name, "user_profiles_id_seq"); + assert_eq!(sequence.table.schema, Some("public")); + assert_eq!(column.name, "id"); + assert_eq!(column.table, Some("user_profiles")); + assert_eq!(column.schema, Some("public")); + } else { + panic!("no owned by clause"); + } + } + _ => panic!("not an alter sequence"), + } + } + + #[test] + fn test_sequence_display() { + let table = Table { + name: "my_seq", + schema: Some("public"), + }; + let sequence = Sequence::from(table); + + assert_eq!(sequence.to_string(), "\"public\".\"my_seq\""); + } + + #[test] + fn test_sequence_display_no_schema() { + let table = Table { + name: "my_seq", + schema: None, + }; + let sequence = Sequence::from(table); + + assert_eq!(sequence.to_string(), "\"my_seq\""); + } +} diff --git a/pgdog/src/frontend/router/parser/table.rs b/pgdog/src/frontend/router/parser/table.rs index 61c346e4..f7b46b8e 100644 --- a/pgdog/src/frontend/router/parser/table.rs +++ b/pgdog/src/frontend/router/parser/table.rs @@ -1,4 +1,8 @@ -use pg_query::{protobuf::*, NodeEnum}; +use std::fmt::Display; + +use pg_query::{protobuf::RangeVar, Node, NodeEnum}; + +use crate::util::escape_identifier; /// Table name in a query. #[derive(Debug, Clone, Copy, PartialEq)] @@ -9,6 +13,30 @@ pub struct Table<'a> { pub schema: Option<&'a str>, } +/// Owned version of Table that owns its string data. +#[derive(Debug, Clone, PartialEq)] +pub struct OwnedTable { + /// Table name. + pub name: String, + /// Schema name, if specified. + pub schema: Option, +} + +impl Display for Table<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if let Some(schema) = self.schema { + write!( + f, + "\"{}\".\"{}\"", + escape_identifier(schema), + escape_identifier(self.name) + ) + } else { + write!(f, "\"{}\"", escape_identifier(self.name)) + } + } +} + impl Default for Table<'_> { fn default() -> Self { Self { @@ -18,6 +46,47 @@ impl Default for Table<'_> { } } +impl<'a> Table<'a> { + /// Convert this borrowed Table to an owned OwnedTable + pub fn to_owned(&self) -> OwnedTable { + OwnedTable::from(*self) + } +} + +impl Display for OwnedTable { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let borrowed = Table::from(self); + borrowed.fmt(f) + } +} + +impl Default for OwnedTable { + fn default() -> Self { + Self { + name: String::new(), + schema: None, + } + } +} + +impl<'a> From> for OwnedTable { + fn from(table: Table<'a>) -> Self { + Self { + name: table.name.to_owned(), + schema: table.schema.map(|s| s.to_owned()), + } + } +} + +impl<'a> From<&'a OwnedTable> for Table<'a> { + fn from(owned: &'a OwnedTable) -> Self { + Self { + name: &owned.name, + schema: owned.schema.as_deref(), + } + } +} + impl<'a> TryFrom<&'a Node> for Table<'a> { type Error = (); diff --git a/pgdog/src/util.rs b/pgdog/src/util.rs index 614dddf4..eccbdb25 100644 --- a/pgdog/src/util.rs +++ b/pgdog/src/util.rs @@ -70,6 +70,11 @@ pub fn random_string(n: usize) -> String { .collect() } +/// Escape PostgreSQL identifiers by doubling any embedded quotes. +pub fn escape_identifier(s: &str) -> String { + s.replace("\"", "\"\"") +} + #[cfg(test)] mod test { @@ -94,4 +99,16 @@ mod test { ); let _now = postgres_now(); } + + #[test] + fn test_escape_identifier() { + assert_eq!(escape_identifier("simple"), "simple"); + assert_eq!(escape_identifier("has\"quote"), "has\"\"quote"); + assert_eq!(escape_identifier("\"starts_with"), "\"\"starts_with"); + assert_eq!(escape_identifier("ends_with\""), "ends_with\"\""); + assert_eq!( + escape_identifier("\"multiple\"quotes\""), + "\"\"multiple\"\"quotes\"\"" + ); + } }