diff --git a/example/vote/src/generated/procedure.rs b/example/vote/src/generated/procedure.rs index 6bd3586..4510b4a 100644 --- a/example/vote/src/generated/procedure.rs +++ b/example/vote/src/generated/procedure.rs @@ -337,8 +337,8 @@ pub async fn mudu_inner_p2_create_user( Ok(::mudu_contract::procedure::procedure_result::ProcedureResult::from(tuple, &return_desc)?) } -pub fn mudu_argv_desc_create_user( -) -> &'static ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc { +pub fn mudu_argv_desc_create_user() +-> &'static ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc { static ARGV_DESC: std::sync::OnceLock< ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc, > = std::sync::OnceLock::new(); @@ -353,8 +353,8 @@ pub fn mudu_argv_desc_create_user( }) } -pub fn mudu_result_desc_create_user( -) -> &'static ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc { +pub fn mudu_result_desc_create_user() +-> &'static ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc { static RESULT_DESC: std::sync::OnceLock< ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc, > = std::sync::OnceLock::new(); @@ -426,8 +426,8 @@ pub async fn mudu_inner_p2_cast_vote( Ok(::mudu_contract::procedure::procedure_result::ProcedureResult::from(tuple, &return_desc)?) } -pub fn mudu_argv_desc_cast_vote( -) -> &'static ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc { +pub fn mudu_argv_desc_cast_vote() +-> &'static ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc { static ARGV_DESC: std::sync::OnceLock< ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc, > = std::sync::OnceLock::new(); @@ -464,8 +464,8 @@ pub fn mudu_argv_desc_cast_vote( }) } -pub fn mudu_result_desc_cast_vote( -) -> &'static ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc { +pub fn mudu_result_desc_cast_vote() +-> &'static ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc { static RESULT_DESC: std::sync::OnceLock< ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc, > = std::sync::OnceLock::new(); @@ -532,8 +532,8 @@ pub async fn mudu_inner_p2_get_vote_result( Ok(::mudu_contract::procedure::procedure_result::ProcedureResult::from(tuple, &return_desc)?) } -pub fn mudu_argv_desc_get_vote_result( -) -> &'static ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc { +pub fn mudu_argv_desc_get_vote_result() +-> &'static ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc { static ARGV_DESC: std::sync::OnceLock< ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc, > = std::sync::OnceLock::new(); @@ -548,8 +548,8 @@ pub fn mudu_argv_desc_get_vote_result( }) } -pub fn mudu_result_desc_get_vote_result( -) -> &'static ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc { +pub fn mudu_result_desc_get_vote_result() +-> &'static ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc { static RESULT_DESC: std::sync::OnceLock< ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc, > = std::sync::OnceLock::new(); @@ -617,8 +617,8 @@ pub async fn mudu_inner_p2_get_voting_history( Ok(::mudu_contract::procedure::procedure_result::ProcedureResult::from(tuple, &return_desc)?) } -pub fn mudu_argv_desc_get_voting_history( -) -> &'static ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc { +pub fn mudu_argv_desc_get_voting_history() +-> &'static ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc { static ARGV_DESC: std::sync::OnceLock< ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc, > = std::sync::OnceLock::new(); @@ -633,8 +633,8 @@ pub fn mudu_argv_desc_get_voting_history( }) } -pub fn mudu_result_desc_get_voting_history( -) -> &'static ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc { +pub fn mudu_result_desc_get_voting_history() +-> &'static ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc { static RESULT_DESC: std::sync::OnceLock< ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc, > = std::sync::OnceLock::new(); @@ -649,8 +649,8 @@ pub fn mudu_result_desc_get_voting_history( }) } -pub fn mudu_proc_desc_get_voting_history( -) -> &'static ::mudu_contract::procedure::proc_desc::ProcDesc { +pub fn mudu_proc_desc_get_voting_history() +-> &'static ::mudu_contract::procedure::proc_desc::ProcDesc { static _PROC_DESC: std::sync::OnceLock<::mudu_contract::procedure::proc_desc::ProcDesc> = std::sync::OnceLock::new(); _PROC_DESC.get_or_init(|| { @@ -709,8 +709,8 @@ pub async fn mudu_inner_p2_add_option( Ok(::mudu_contract::procedure::procedure_result::ProcedureResult::from(tuple, &return_desc)?) } -pub fn mudu_argv_desc_add_option( -) -> &'static ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc { +pub fn mudu_argv_desc_add_option() +-> &'static ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc { static ARGV_DESC: std::sync::OnceLock< ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc, > = std::sync::OnceLock::new(); @@ -726,8 +726,8 @@ pub fn mudu_argv_desc_add_option( }) } -pub fn mudu_result_desc_add_option( -) -> &'static ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc { +pub fn mudu_result_desc_add_option() +-> &'static ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc { static RESULT_DESC: std::sync::OnceLock< ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc, > = std::sync::OnceLock::new(); @@ -799,8 +799,8 @@ pub async fn mudu_inner_p2_create_vote( Ok(::mudu_contract::procedure::procedure_result::ProcedureResult::from(tuple, &return_desc)?) } -pub fn mudu_argv_desc_create_vote( -) -> &'static ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc { +pub fn mudu_argv_desc_create_vote() +-> &'static ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc { static ARGV_DESC: std::sync::OnceLock< ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc, > = std::sync::OnceLock::new(); @@ -849,8 +849,8 @@ pub fn mudu_argv_desc_create_vote( }) } -pub fn mudu_result_desc_create_vote( -) -> &'static ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc { +pub fn mudu_result_desc_create_vote() +-> &'static ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc { static RESULT_DESC: std::sync::OnceLock< ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc, > = std::sync::OnceLock::new(); @@ -918,8 +918,8 @@ pub async fn mudu_inner_p2_withdraw_vote( Ok(::mudu_contract::procedure::procedure_result::ProcedureResult::from(tuple, &return_desc)?) } -pub fn mudu_argv_desc_withdraw_vote( -) -> &'static ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc { +pub fn mudu_argv_desc_withdraw_vote() +-> &'static ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc { static ARGV_DESC: std::sync::OnceLock< ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc, > = std::sync::OnceLock::new(); @@ -934,8 +934,8 @@ pub fn mudu_argv_desc_withdraw_vote( }) } -pub fn mudu_result_desc_withdraw_vote( -) -> &'static ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc { +pub fn mudu_result_desc_withdraw_vote() +-> &'static ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc { static RESULT_DESC: std::sync::OnceLock< ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc, > = std::sync::OnceLock::new(); diff --git a/example/wallet/src/generated/procedures.rs b/example/wallet/src/generated/procedures.rs index efcc739..635c689 100644 --- a/example/wallet/src/generated/procedures.rs +++ b/example/wallet/src/generated/procedures.rs @@ -435,8 +435,8 @@ pub async fn mudu_inner_p2_create_user( } } -pub fn mudu_argv_desc_create_user( -) -> &'static ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc { +pub fn mudu_argv_desc_create_user() +-> &'static ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc { static ARGV_DESC: std::sync::OnceLock< ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc, > = std::sync::OnceLock::new(); @@ -458,8 +458,8 @@ pub fn mudu_argv_desc_create_user( }) } -pub fn mudu_result_desc_create_user( -) -> &'static ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc { +pub fn mudu_result_desc_create_user() +-> &'static ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc { static RESULT_DESC: std::sync::OnceLock< ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc, > = std::sync::OnceLock::new(); @@ -555,8 +555,8 @@ pub fn mudu_argv_desc_purchase() -> &'static ::mudu_contract::tuple::tuple_field }) } -pub fn mudu_result_desc_purchase( -) -> &'static ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc { +pub fn mudu_result_desc_purchase() +-> &'static ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc { static RESULT_DESC: std::sync::OnceLock< ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc, > = std::sync::OnceLock::new(); @@ -627,8 +627,8 @@ pub async fn mudu_inner_p2_delete_user( } } -pub fn mudu_argv_desc_delete_user( -) -> &'static ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc { +pub fn mudu_argv_desc_delete_user() +-> &'static ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc { static ARGV_DESC: std::sync::OnceLock< ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc, > = std::sync::OnceLock::new(); @@ -642,8 +642,8 @@ pub fn mudu_argv_desc_delete_user( }) } -pub fn mudu_result_desc_delete_user( -) -> &'static ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc { +pub fn mudu_result_desc_delete_user() +-> &'static ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc { static RESULT_DESC: std::sync::OnceLock< ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc, > = std::sync::OnceLock::new(); @@ -734,8 +734,8 @@ pub fn mudu_argv_desc_deposit() -> &'static ::mudu_contract::tuple::tuple_field_ }) } -pub fn mudu_result_desc_deposit( -) -> &'static ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc { +pub fn mudu_result_desc_deposit() +-> &'static ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc { static RESULT_DESC: std::sync::OnceLock< ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc, > = std::sync::OnceLock::new(); @@ -808,8 +808,8 @@ pub async fn mudu_inner_p2_transfer_funds( } } -pub fn mudu_argv_desc_transfer_funds( -) -> &'static ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc { +pub fn mudu_argv_desc_transfer_funds() +-> &'static ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc { static ARGV_DESC: std::sync::OnceLock< ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc, > = std::sync::OnceLock::new(); @@ -831,8 +831,8 @@ pub fn mudu_argv_desc_transfer_funds( }) } -pub fn mudu_result_desc_transfer_funds( -) -> &'static ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc { +pub fn mudu_result_desc_transfer_funds() +-> &'static ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc { static RESULT_DESC: std::sync::OnceLock< ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc, > = std::sync::OnceLock::new(); @@ -905,8 +905,8 @@ pub async fn mudu_inner_p2_update_user( } } -pub fn mudu_argv_desc_update_user( -) -> &'static ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc { +pub fn mudu_argv_desc_update_user() +-> &'static ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc { static ARGV_DESC: std::sync::OnceLock< ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc, > = std::sync::OnceLock::new(); @@ -928,8 +928,8 @@ pub fn mudu_argv_desc_update_user( }) } -pub fn mudu_result_desc_update_user( -) -> &'static ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc { +pub fn mudu_result_desc_update_user() +-> &'static ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc { static RESULT_DESC: std::sync::OnceLock< ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc, > = std::sync::OnceLock::new(); @@ -1020,8 +1020,8 @@ pub fn mudu_argv_desc_withdraw() -> &'static ::mudu_contract::tuple::tuple_field }) } -pub fn mudu_result_desc_withdraw( -) -> &'static ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc { +pub fn mudu_result_desc_withdraw() +-> &'static ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc { static RESULT_DESC: std::sync::OnceLock< ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc, > = std::sync::OnceLock::new(); @@ -1117,8 +1117,8 @@ pub fn mudu_argv_desc_transfer() -> &'static ::mudu_contract::tuple::tuple_field }) } -pub fn mudu_result_desc_transfer( -) -> &'static ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc { +pub fn mudu_result_desc_transfer() +-> &'static ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc { static RESULT_DESC: std::sync::OnceLock< ::mudu_contract::tuple::tuple_field_desc::TupleFieldDesc, > = std::sync::OnceLock::new(); diff --git a/example/ycsb/src/bin/ycsb_benchmark.rs b/example/ycsb/src/bin/ycsb_benchmark.rs index 31adb8b..c8e9655 100644 --- a/example/ycsb/src/bin/ycsb_benchmark.rs +++ b/example/ycsb/src/bin/ycsb_benchmark.rs @@ -2,15 +2,15 @@ use clap::Parser; use mudu::common::result::RS; use mudu::common::xid::XID; use mudu_binding::universal::uni_session_open_argv::UniSessionOpenArgv; -use mudu_cli::management::{fetch_server_topology, ServerTopology}; +use mudu_cli::management::{ServerTopology, fetch_server_topology}; use mudu_contract::database::sql_stmt_text::SQLStmtText; use mudu_utils::debug::debug_serve; use mudu_utils::notifier::NotifyWait; use mudu_utils::task::spawn_task; use mudu_utils::task_trace; -use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use std::sync::Barrier as StdBarrier; +use std::sync::atomic::{AtomicUsize, Ordering}; use std::thread; use std::time::{Duration, Instant}; use sys_interface::async_api::{ diff --git a/mudu/src/common/_arb_de_en.rs b/mudu/src/common/_arb_de_en.rs index 39a9a07..44e5df4 100644 --- a/mudu/src/common/_arb_de_en.rs +++ b/mudu/src/common/_arb_de_en.rs @@ -1,6 +1,4 @@ -use crate::common::bc::{hdr_size, tail_size}; -use crate::common::bc_dec::{Decode, decode_binary}; -use crate::common::bc_enc::{Encode, encode_binary}; +use crate::common::codec::{Decode, Encode}; #[cfg(any(test, feature = "test"))] use arbitrary::{Arbitrary, Unstructured}; use std::fmt::Debug; @@ -17,20 +15,15 @@ pub fn _fuzz_decode_and_encode<'a, T: Arbitrary<'a> + Decode + Encode + Eq + Deb break; } }; - let _r = encode_binary(&t); - let b = match _r { - Ok(b) => b, - Err(_e) => { - panic!("{:?}", _e); - } - }; + let mut b = Vec::new(); + t.encode(&mut b).unwrap(); let _size = t.size().unwrap(); if _size != b.len() { let _ = t.size().unwrap(); } - assert_eq!(b.len(), _size + hdr_size() + tail_size()); - let _r = decode_binary::(&b); + assert_eq!(b.len(), _size); + let _r = T::decode(&mut (b.clone(), 0)); let _t = match _r { Ok(_t) => _t, Err(_e) => { diff --git a/mudu/src/common/bc.rs b/mudu/src/common/bc.rs deleted file mode 100644 index 807482b..0000000 --- a/mudu/src/common/bc.rs +++ /dev/null @@ -1,97 +0,0 @@ -use crate::common::bc_dec::{DecErr, Decode, Decoder}; -use crate::common::bc_enc::{EncErr, Encode, Encoder}; -use std::mem::size_of; - -pub fn hdr_size() -> usize { - BCHdr::hdr_size() -} - -pub fn tail_size() -> usize { - BCTail::tail_size() -} - -/// header, -/// 4 bytes body size -/// 8 bytes body crc -pub struct BCHdr { - length: u32, - crc: u64, -} - -/// tail, -/// 8 bytes body crc -pub struct BCTail { - crc: u64, -} - -impl BCHdr { - pub fn new(length: u32, crc: u64) -> Self { - Self { length, crc } - } - - // body length - pub fn length(&self) -> u32 { - self.length - } - - pub fn crc(&self) -> u64 { - self.crc - } - - pub fn hdr_size() -> usize { - // length size + crc size - size_of::() + size_of::() - } -} - -impl Decode for BCHdr { - fn decode(decoder: &mut D) -> Result { - let length = decoder.read_u32()?; - let crc = decoder.read_u64()?; - Ok(Self { length, crc }) - } -} - -impl Encode for BCHdr { - fn encode(&self, encoder: &mut E) -> Result<(), EncErr> { - encoder.write_u32(self.length)?; - encoder.write_u64(self.crc)?; - Ok(()) - } - - fn size(&self) -> Result { - Ok(size_of::() + size_of::()) - } -} - -impl BCTail { - pub fn new(crc: u64) -> Self { - Self { crc } - } - - pub fn crc(&self) -> u64 { - self.crc - } - - pub fn tail_size() -> usize { - size_of::() - } -} - -impl Decode for BCTail { - fn decode(decoder: &mut D) -> Result { - let crc = decoder.read_u64()?; - Ok(Self { crc }) - } -} - -impl Encode for BCTail { - fn encode(&self, encoder: &mut E) -> Result<(), EncErr> { - encoder.write_u64(self.crc)?; - Ok(()) - } - - fn size(&self) -> Result { - Ok(Self::tail_size()) - } -} diff --git a/mudu/src/common/bc_dec.rs b/mudu/src/common/bc_dec.rs deleted file mode 100644 index 6fbe1c9..0000000 --- a/mudu/src/common/bc_dec.rs +++ /dev/null @@ -1,122 +0,0 @@ -use crate::common::bc::{BCHdr, BCTail}; -use crate::common::endian::Endian; -use crate::common::slice::SliceRef; -use byteorder::ByteOrder; -use std::error::Error; -use std::fmt::{Display, Formatter}; - -#[derive(Debug, Clone)] -pub enum DecErr { - CapacityNotAvailable, - EmptyEnum { type_name: String }, - ErrorCRC, -} - -impl Error for DecErr {} -impl Display for DecErr { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{:?}", self)?; - Ok(()) - } -} - -pub trait Decoder { - fn read_i8(&mut self, _n: u8) -> Result { - let mut s = [0i8; 1]; - let bytes = unsafe { &mut *(&mut s as *mut [i8] as *mut [u8]) }; - self.read(bytes)?; - Ok(s[0]) - } - - fn read_u8(&mut self) -> Result { - let mut s = [0u8; 1]; - self.read(&mut s)?; - Ok(s[0]) - } - - fn read_u32(&mut self) -> Result { - let mut s = [0u8; 4]; - self.read(&mut s)?; - let n = Endian::read_u32(&s); - Ok(n) - } - - fn read_i32(&mut self) -> Result { - let mut s = [0u8; 4]; - self.read(&mut s)?; - let n = Endian::read_i32(&s); - Ok(n) - } - - fn read_i64(&mut self) -> Result { - let mut s = [0u8; 8]; - self.read(&mut s)?; - let n = Endian::read_i64(&s); - Ok(n) - } - - fn read_u64(&mut self) -> Result { - let mut s = [0u8; 8]; - self.read(&mut s)?; - let n = Endian::read_u64(&s); - Ok(n) - } - - fn read_i128(&mut self) -> Result { - let mut s = [0u8; 16]; - self.read(&mut s)?; - let n = Endian::read_i128(&s); - Ok(n) - } - - fn read_u128(&mut self) -> Result { - let mut s = [0u8; 16]; - self.read(&mut s)?; - let n = Endian::read_u128(&s); - Ok(n) - } - - fn read_bytes(&mut self, s: &mut [u8]) -> Result<(), DecErr> { - self.read(s)?; - Ok(()) - } - - fn read(&mut self, s: &mut [u8]) -> Result<(), DecErr>; -} - -pub trait Decode: Sized { - /// Encode a given type. - fn decode(decoder: &mut D) -> Result; -} - -fn decode_binary_header(slice: &[u8]) -> Result<(u32, u64), DecErr> { - let mut s = SliceRef::new(slice); - let hdr = BCHdr::decode(&mut s)?; - Ok((hdr.length(), hdr.crc())) -} - -fn decode_binary_tail(slice: &[u8]) -> Result { - let mut s = SliceRef::new(slice); - let tail = BCTail::decode(&mut s)?; - Ok(tail.crc()) -} - -fn decode_binary_body(slice: &[u8]) -> Result { - let mut r = SliceRef::new(slice); - let d = D::decode(&mut r)?; - Ok(d) -} - -pub fn decode_binary(slice: &[u8]) -> Result { - if slice.len() < BCHdr::hdr_size() + BCTail::tail_size() { - return Err(DecErr::CapacityNotAvailable); - } - let (length, start_crc) = decode_binary_header(&slice[0..BCHdr::hdr_size()])?; - let d = - decode_binary_body::(&slice[BCHdr::hdr_size()..BCHdr::hdr_size() + length as usize])?; - let end_crc = decode_binary_tail(&slice[BCHdr::hdr_size() + length as usize..])?; - if start_crc != end_crc { - return Err(DecErr::ErrorCRC); - } - Ok(d) -} diff --git a/mudu/src/common/bc_enc.rs b/mudu/src/common/bc_enc.rs deleted file mode 100644 index 51b0f2f..0000000 --- a/mudu/src/common/bc_enc.rs +++ /dev/null @@ -1,159 +0,0 @@ -use crate::common::bc::{BCHdr, BCTail, hdr_size, tail_size}; -use crate::common::bc_dec::{DecErr, Decoder}; -use crate::common::buf::Buf; -use crate::common::crc::calc_crc; -use crate::common::endian::Endian; -use crate::common::slice::SliceMutRef; -use byteorder::ByteOrder; - -#[derive(Debug, Ord, PartialOrd, PartialEq, Eq)] -pub enum EncErr { - CapacityNotAvailable, -} - -pub trait Encoder { - fn write_i8(&mut self, n: i8) -> Result<(), EncErr> { - let a = [n]; - let bytes = unsafe { &*(&a as *const [i8] as *const [u8]) }; - self.write(bytes)?; - Ok(()) - } - - fn write_u8(&mut self, n: u8) -> Result<(), EncErr> { - self.write(&[n])?; - Ok(()) - } - - fn write_i32(&mut self, n: i32) -> Result<(), EncErr> { - let mut buf = [0; 4]; - Endian::write_i32(&mut buf, n); - self.write(&buf)?; - Ok(()) - } - - fn write_u32(&mut self, n: u32) -> Result<(), EncErr> { - let mut buf = [0; 4]; - Endian::write_u32(&mut buf, n); - self.write(&buf)?; - Ok(()) - } - - fn write_i64(&mut self, n: i64) -> Result<(), EncErr> { - let mut buf = [0; 8]; - Endian::write_i64(&mut buf, n); - self.write(&buf)?; - Ok(()) - } - - fn write_u64(&mut self, n: u64) -> Result<(), EncErr> { - let mut buf = [0; 8]; - Endian::write_u64(&mut buf, n); - self.write(&buf)?; - Ok(()) - } - - fn write_i128(&mut self, n: i128) -> Result<(), EncErr> { - let mut buf = [0; 16]; - Endian::write_i128(&mut buf, n); - self.write(&buf)?; - Ok(()) - } - - fn write_u128(&mut self, n: u128) -> Result<(), EncErr> { - let mut buf = [0; 16]; - Endian::write_u128(&mut buf, n); - self.write(&buf)?; - Ok(()) - } - - fn write_bytes(&mut self, s: &[u8]) -> Result<(), EncErr> { - self.write(s) - } - - fn write(&mut self, s: &[u8]) -> Result<(), EncErr>; -} - -pub trait Encode { - /// Encode a given type. - fn encode(&self, ncoder: &mut E) -> Result<(), EncErr>; - - fn size(&self) -> Result; -} - -const DEFAULT_BUF_SIZE: usize = 1024; - -pub fn encode_binary(e: &E) -> Result { - let n = DEFAULT_BUF_SIZE + hdr_size() + tail_size(); - let mut buf: Buf = vec![0; n]; - let r = _binary_encode(e, &mut buf)?; - match r { - Ok(()) => Ok(buf), - Err(size) => { - buf.resize(size + hdr_size() + tail_size(), 0); - let _r = _binary_encode(e, &mut buf)?; - match _r { - Ok(()) => Ok(buf), - Err(_) => { - buf.resize(size + hdr_size() + tail_size(), 0); - let _r = _binary_encode(e, &mut buf)?; - panic!("error capacity"); - } - } - } - } -} - -fn _binary_encode(e: &E, buf: &mut Buf) -> Result, EncErr> { - let header_size = hdr_size(); - if buf.len() < header_size { - return Ok(Err(e.size()?)); - }; - - let buf_len = buf.len(); - let mut s = SliceMutRef::new(&mut buf.as_mut_slice()[header_size..buf_len]); - let r = e.encode(&mut s); - match r { - Ok(()) => { - let body_size = s.write_pos(); - let _ = s; - - let size = header_size + tail_size() + body_size; - - buf.resize(size, 0); - let crc = calc_crc(&buf.as_slice()[header_size..header_size + body_size]); - { - let hdr = BCHdr::new(body_size as u32, crc); - let mut s1 = SliceMutRef::new(&mut buf.as_mut_slice()[0..header_size]); - hdr.encode(&mut s1)?; - } - { - let mut s2 = SliceMutRef::new(&mut buf.as_mut_slice()[header_size + body_size..]); - let tail = BCTail::new(crc); - tail.encode(&mut s2)?; - } - Ok(Ok(())) - } - Err(err) => match err { - EncErr::CapacityNotAvailable => Ok(Err(e.size()?)), - }, - } -} - -impl Encoder for Buf { - fn write(&mut self, bytes: &[u8]) -> Result<(), EncErr> { - self.extend(bytes); - Ok(()) - } -} - -impl Decoder for (Buf, usize) { - fn read(&mut self, bytes: &mut [u8]) -> Result<(), DecErr> { - if self.0.len() >= self.1 + bytes.len() { - bytes.copy_from_slice(&self.0[self.1..self.1 + bytes.len()]); - self.1 += bytes.len(); - Ok(()) - } else { - Err(DecErr::CapacityNotAvailable) - } - } -} diff --git a/mudu/src/common/codec.rs b/mudu/src/common/codec.rs new file mode 100644 index 0000000..7e5b61b --- /dev/null +++ b/mudu/src/common/codec.rs @@ -0,0 +1,165 @@ +use crate::common::buf::Buf; +use crate::common::endian::Endian; +use byteorder::ByteOrder; +use std::error::Error; +use std::fmt::{Display, Formatter}; + +#[derive(Debug, Clone)] +pub enum DecErr { + CapacityNotAvailable, + EmptyEnum { type_name: String }, + ErrorCRC, +} + +impl Error for DecErr {} + +impl Display for DecErr { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self) + } +} + +#[derive(Debug, Ord, PartialOrd, PartialEq, Eq)] +pub enum EncErr { + CapacityNotAvailable, +} + +pub trait Decoder { + fn read_i8(&mut self, _n: u8) -> Result { + let mut s = [0i8; 1]; + let bytes = unsafe { &mut *(&mut s as *mut [i8] as *mut [u8]) }; + self.read(bytes)?; + Ok(s[0]) + } + + fn read_u8(&mut self) -> Result { + let mut s = [0u8; 1]; + self.read(&mut s)?; + Ok(s[0]) + } + + fn read_u32(&mut self) -> Result { + let mut s = [0u8; 4]; + self.read(&mut s)?; + Ok(Endian::read_u32(&s)) + } + + fn read_i32(&mut self) -> Result { + let mut s = [0u8; 4]; + self.read(&mut s)?; + Ok(Endian::read_i32(&s)) + } + + fn read_i64(&mut self) -> Result { + let mut s = [0u8; 8]; + self.read(&mut s)?; + Ok(Endian::read_i64(&s)) + } + + fn read_u64(&mut self) -> Result { + let mut s = [0u8; 8]; + self.read(&mut s)?; + Ok(Endian::read_u64(&s)) + } + + fn read_i128(&mut self) -> Result { + let mut s = [0u8; 16]; + self.read(&mut s)?; + Ok(Endian::read_i128(&s)) + } + + fn read_u128(&mut self) -> Result { + let mut s = [0u8; 16]; + self.read(&mut s)?; + Ok(Endian::read_u128(&s)) + } + + fn read_bytes(&mut self, s: &mut [u8]) -> Result<(), DecErr> { + self.read(s) + } + + fn read(&mut self, s: &mut [u8]) -> Result<(), DecErr>; +} + +pub trait Decode: Sized { + fn decode(decoder: &mut D) -> Result; +} + +pub trait Encoder { + fn write_i8(&mut self, n: i8) -> Result<(), EncErr> { + let a = [n]; + let bytes = unsafe { &*(&a as *const [i8] as *const [u8]) }; + self.write(bytes) + } + + fn write_u8(&mut self, n: u8) -> Result<(), EncErr> { + self.write(&[n]) + } + + fn write_i32(&mut self, n: i32) -> Result<(), EncErr> { + let mut buf = [0; 4]; + Endian::write_i32(&mut buf, n); + self.write(&buf) + } + + fn write_u32(&mut self, n: u32) -> Result<(), EncErr> { + let mut buf = [0; 4]; + Endian::write_u32(&mut buf, n); + self.write(&buf) + } + + fn write_i64(&mut self, n: i64) -> Result<(), EncErr> { + let mut buf = [0; 8]; + Endian::write_i64(&mut buf, n); + self.write(&buf) + } + + fn write_u64(&mut self, n: u64) -> Result<(), EncErr> { + let mut buf = [0; 8]; + Endian::write_u64(&mut buf, n); + self.write(&buf) + } + + fn write_i128(&mut self, n: i128) -> Result<(), EncErr> { + let mut buf = [0; 16]; + Endian::write_i128(&mut buf, n); + self.write(&buf) + } + + fn write_u128(&mut self, n: u128) -> Result<(), EncErr> { + let mut buf = [0; 16]; + Endian::write_u128(&mut buf, n); + self.write(&buf) + } + + fn write_bytes(&mut self, s: &[u8]) -> Result<(), EncErr> { + self.write(s) + } + + fn write(&mut self, s: &[u8]) -> Result<(), EncErr>; +} + +pub trait Encode { + fn encode(&self, encoder: &mut E) -> Result<(), EncErr>; + + fn size(&self) -> Result; +} + +impl Encoder for Buf { + fn write(&mut self, bytes: &[u8]) -> Result<(), EncErr> { + self.extend(bytes); + Ok(()) + } +} + +impl Decoder for (Buf, usize) { + fn read(&mut self, bytes: &mut [u8]) -> Result<(), DecErr> { + if self.0.len() >= self.1 + bytes.len() { + bytes.copy_from_slice(&self.0[self.1..self.1 + bytes.len()]); + self.1 += bytes.len(); + Ok(()) + } else { + Err(DecErr::CapacityNotAvailable) + } + } +} diff --git a/mudu/src/common/id.rs b/mudu/src/common/id.rs index 517365c..80509cd 100644 --- a/mudu/src/common/id.rs +++ b/mudu/src/common/id.rs @@ -6,6 +6,9 @@ pub type OID = u128; // Nth attribute index of data tuple pub type AttrIndex = usize; +// Nth datum position inside a key or value tuple +pub type DatumIndex = usize; + pub type TupleID = u64; pub type ThdID = u64; diff --git a/mudu/src/common/mod.rs b/mudu/src/common/mod.rs index 3665ae9..53447f6 100644 --- a/mudu/src/common/mod.rs +++ b/mudu/src/common/mod.rs @@ -2,10 +2,8 @@ pub mod _arb_de_en; pub mod _debug; -mod bc; -pub mod bc_dec; -pub mod bc_enc; pub mod buf; +pub mod codec; pub mod crc; pub mod endian; pub mod expected; diff --git a/mudu/src/common/slice.rs b/mudu/src/common/slice.rs index fca7023..db6147a 100644 --- a/mudu/src/common/slice.rs +++ b/mudu/src/common/slice.rs @@ -1,5 +1,4 @@ -use crate::common::bc_dec::{DecErr, Decoder}; -use crate::common::bc_enc::{EncErr, Encoder}; +use crate::common::codec::{DecErr, Decoder, EncErr, Encoder}; pub struct SliceRef<'r> { s: &'r [u8], read_pos: usize, diff --git a/mudu/src/common/update_delta.rs b/mudu/src/common/update_delta.rs index d23500d..e8e8114 100644 --- a/mudu/src/common/update_delta.rs +++ b/mudu/src/common/update_delta.rs @@ -1,5 +1,4 @@ -use crate::common::bc_dec::{DecErr, Decode, Decoder}; -use crate::common::bc_enc::{EncErr, Encode, Encoder}; +use crate::common::codec::{DecErr, Decode, Decoder, EncErr, Encode, Encoder}; use crate::common::buf::Buf; #[cfg(any(test, feature = "test"))] use arbitrary::{Arbitrary, Unstructured}; diff --git a/mudu_api/rust/src/universal/uni_dat_type_id.rs b/mudu_api/rust/src/universal/uni_dat_type_id.rs index 4776fe9..8021202 100644 --- a/mudu_api/rust/src/universal/uni_dat_type_id.rs +++ b/mudu_api/rust/src/universal/uni_dat_type_id.rs @@ -28,19 +28,23 @@ pub enum UniDatTypeId { I64 = 8, - F32 = 9, + OID = 9, - F64 = 10, + I128 = 10, - Char = 11, + F32 = 11, - String = 12, + F64 = 12, - Array = 13, + Char = 13, - Record = 14, + String = 14, - Binary = 15, + Array = 15, + + Record = 16, + + Binary = 17, } impl Default for UniDatTypeId { diff --git a/mudu_api/rust/src/universal/uni_dat_type_id_impl.rs b/mudu_api/rust/src/universal/uni_dat_type_id_impl.rs index 516f3d0..813300c 100644 --- a/mudu_api/rust/src/universal/uni_dat_type_id_impl.rs +++ b/mudu_api/rust/src/universal/uni_dat_type_id_impl.rs @@ -9,6 +9,8 @@ impl UniDatTypeId { let ty_id = match self { Self::I32 => DatTypeID::I32, Self::I64 => DatTypeID::I64, + Self::OID => DatTypeID::U128, + Self::I128 => DatTypeID::I128, Self::F32 => DatTypeID::F32, Self::F64 => DatTypeID::F64, Self::String => DatTypeID::String, @@ -25,6 +27,8 @@ impl UniDatTypeId { let uni_ty = match ty { DatTypeID::I32 => Self::I32, DatTypeID::I64 => Self::I64, + DatTypeID::U128 => Self::OID, + DatTypeID::I128 => Self::I128, DatTypeID::F32 => Self::F32, DatTypeID::F64 => Self::F64, DatTypeID::String => Self::String, diff --git a/mudu_api/rust/src/universal/uni_dat_value_impl.rs b/mudu_api/rust/src/universal/uni_dat_value_impl.rs index f4b33ad..3be1e12 100644 --- a/mudu_api/rust/src/universal/uni_dat_value_impl.rs +++ b/mudu_api/rust/src/universal/uni_dat_value_impl.rs @@ -32,7 +32,9 @@ impl UniDatValue { UniPrimitiveValue::U64(_) => { unimplemented!() } + UniPrimitiveValue::U128(v) => DatValue::from_u128(v), UniPrimitiveValue::I64(v) => DatValue::from_i64(v), + UniPrimitiveValue::I128(v) => DatValue::from_i128(v), UniPrimitiveValue::F32(v) => DatValue::from_f32(v), UniPrimitiveValue::F64(v) => DatValue::from_f64(v), UniPrimitiveValue::Char(_) => { @@ -72,6 +74,12 @@ impl UniDatValue { DatTypeID::I64 => { UniDatValue::from_primitive(UniPrimitiveValue::I64(dat_value.expect_i64().clone())) } + DatTypeID::I128 => UniDatValue::from_primitive(UniPrimitiveValue::I128( + dat_value.expect_i128().clone(), + )), + DatTypeID::U128 => UniDatValue::from_primitive(UniPrimitiveValue::U128( + dat_value.expect_u128().clone(), + )), DatTypeID::F32 => { UniDatValue::from_primitive(UniPrimitiveValue::F32(dat_value.expect_f32().clone())) } diff --git a/mudu_api/rust/src/universal/uni_primitive.rs b/mudu_api/rust/src/universal/uni_primitive.rs index 1dff67b..a0b53b4 100644 --- a/mudu_api/rust/src/universal/uni_primitive.rs +++ b/mudu_api/rust/src/universal/uni_primitive.rs @@ -26,17 +26,21 @@ pub enum UniPrimitive { U64 = 7, - I64 = 8, + U128 = 8, - F32 = 9, + I64 = 9, - F64 = 10, + I128 = 10, - Char = 11, + F32 = 11, - String = 12, + F64 = 12, - Blob = 13, + Char = 13, + + String = 14, + + Blob = 15, } impl Default for UniPrimitive { diff --git a/mudu_api/rust/src/universal/uni_primitive_impl.rs b/mudu_api/rust/src/universal/uni_primitive_impl.rs index 6ab6797..425a873 100644 --- a/mudu_api/rust/src/universal/uni_primitive_impl.rs +++ b/mudu_api/rust/src/universal/uni_primitive_impl.rs @@ -32,7 +32,9 @@ impl UniPrimitive { UniPrimitive::U64 => { unimplemented!() } + UniPrimitive::U128 => DatType::default_for(DatTypeID::U128), UniPrimitive::I64 => DatType::default_for(DatTypeID::I64), + UniPrimitive::I128 => DatType::default_for(DatTypeID::I128), UniPrimitive::F32 => DatType::default_for(DatTypeID::F32), UniPrimitive::F64 => DatType::default_for(DatTypeID::F64), UniPrimitive::Char => { @@ -48,6 +50,8 @@ impl UniPrimitive { let uni_prim = match ty.dat_type_id() { DatTypeID::I32 => Self::I32, DatTypeID::I64 => Self::I64, + DatTypeID::I128 => Self::I128, + DatTypeID::U128 => Self::U128, DatTypeID::F32 => Self::F32, DatTypeID::F64 => Self::F64, DatTypeID::String => Self::String, diff --git a/mudu_api/rust/src/universal/uni_primitive_value.rs b/mudu_api/rust/src/universal/uni_primitive_value.rs index d10227d..f10aa5d 100644 --- a/mudu_api/rust/src/universal/uni_primitive_value.rs +++ b/mudu_api/rust/src/universal/uni_primitive_value.rs @@ -17,8 +17,12 @@ pub enum UniPrimitiveValue { U64(u64), + U128(u128), + I64(i64), + I128(i128), + F32(f32), F64(f64), @@ -179,6 +183,24 @@ impl UniPrimitiveValue { } } + pub fn from_u128(inner: u128) -> Self { + Self::U128(inner) + } + + pub fn as_u128(&self) -> Option<&u128> { + match self { + Self::U128(inner) => Some(inner), + _ => None, + } + } + + pub fn expect_u128(&self) -> &u128 { + match self { + Self::U128(inner) => inner, + _ => unsafe { std::hint::unreachable_unchecked() }, + } + } + pub fn from_i64(inner: i64) -> Self { Self::I64(inner) } @@ -197,6 +219,24 @@ impl UniPrimitiveValue { } } + pub fn from_i128(inner: i128) -> Self { + Self::I128(inner) + } + + pub fn as_i128(&self) -> Option<&i128> { + match self { + Self::I128(inner) => Some(inner), + _ => None, + } + } + + pub fn expect_i128(&self) -> &i128 { + match self { + Self::I128(inner) => inner, + _ => unsafe { std::hint::unreachable_unchecked() }, + } + } + pub fn from_f32(inner: f32) -> Self { Self::F32(inner) } @@ -318,30 +358,40 @@ impl serde::Serialize for UniPrimitiveValue { serialize_seq.serialize_element(&inner)?; } - UniPrimitiveValue::I64(inner) => { + UniPrimitiveValue::U128(inner) => { serialize_seq.serialize_element(&8u32)?; serialize_seq.serialize_element(&inner)?; } - UniPrimitiveValue::F32(inner) => { + UniPrimitiveValue::I64(inner) => { serialize_seq.serialize_element(&9u32)?; serialize_seq.serialize_element(&inner)?; } - UniPrimitiveValue::F64(inner) => { + UniPrimitiveValue::I128(inner) => { serialize_seq.serialize_element(&10u32)?; serialize_seq.serialize_element(&inner)?; } - UniPrimitiveValue::Char(inner) => { + UniPrimitiveValue::F32(inner) => { serialize_seq.serialize_element(&11u32)?; serialize_seq.serialize_element(&inner)?; } - UniPrimitiveValue::String(inner) => { + UniPrimitiveValue::F64(inner) => { serialize_seq.serialize_element(&12u32)?; serialize_seq.serialize_element(&inner)?; } + + UniPrimitiveValue::Char(inner) => { + serialize_seq.serialize_element(&13u32)?; + serialize_seq.serialize_element(&inner)?; + } + + UniPrimitiveValue::String(inner) => { + serialize_seq.serialize_element(&14u32)?; + serialize_seq.serialize_element(&inner)?; + } } serialize_seq.end() } @@ -428,34 +478,48 @@ impl<'de> serde::de::Visitor<'de> for UniPrimitiveValueVisitor { } 8 => { + let value = seq + .next_element::()? + .map_or_else(|| Err(A::Error::invalid_length(1, &self)), Ok)?; + Ok(Self::Value::U128(value)) + } + + 9 => { let value = seq .next_element::()? .map_or_else(|| Err(A::Error::invalid_length(1, &self)), Ok)?; Ok(Self::Value::I64(value)) } - 9 => { + 10 => { + let value = seq + .next_element::()? + .map_or_else(|| Err(A::Error::invalid_length(1, &self)), Ok)?; + Ok(Self::Value::I128(value)) + } + + 11 => { let value = seq .next_element::()? .map_or_else(|| Err(A::Error::invalid_length(1, &self)), Ok)?; Ok(Self::Value::F32(value)) } - 10 => { + 12 => { let value = seq .next_element::()? .map_or_else(|| Err(A::Error::invalid_length(1, &self)), Ok)?; Ok(Self::Value::F64(value)) } - 11 => { + 13 => { let value = seq .next_element::()? .map_or_else(|| Err(A::Error::invalid_length(1, &self)), Ok)?; Ok(Self::Value::Char(value)) } - 12 => { + 14 => { let value = seq .next_element::()? .map_or_else(|| Err(A::Error::invalid_length(1, &self)), Ok)?; diff --git a/mudu_binding/src/universal/uni_dat_type_id.rs b/mudu_binding/src/universal/uni_dat_type_id.rs index 4776fe9..8021202 100644 --- a/mudu_binding/src/universal/uni_dat_type_id.rs +++ b/mudu_binding/src/universal/uni_dat_type_id.rs @@ -28,19 +28,23 @@ pub enum UniDatTypeId { I64 = 8, - F32 = 9, + OID = 9, - F64 = 10, + I128 = 10, - Char = 11, + F32 = 11, - String = 12, + F64 = 12, - Array = 13, + Char = 13, - Record = 14, + String = 14, - Binary = 15, + Array = 15, + + Record = 16, + + Binary = 17, } impl Default for UniDatTypeId { diff --git a/mudu_binding/src/universal/uni_dat_type_id_impl.rs b/mudu_binding/src/universal/uni_dat_type_id_impl.rs index 516f3d0..813300c 100644 --- a/mudu_binding/src/universal/uni_dat_type_id_impl.rs +++ b/mudu_binding/src/universal/uni_dat_type_id_impl.rs @@ -9,6 +9,8 @@ impl UniDatTypeId { let ty_id = match self { Self::I32 => DatTypeID::I32, Self::I64 => DatTypeID::I64, + Self::OID => DatTypeID::U128, + Self::I128 => DatTypeID::I128, Self::F32 => DatTypeID::F32, Self::F64 => DatTypeID::F64, Self::String => DatTypeID::String, @@ -25,6 +27,8 @@ impl UniDatTypeId { let uni_ty = match ty { DatTypeID::I32 => Self::I32, DatTypeID::I64 => Self::I64, + DatTypeID::U128 => Self::OID, + DatTypeID::I128 => Self::I128, DatTypeID::F32 => Self::F32, DatTypeID::F64 => Self::F64, DatTypeID::String => Self::String, diff --git a/mudu_binding/src/universal/uni_dat_value_impl.rs b/mudu_binding/src/universal/uni_dat_value_impl.rs index f4b33ad..3be1e12 100644 --- a/mudu_binding/src/universal/uni_dat_value_impl.rs +++ b/mudu_binding/src/universal/uni_dat_value_impl.rs @@ -32,7 +32,9 @@ impl UniDatValue { UniPrimitiveValue::U64(_) => { unimplemented!() } + UniPrimitiveValue::U128(v) => DatValue::from_u128(v), UniPrimitiveValue::I64(v) => DatValue::from_i64(v), + UniPrimitiveValue::I128(v) => DatValue::from_i128(v), UniPrimitiveValue::F32(v) => DatValue::from_f32(v), UniPrimitiveValue::F64(v) => DatValue::from_f64(v), UniPrimitiveValue::Char(_) => { @@ -72,6 +74,12 @@ impl UniDatValue { DatTypeID::I64 => { UniDatValue::from_primitive(UniPrimitiveValue::I64(dat_value.expect_i64().clone())) } + DatTypeID::I128 => UniDatValue::from_primitive(UniPrimitiveValue::I128( + dat_value.expect_i128().clone(), + )), + DatTypeID::U128 => UniDatValue::from_primitive(UniPrimitiveValue::U128( + dat_value.expect_u128().clone(), + )), DatTypeID::F32 => { UniDatValue::from_primitive(UniPrimitiveValue::F32(dat_value.expect_f32().clone())) } diff --git a/mudu_binding/src/universal/uni_primitive.rs b/mudu_binding/src/universal/uni_primitive.rs index 1dff67b..a0b53b4 100644 --- a/mudu_binding/src/universal/uni_primitive.rs +++ b/mudu_binding/src/universal/uni_primitive.rs @@ -26,17 +26,21 @@ pub enum UniPrimitive { U64 = 7, - I64 = 8, + U128 = 8, - F32 = 9, + I64 = 9, - F64 = 10, + I128 = 10, - Char = 11, + F32 = 11, - String = 12, + F64 = 12, - Blob = 13, + Char = 13, + + String = 14, + + Blob = 15, } impl Default for UniPrimitive { diff --git a/mudu_binding/src/universal/uni_primitive_impl.rs b/mudu_binding/src/universal/uni_primitive_impl.rs index 6ab6797..425a873 100644 --- a/mudu_binding/src/universal/uni_primitive_impl.rs +++ b/mudu_binding/src/universal/uni_primitive_impl.rs @@ -32,7 +32,9 @@ impl UniPrimitive { UniPrimitive::U64 => { unimplemented!() } + UniPrimitive::U128 => DatType::default_for(DatTypeID::U128), UniPrimitive::I64 => DatType::default_for(DatTypeID::I64), + UniPrimitive::I128 => DatType::default_for(DatTypeID::I128), UniPrimitive::F32 => DatType::default_for(DatTypeID::F32), UniPrimitive::F64 => DatType::default_for(DatTypeID::F64), UniPrimitive::Char => { @@ -48,6 +50,8 @@ impl UniPrimitive { let uni_prim = match ty.dat_type_id() { DatTypeID::I32 => Self::I32, DatTypeID::I64 => Self::I64, + DatTypeID::I128 => Self::I128, + DatTypeID::U128 => Self::U128, DatTypeID::F32 => Self::F32, DatTypeID::F64 => Self::F64, DatTypeID::String => Self::String, diff --git a/mudu_binding/src/universal/uni_primitive_value.rs b/mudu_binding/src/universal/uni_primitive_value.rs index d10227d..f10aa5d 100644 --- a/mudu_binding/src/universal/uni_primitive_value.rs +++ b/mudu_binding/src/universal/uni_primitive_value.rs @@ -17,8 +17,12 @@ pub enum UniPrimitiveValue { U64(u64), + U128(u128), + I64(i64), + I128(i128), + F32(f32), F64(f64), @@ -179,6 +183,24 @@ impl UniPrimitiveValue { } } + pub fn from_u128(inner: u128) -> Self { + Self::U128(inner) + } + + pub fn as_u128(&self) -> Option<&u128> { + match self { + Self::U128(inner) => Some(inner), + _ => None, + } + } + + pub fn expect_u128(&self) -> &u128 { + match self { + Self::U128(inner) => inner, + _ => unsafe { std::hint::unreachable_unchecked() }, + } + } + pub fn from_i64(inner: i64) -> Self { Self::I64(inner) } @@ -197,6 +219,24 @@ impl UniPrimitiveValue { } } + pub fn from_i128(inner: i128) -> Self { + Self::I128(inner) + } + + pub fn as_i128(&self) -> Option<&i128> { + match self { + Self::I128(inner) => Some(inner), + _ => None, + } + } + + pub fn expect_i128(&self) -> &i128 { + match self { + Self::I128(inner) => inner, + _ => unsafe { std::hint::unreachable_unchecked() }, + } + } + pub fn from_f32(inner: f32) -> Self { Self::F32(inner) } @@ -318,30 +358,40 @@ impl serde::Serialize for UniPrimitiveValue { serialize_seq.serialize_element(&inner)?; } - UniPrimitiveValue::I64(inner) => { + UniPrimitiveValue::U128(inner) => { serialize_seq.serialize_element(&8u32)?; serialize_seq.serialize_element(&inner)?; } - UniPrimitiveValue::F32(inner) => { + UniPrimitiveValue::I64(inner) => { serialize_seq.serialize_element(&9u32)?; serialize_seq.serialize_element(&inner)?; } - UniPrimitiveValue::F64(inner) => { + UniPrimitiveValue::I128(inner) => { serialize_seq.serialize_element(&10u32)?; serialize_seq.serialize_element(&inner)?; } - UniPrimitiveValue::Char(inner) => { + UniPrimitiveValue::F32(inner) => { serialize_seq.serialize_element(&11u32)?; serialize_seq.serialize_element(&inner)?; } - UniPrimitiveValue::String(inner) => { + UniPrimitiveValue::F64(inner) => { serialize_seq.serialize_element(&12u32)?; serialize_seq.serialize_element(&inner)?; } + + UniPrimitiveValue::Char(inner) => { + serialize_seq.serialize_element(&13u32)?; + serialize_seq.serialize_element(&inner)?; + } + + UniPrimitiveValue::String(inner) => { + serialize_seq.serialize_element(&14u32)?; + serialize_seq.serialize_element(&inner)?; + } } serialize_seq.end() } @@ -428,34 +478,48 @@ impl<'de> serde::de::Visitor<'de> for UniPrimitiveValueVisitor { } 8 => { + let value = seq + .next_element::()? + .map_or_else(|| Err(A::Error::invalid_length(1, &self)), Ok)?; + Ok(Self::Value::U128(value)) + } + + 9 => { let value = seq .next_element::()? .map_or_else(|| Err(A::Error::invalid_length(1, &self)), Ok)?; Ok(Self::Value::I64(value)) } - 9 => { + 10 => { + let value = seq + .next_element::()? + .map_or_else(|| Err(A::Error::invalid_length(1, &self)), Ok)?; + Ok(Self::Value::I128(value)) + } + + 11 => { let value = seq .next_element::()? .map_or_else(|| Err(A::Error::invalid_length(1, &self)), Ok)?; Ok(Self::Value::F32(value)) } - 10 => { + 12 => { let value = seq .next_element::()? .map_or_else(|| Err(A::Error::invalid_length(1, &self)), Ok)?; Ok(Self::Value::F64(value)) } - 11 => { + 13 => { let value = seq .next_element::()? .map_or_else(|| Err(A::Error::invalid_length(1, &self)), Ok)?; Ok(Self::Value::Char(value)) } - 12 => { + 14 => { let value = seq .next_element::()? .map_or_else(|| Err(A::Error::invalid_length(1, &self)), Ok)?; diff --git a/mudu_binding/wit/uni-dat-type-id.wit b/mudu_binding/wit/uni-dat-type-id.wit index f6485d5..8bc0039 100644 --- a/mudu_binding/wit/uni-dat-type-id.wit +++ b/mudu_binding/wit/uni-dat-type-id.wit @@ -9,6 +9,8 @@ enum uni-dat-type-id { i32, %u64, i64, + oid, + i128, %f32, %f64, %char, @@ -17,4 +19,4 @@ enum uni-dat-type-id { record, binary } -} \ No newline at end of file +} diff --git a/mudu_cli/src/client/async_client.rs b/mudu_cli/src/client/async_client.rs index 1d2d960..0331c27 100644 --- a/mudu_cli/src/client/async_client.rs +++ b/mudu_cli/src/client/async_client.rs @@ -9,7 +9,7 @@ use mudu_contract::protocol::{ SessionCreateResponse, decode_error_response, decode_get_response, decode_procedure_invoke_response, decode_put_response, decode_range_scan_response, decode_server_response, decode_session_close_response, decode_session_create_response, - encode_client_request_with_message_type, encode_get_request, encode_procedure_invoke_request, + encode_batch_request, encode_client_request_with_message_type, encode_get_request, encode_procedure_invoke_request, encode_put_request, encode_range_scan_request, encode_session_close_request, encode_session_create_request, }; @@ -20,6 +20,7 @@ use tokio::net::TcpStream; pub trait AsyncClient: Send { async fn query(&mut self, request: ClientRequest) -> RS; async fn execute(&mut self, request: ClientRequest) -> RS; + async fn batch(&mut self, request: ClientRequest) -> RS; async fn get(&mut self, request: GetRequest) -> RS; async fn put(&mut self, request: PutRequest) -> RS; async fn range_scan(&mut self, request: RangeScanRequest) -> RS; @@ -123,6 +124,12 @@ impl AsyncClient for AsyncClientImpl { decode_server_response(&frame) } + async fn batch(&mut self, request: ClientRequest) -> RS { + let payload = encode_batch_request(self.take_request_id(), &request)?; + let frame = self.send_and_receive(&payload).await?; + decode_server_response(&frame) + } + async fn get(&mut self, request: GetRequest) -> RS { let payload = encode_get_request(self.take_request_id(), &request)?; let frame = self.send_and_receive(&payload).await?; diff --git a/mudu_cli/src/client/client.rs b/mudu_cli/src/client/client.rs index 6e2e8b6..682d1b9 100644 --- a/mudu_cli/src/client/client.rs +++ b/mudu_cli/src/client/client.rs @@ -6,7 +6,7 @@ use mudu_contract::protocol::{ PutRequest, RangeScanRequest, ServerResponse, SessionCloseRequest, SessionCreateRequest, decode_error_response, decode_get_response, decode_procedure_invoke_response, decode_put_response, decode_range_scan_response, decode_server_response, - decode_session_close_response, decode_session_create_response, encode_client_request, + decode_session_close_response, decode_session_create_response, encode_batch_request, encode_client_request, encode_client_request_with_message_type, encode_get_request, encode_procedure_invoke_request, encode_put_request, encode_range_scan_request, encode_session_close_request, encode_session_create_request, @@ -59,6 +59,19 @@ impl SyncClient { decode_server_response(&frame) } + pub fn batch( + &mut self, + app_name: impl Into, + sql: impl Into, + ) -> RS { + let request_id = self.take_request_id(); + let request = ClientRequest::new(app_name, sql); + let payload = encode_batch_request(request_id, &request)?; + let frame = self.send_and_receive(&payload)?; + self.ensure_success_frame(&frame)?; + decode_server_response(&frame) + } + pub fn get(&mut self, session_id: u128, key: impl Into>) -> RS>> { let request_id = self.take_request_id(); let payload = encode_get_request(request_id, &GetRequest::new(session_id, key.into()))?; diff --git a/mudu_cli/src/client/json_client.rs b/mudu_cli/src/client/json_client.rs index 2abf810..bcf6c67 100644 --- a/mudu_cli/src/client/json_client.rs +++ b/mudu_cli/src/client/json_client.rs @@ -248,7 +248,9 @@ fn uni_dat_value_to_json_value(value: UniDatValue) -> RS { UniPrimitiveValue::U32(v) => Ok(json!(v)), UniPrimitiveValue::I32(v) => Ok(json!(v)), UniPrimitiveValue::U64(v) => Ok(json!(v)), + UniPrimitiveValue::U128(v) => Ok(Value::String(v.to_string())), UniPrimitiveValue::I64(v) => Ok(json!(v)), + UniPrimitiveValue::I128(v) => Ok(Value::String(v.to_string())), UniPrimitiveValue::F32(v) => Ok(json!(v)), UniPrimitiveValue::F64(v) => Ok(json!(v)), UniPrimitiveValue::Char(v) => Ok(json!(v.to_string())), @@ -314,6 +316,7 @@ mod tests { struct MockAsyncIoUringTcpClient { last_query: Option, last_execute: Option, + last_batch: Option, last_get: Option, last_put: Option, last_range: Option, @@ -325,6 +328,7 @@ mod tests { Self { last_query: None, last_execute: None, + last_batch: None, last_get: None, last_put: None, last_range: None, @@ -350,6 +354,11 @@ mod tests { Ok(ServerResponse::new(vec![], vec![], 2, None)) } + async fn batch(&mut self, request: ClientRequest) -> RS { + self.last_batch = Some(request); + Ok(ServerResponse::new(vec![], vec![], 3, None)) + } + async fn get(&mut self, request: GetRequest) -> RS { self.last_get = Some(request); Ok(GetResponse::new(Some( diff --git a/mudu_contract/src/protocol.rs b/mudu_contract/src/protocol.rs index 98428a8..bacab2e 100644 --- a/mudu_contract/src/protocol.rs +++ b/mudu_contract/src/protocol.rs @@ -13,14 +13,15 @@ pub enum MessageType { Auth = 2, Query = 3, Execute = 4, - Response = 5, - Error = 6, - Get = 7, - Put = 8, - RangeScan = 9, - ProcedureInvoke = 10, - SessionCreate = 11, - SessionClose = 12, + Batch = 5, + Response = 6, + Error = 7, + Get = 8, + Put = 9, + RangeScan = 10, + ProcedureInvoke = 11, + SessionCreate = 12, + SessionClose = 13, } impl From for u16 { @@ -38,14 +39,15 @@ impl TryFrom for MessageType { 2 => Ok(MessageType::Auth), 3 => Ok(MessageType::Query), 4 => Ok(MessageType::Execute), - 5 => Ok(MessageType::Response), - 6 => Ok(MessageType::Error), - 7 => Ok(MessageType::Get), - 8 => Ok(MessageType::Put), - 9 => Ok(MessageType::RangeScan), - 10 => Ok(MessageType::ProcedureInvoke), - 11 => Ok(MessageType::SessionCreate), - 12 => Ok(MessageType::SessionClose), + 5 => Ok(MessageType::Batch), + 6 => Ok(MessageType::Response), + 7 => Ok(MessageType::Error), + 8 => Ok(MessageType::Get), + 9 => Ok(MessageType::Put), + 10 => Ok(MessageType::RangeScan), + 11 => Ok(MessageType::ProcedureInvoke), + 12 => Ok(MessageType::SessionCreate), + 13 => Ok(MessageType::SessionClose), _ => Err(m_error!( EC::ParseErr, format!("unknown message type {}", value) @@ -66,6 +68,7 @@ pub struct FrameHeader { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ClientRequest { + oid: u128, app_name: String, sql: String, } @@ -287,11 +290,24 @@ impl FrameHeader { impl ClientRequest { pub fn new(app_name: impl Into, sql: impl Into) -> Self { Self { + oid: 0, app_name: app_name.into(), sql: sql.into(), } } + pub fn new_with_oid(oid: u128, app_name: impl Into, sql: impl Into) -> Self { + Self { + oid, + app_name: app_name.into(), + sql: sql.into(), + } + } + + pub fn oid(&self) -> u128 { + self.oid + } + pub fn app_name(&self) -> &str { &self.app_name } @@ -560,6 +576,10 @@ pub fn decode_client_request(frame: &Frame) -> RS { decode_payload(frame.payload(), "decode client request error") } +pub fn encode_batch_request(request_id: u64, request: &ClientRequest) -> RS> { + encode_client_request_with_message_type(MessageType::Batch, request_id, request) +} + pub fn encode_server_response(request_id: u64, response: &ServerResponse) -> RS> { let payload = encode_payload(response, "encode server response error")?; Ok(Frame::new(MessageType::Response, request_id, payload).encode()) diff --git a/mudu_gen/src/lang_impl/csharp/lang_def.rs b/mudu_gen/src/lang_impl/csharp/lang_def.rs index f95afeb..2109025 100644 --- a/mudu_gen/src/lang_impl/csharp/lang_def.rs +++ b/mudu_gen/src/lang_impl/csharp/lang_def.rs @@ -14,10 +14,12 @@ impl_primitive! { (U16, "ushort"), (U32, "uint"), (U64, "ulong"), + (U128, "Mudu.OID"), (I8, "sbyte"), (I16, "short"), (I32, "int"), (I64, "long"), + (I128, "Int128"), (F32, "float"), (F64, "double"), (Char, "char"), diff --git a/mudu_gen/src/lang_impl/lang/lang_data_type.rs b/mudu_gen/src/lang_impl/lang/lang_data_type.rs index 7ae1c2e..c51d702 100644 --- a/mudu_gen/src/lang_impl/lang/lang_data_type.rs +++ b/mudu_gen/src/lang_impl/lang/lang_data_type.rs @@ -18,10 +18,12 @@ pub fn csharp_default_value_expr(wit_ty: &UniDatType) -> RS { UniPrimitive::U16 => "0".to_string(), UniPrimitive::U32 => "0".to_string(), UniPrimitive::U64 => "0".to_string(), + UniPrimitive::U128 => "default".to_string(), UniPrimitive::I8 => "0".to_string(), UniPrimitive::I16 => "0".to_string(), UniPrimitive::I32 => "0".to_string(), UniPrimitive::I64 => "0".to_string(), + UniPrimitive::I128 => "0".to_string(), UniPrimitive::F32 => "0".to_string(), UniPrimitive::F64 => "0".to_string(), UniPrimitive::Char => "'\\0'".to_string(), diff --git a/mudu_gen/src/lang_impl/rust/lang_def.rs b/mudu_gen/src/lang_impl/rust/lang_def.rs index 89e0e6a..0141232 100644 --- a/mudu_gen/src/lang_impl/rust/lang_def.rs +++ b/mudu_gen/src/lang_impl/rust/lang_def.rs @@ -14,10 +14,12 @@ impl_primitive! { (U16, "u16"), (U32, "u32"), (U64, "u64"), + (U128, "OID"), (I8, "i8"), (I16, "i16"), (I32, "i32"), (I64, "i64"), + (I128, "i128"), (F32, "f32"), (F64, "f64"), (Char, "char"), diff --git a/mudu_kernel/src/collection/hash_map.rs b/mudu_kernel/src/collection/hash_map.rs deleted file mode 100644 index c3cb165..0000000 --- a/mudu_kernel/src/collection/hash_map.rs +++ /dev/null @@ -1,77 +0,0 @@ -use mudu::common::result::RS; -use scc::HashMap; -use std::hash::Hash; - -pub async fn hash_map_async_get_or_create( - scc_hash_map: &HashMap, - key: K, - create: C, - call: T, -) -> RS -where - K: Hash + Eq + Copy + 'static, - V: Clone + 'static, - C: Fn() -> V + 'static, - T: for<'r> Fn(&'r V) -> Option + 'static, -{ - let mut key = key; - loop { - let opt = scc_hash_map.get_async(&key).await; - let value = match opt { - Some(e) => e.get().clone(), - None => { - let v_created = create(); - let opt = scc_hash_map.insert_async(key, v_created.clone()).await; - match opt { - Ok(_) => v_created, - Err((k, v)) => { - key = k; - v - } - } - } - }; - let r = call(&value); - match r { - Some(r) => return Ok(r), - None => {} - } - } -} - -pub fn hash_map_get_or_create( - scc_hash_map: &HashMap, - key: K, - creater: C, - call: T, -) -> RS -where - K: Hash + Eq + Clone + 'static, - V: Clone + 'static, - C: Fn() -> V + 'static, - T: for<'r> Fn(&'r V) -> Option + 'static, -{ - let mut key = key; - loop { - let opt = scc_hash_map.get_sync(&key); - let v = match opt { - Some(e) => e.get().clone(), - None => { - let v = creater(); - let _opt = scc_hash_map.insert_sync(key.clone(), v.clone()); - match _opt { - Ok(_) => v, - Err((_k, _v)) => { - key = _k; - v - } - } - } - }; - let r = call(&v); - match r { - Some(r) => return Ok(r), - None => {} - } - } -} diff --git a/mudu_kernel/src/collection/mod.rs b/mudu_kernel/src/collection/mod.rs deleted file mode 100644 index 9a16d54..0000000 --- a/mudu_kernel/src/collection/mod.rs +++ /dev/null @@ -1,4 +0,0 @@ -#![allow(dead_code)] - -pub mod hash_map; -pub mod tree_map; diff --git a/mudu_kernel/src/collection/tree_map.rs b/mudu_kernel/src/collection/tree_map.rs deleted file mode 100644 index 813fa3d..0000000 --- a/mudu_kernel/src/collection/tree_map.rs +++ /dev/null @@ -1,88 +0,0 @@ -use scc::Guard; -use scc::TreeIndex; - -pub async fn tree_map_async_get_or_create( - scc_tree_map: &TreeIndex, - key: K, - create: C, - is_valid: T, -) -> Option -where - K: Ord + Eq + Copy + 'static, - V: Clone + 'static, - C: Fn() -> V + 'static, - T: for<'r> Fn(&'r V) -> bool + 'static, -{ - let mut key = key; - let mut value = None; - loop { - let guard = Guard::new(); - let opt = scc_tree_map.peek(&key, &guard); - match opt { - Some(e) => { - if is_valid(e) { - return Some(e.clone()); - } - } - None => { - let v = match value { - Some(v) => v, - None => create(), - }; - let opt = scc_tree_map.insert_async(key, v.clone()).await; - match opt { - Ok(_) => { - return Some(v); - } - Err((k, v)) => { - key = k; - value = Some(v) - } - } - } - } - } -} - -pub fn tree_map_get_or_create( - scc_tree_map: &TreeIndex, - key: K, - creater: C, - is_valid: T, -) -> Option -where - K: Ord + Eq + Copy + 'static, - V: Clone + 'static, - C: Fn() -> V + 'static, - T: for<'r> Fn(&'r V) -> bool + 'static, -{ - let mut key = key; - let mut value = None; - loop { - let guard = Guard::new(); - let opt = scc_tree_map.peek(&key, &guard); - match opt { - Some(e) => { - if is_valid(e) { - return Some(e.clone()); - } - } - None => { - let v = match value { - Some(v) => v, - None => creater(), - }; - let opt = scc_tree_map.insert_sync(key, v.clone()); - match opt { - Ok(_) => { - return Some(v); - } - Err((k, v)) => { - key = k; - value = Some(v) - } - } - } - } - } -} diff --git a/mudu_kernel/src/command/create_table.rs b/mudu_kernel/src/command/create_table.rs index bf3fc5f..ab15442 100644 --- a/mudu_kernel/src/command/create_table.rs +++ b/mudu_kernel/src/command/create_table.rs @@ -84,7 +84,7 @@ impl _InnerCreateTable { async fn run(&mut self) -> RS<()> { task_trace!(); self.x_contract - .create_table(self.param.xid, &self.param.schema) + .create_table(self.param.tx_mgr.clone(), &self.param.schema) .await } } diff --git a/mudu_kernel/src/command/delete_key_value.rs b/mudu_kernel/src/command/delete_key_value.rs index 9454280..338ad98 100644 --- a/mudu_kernel/src/command/delete_key_value.rs +++ b/mudu_kernel/src/command/delete_key_value.rs @@ -59,7 +59,7 @@ impl _DeleteKeyValue { let deleted = self .x_contract .delete( - self.param.xid, + self.param.tx_mgr.clone(), self.param.table_id, &self.param.key, &Predicate::CNF(Vec::new()), diff --git a/mudu_kernel/src/command/drop_table.rs b/mudu_kernel/src/command/drop_table.rs index dae7aaa..0d1169e 100644 --- a/mudu_kernel/src/command/drop_table.rs +++ b/mudu_kernel/src/command/drop_table.rs @@ -40,7 +40,7 @@ impl CmdExec for DropTable { task_trace!(); if let Some(table_id) = self.drop_param.oid { self.x_contract - .drop_table(self.drop_param.xid, table_id) + .drop_table(self.drop_param.tx_mgr.clone(), table_id) .await?; } Ok(()) diff --git a/mudu_kernel/src/command/insert_key_value.rs b/mudu_kernel/src/command/insert_key_value.rs index f326c51..1f58a72 100644 --- a/mudu_kernel/src/command/insert_key_value.rs +++ b/mudu_kernel/src/command/insert_key_value.rs @@ -79,7 +79,7 @@ impl _InsertKeyValue { task_trace!(); self.x_contract .insert( - self.param.xid, + self.param.tx_mgr.clone(), self.param.table_id, &self.param.key, &self.param.value, diff --git a/mudu_kernel/src/command/load_from_file.rs b/mudu_kernel/src/command/load_from_file.rs index 0707e70..562ddee 100644 --- a/mudu_kernel/src/command/load_from_file.rs +++ b/mudu_kernel/src/command/load_from_file.rs @@ -1,6 +1,7 @@ use crate::contract::cmd_exec::CmdExec; use crate::contract::meta_mgr::MetaMgr; use crate::x_engine::api::{OptInsert, VecDatum, XContract}; +use crate::x_engine::tx_mgr::TxMgr; use async_std::fs::File; use async_trait::async_trait; use csv_async::StringRecord; @@ -8,7 +9,6 @@ use futures::StreamExt; use mudu::common::buf::Buf; use mudu::common::id::OID; use mudu::common::result::RS; -use mudu::common::xid::XID; use mudu::error::ec::EC as ER; use mudu::m_error; use std::sync::Arc; @@ -20,7 +20,7 @@ pub struct LoadFromFile { struct _LoadFromFile { csv_file: String, - xid: XID, + tx_mgr: Arc, table_id: OID, key_index: Vec, value_index: Vec, @@ -32,7 +32,7 @@ struct _LoadFromFile { impl LoadFromFile { pub fn new( csv_file: String, - xid: XID, + tx_mgr: Arc, table_id: OID, key_index: Vec, value_index: Vec, @@ -42,7 +42,7 @@ impl LoadFromFile { Self { inner: Arc::new(Mutex::new(_LoadFromFile::new( csv_file, - xid, + tx_mgr, table_id, key_index, value_index, @@ -56,7 +56,7 @@ impl LoadFromFile { impl _LoadFromFile { fn new( csv_file: String, - xid: XID, + tx_mgr: Arc, table_id: OID, key_index: Vec, value_index: Vec, @@ -65,7 +65,7 @@ impl _LoadFromFile { ) -> Self { Self { csv_file, - xid, + tx_mgr, table_id, key_index, value_index, @@ -115,15 +115,26 @@ impl _LoadFromFile { )); } - let key = Self::build_datum_from_line(&record, &self.key_index, &table_desc, 0)?; + let key = Self::build_datum_from_line( + &record, + &self.key_index, + table_desc.key_indices(), + &table_desc, + )?; let value = Self::build_datum_from_line( &record, &self.value_index, + table_desc.value_indices(), &table_desc, - table_desc.key_info().len(), )?; self.x_contract - .insert(self.xid, self.table_id, &key, &value, &OptInsert::default()) + .insert( + self.tx_mgr.clone(), + self.table_id, + &key, + &value, + &OptInsert::default(), + ) .await?; rows += 1; } @@ -141,15 +152,16 @@ impl _LoadFromFile { fn build_datum_from_line( record: &StringRecord, csv_index: &[usize], + attr_indices: &[usize], table_desc: &crate::contract::table_desc::TableDesc, - attr_base: usize, ) -> RS { let mut datum = Vec::with_capacity(csv_index.len()); for (position, csv_col) in csv_index.iter().enumerate() { let textual = record .get(*csv_col) .ok_or_else(|| m_error!(ER::IndexOutOfRange))?; - let field = table_desc.get_attr(attr_base + position); + let attr_index = attr_indices[position]; + let field = table_desc.get_attr(attr_index); let dat_type = field.type_desc(); let dat_id = dat_type.dat_type_id(); let internal = dat_id.fn_input()(textual, dat_type) @@ -157,7 +169,7 @@ impl _LoadFromFile { let binary: Buf = dat_id.fn_send()(&internal, dat_type) .map_err(|e| m_error!(ER::TypeBaseErr, "converting internal to binary error", e))? .into(); - datum.push((attr_base + position, binary)); + datum.push((attr_index, binary)); } Ok(VecDatum::new(datum)) } diff --git a/mudu_kernel/src/command/save_to_file.rs b/mudu_kernel/src/command/save_to_file.rs index d4d84c6..9084dd7 100644 --- a/mudu_kernel/src/command/save_to_file.rs +++ b/mudu_kernel/src/command/save_to_file.rs @@ -2,12 +2,12 @@ use crate::contract::cmd_exec::CmdExec; use crate::contract::meta_mgr::MetaMgr; use crate::contract::table_desc::TableDesc; use crate::x_engine::api::{OptRead, Predicate, RangeData, VecSelTerm, XContract}; +use crate::x_engine::tx_mgr::TxMgr; use async_std::fs::File; use async_trait::async_trait; use csv_async::AsyncWriter; use mudu::common::id::OID; use mudu::common::result::RS; -use mudu::common::xid::XID; use mudu::error::ec::EC as ER; use mudu::m_error; use mudu_contract::tuple::datum_desc::DatumDesc; @@ -21,7 +21,7 @@ pub struct SaveToFile { struct _SaveToFile { file_path: String, - xid: XID, + tx_mgr: Arc, table_id: OID, key_indexing: Vec, value_indexing: Vec, @@ -33,7 +33,7 @@ struct _SaveToFile { impl SaveToFile { pub fn new( file_path: String, - xid: XID, + tx_mgr: Arc, table_id: OID, key_indexing: Vec, value_indexing: Vec, @@ -43,7 +43,7 @@ impl SaveToFile { Self { inner: AMutex::new(_SaveToFile::new( file_path, - xid, + tx_mgr, table_id, key_indexing, value_indexing, @@ -77,7 +77,7 @@ impl CmdExec for SaveToFile { impl _SaveToFile { fn new( file_path: String, - xid: XID, + tx_mgr: Arc, table_id: OID, key_indexing: Vec, value_indexing: Vec, @@ -86,7 +86,7 @@ impl _SaveToFile { ) -> Self { Self { file_path, - xid, + tx_mgr, table_id, key_indexing, value_indexing, @@ -114,7 +114,7 @@ impl _SaveToFile { let cursor = self .x_contract .read_range( - self.xid, + self.tx_mgr.clone(), self.table_id, &RangeData::new(Bound::Unbounded, Bound::Unbounded), &Predicate::CNF(Vec::new()), @@ -186,12 +186,12 @@ impl _SaveToFile { } fn build_select(table_desc: &TableDesc) -> VecSelTerm { - let total = table_desc.key_info().len() + table_desc.value_info().len(); + let total = table_desc.fields().len(); VecSelTerm::new((0..total).collect()) } fn build_output_desc(table_desc: &TableDesc) -> Vec { - let total = table_desc.key_info().len() + table_desc.value_info().len(); + let total = table_desc.fields().len(); (0..total) .map(|attr| { let field = table_desc.get_attr(attr); @@ -201,7 +201,7 @@ impl _SaveToFile { } fn build_header(table_desc: &TableDesc) -> Vec { - let total = table_desc.key_info().len() + table_desc.value_info().len(); + let total = table_desc.fields().len(); (0..total) .map(|attr| table_desc.get_attr(attr).name().clone()) .collect() diff --git a/mudu_kernel/src/command/update_key_value.rs b/mudu_kernel/src/command/update_key_value.rs index 3df2f03..bdb2f44 100644 --- a/mudu_kernel/src/command/update_key_value.rs +++ b/mudu_kernel/src/command/update_key_value.rs @@ -62,7 +62,7 @@ impl _UpdateKeyValue { let updated = self .x_contract .update( - self.param.xid, + self.param.tx_mgr.clone(), self.param.table_id, &self.param.key, &Predicate::CNF(Vec::new()), diff --git a/mudu_kernel/src/contract/field_info.rs b/mudu_kernel/src/contract/field_info.rs index 62489f5..7cd1468 100644 --- a/mudu_kernel/src/contract/field_info.rs +++ b/mudu_kernel/src/contract/field_info.rs @@ -1,4 +1,4 @@ -use mudu::common::id::OID; +use mudu::common::id::{AttrIndex, DatumIndex, OID}; use mudu_type::dt_fn_param::DatType; #[derive(Clone, Debug, Default)] @@ -7,21 +7,28 @@ pub struct FieldInfo { id: OID, type_desc: DatType, // index in key or value tuple - datum_index: usize, + datum_index: DatumIndex, // index in original create table column definition list - column_index: usize, - is_primary: bool, + column_index: AttrIndex, + primary_index: Option, } impl FieldInfo { - pub fn new(name: String, id: OID, type_desc: DatType, index: usize, is_primary: bool) -> Self { + pub fn new( + name: String, + id: OID, + type_desc: DatType, + datum_index: DatumIndex, + column_index: AttrIndex, + primary_index: Option, + ) -> Self { Self { name, id, type_desc, - datum_index: index, - column_index: index, - is_primary, + datum_index, + column_index, + primary_index, } } @@ -33,19 +40,23 @@ impl FieldInfo { self.id } - pub fn column_index(&self) -> usize { + pub fn column_index(&self) -> AttrIndex { self.column_index } pub fn is_primary(&self) -> bool { - self.is_primary + self.primary_index.is_some() } - pub fn datum_index(&self) -> usize { + pub fn primary_index(&self) -> Option { + self.primary_index + } + + pub fn datum_index(&self) -> DatumIndex { self.datum_index } - pub fn set_datum_index(&mut self, index: usize) { + pub fn set_datum_index(&mut self, index: DatumIndex) { self.datum_index = index; } diff --git a/mudu_kernel/src/contract/meta_mgr.rs b/mudu_kernel/src/contract/meta_mgr.rs index dd79a98..eb43be6 100644 --- a/mudu_kernel/src/contract/meta_mgr.rs +++ b/mudu_kernel/src/contract/meta_mgr.rs @@ -15,4 +15,8 @@ pub trait MetaMgr: Send + Sync { async fn create_table(&self, schema: &SchemaTable) -> RS<()>; async fn drop_table(&self, table_id: OID) -> RS<()>; + + async fn list_schemas(&self) -> RS> { + Ok(Vec::new()) + } } diff --git a/mudu_kernel/src/contract/mod.rs b/mudu_kernel/src/contract/mod.rs index 43767c0..442c2b6 100644 --- a/mudu_kernel/src/contract/mod.rs +++ b/mudu_kernel/src/contract/mod.rs @@ -23,11 +23,5 @@ pub mod timestamp; pub mod version_delta; pub mod version_tuple; pub mod waiter; -pub mod xl_batch; -pub mod xl_chunk; -mod xl_d_delete; -mod xl_d_insert; pub mod xl_d_up_tuple; -mod xl_d_update; -mod xl_op; -pub mod xl_rec; +mod worker_snapshot; diff --git a/mudu_kernel/src/contract/schema_column.rs b/mudu_kernel/src/contract/schema_column.rs index a9cc8cd..4d530fc 100644 --- a/mudu_kernel/src/contract/schema_column.rs +++ b/mudu_kernel/src/contract/schema_column.rs @@ -1,6 +1,6 @@ #[cfg(any(test, feature = "test"))] use arbitrary::{Arbitrary, Unstructured}; -use mudu::common::id::{gen_oid, OID}; +use mudu::common::id::{gen_oid, AttrIndex, OID}; use mudu_type::dat_type_id::DatTypeID as TypeID; use mudu_type::dt_info::DTInfo; use serde::{Deserialize, Serialize}; @@ -11,8 +11,8 @@ pub struct SchemaColumn { name: String, type_id: TypeID, type_param: DTInfo, - index: u32, - is_primary: bool, + index: AttrIndex, + is_primary: Option, } impl SchemaColumn { @@ -24,7 +24,19 @@ impl SchemaColumn { type_param: type_param.clone(), index: 0, - is_primary: false, + is_primary: None, + } + } + + pub fn new_with_oid(oid: OID, name: String, data_type: TypeID, type_param: DTInfo) -> Self { + Self { + oid, + name, + type_id: data_type, + type_param: type_param.clone(), + + index: 0, + is_primary: None, } } @@ -37,18 +49,22 @@ impl SchemaColumn { } pub fn is_primary(&self) -> bool { + self.is_primary.is_some() + } + + pub fn primary_index(&self) -> Option { self.is_primary } - pub fn set_primary(&mut self, is_primary: bool) { - self.is_primary = is_primary; + pub fn set_primary_index(&mut self, index: Option) { + self.is_primary = index; } - pub fn get_index(&self) -> u32 { + pub fn get_index(&self) -> AttrIndex { self.index } - pub fn set_index(&mut self, index: u32) { + pub fn set_index(&mut self, index: AttrIndex) { self.index = index; } diff --git a/mudu_kernel/src/contract/schema_table.rs b/mudu_kernel/src/contract/schema_table.rs index 49d9151..5a76bed 100644 --- a/mudu_kernel/src/contract/schema_table.rs +++ b/mudu_kernel/src/contract/schema_table.rs @@ -2,7 +2,7 @@ use crate::contract::field_info::FieldInfo; use crate::contract::schema_column::SchemaColumn; #[cfg(any(test, feature = "test"))] use arbitrary::{Arbitrary, Unstructured}; -use mudu::common::id::{gen_oid, OID}; +use mudu::common::id::{gen_oid, AttrIndex, DatumIndex, OID}; use mudu::common::result::RS; use mudu_contract::tuple::tuple_binary_desc::TupleBinaryDesc as TupleDesc; use serde::{Deserialize, Serialize}; @@ -13,27 +13,36 @@ use test_utils::_arb_limit; pub struct SchemaTable { oid: OID, table_name: String, - key_columns: Vec, - value_columns: Vec, + columns: Vec, + key_indices: Vec, + value_indices: Vec, } -pub fn schema_columns_to_tuple_desc(fields: &[SchemaColumn]) -> RS<(TupleDesc, Vec)> { - let mut desc = Vec::with_capacity(fields.len()); - for (i, sc) in fields.iter().enumerate() { +// Build a tuple descriptor from a key/value column slice. +// The input AttrIndex is the original column order in the table schema, +// while the generated FieldInfo.datum_index is the position inside this tuple. +pub fn schema_columns_to_tuple_desc( + fields: Vec<(AttrIndex, &SchemaColumn)>, +) -> RS<(TupleDesc, Vec)> { + let field_count = fields.len(); + let mut desc = Vec::with_capacity(field_count); + for (_, (column_index, sc)) in fields.into_iter().enumerate() { let ty = sc.type_param().to_dat_type()?; let field_info = FieldInfo::new( sc.get_name().clone(), sc.get_oid(), ty.clone(), - i, - sc.is_primary(), + DatumIndex::MAX, // set an invalid index + column_index, + sc.primary_index(), ); desc.push((ty, field_info)) } - assert_eq!(desc.len(), fields.len()); + assert_eq!(desc.len(), field_count); let (vec_tuple_desc, mut vec_payload) = TupleDesc::normalized_type_desc_vec(desc)?; for (i, f) in vec_payload.iter_mut().enumerate() { + // set its real index f.set_datum_index(i); } let tuple_desc = TupleDesc::from(vec_tuple_desc)?; @@ -48,40 +57,65 @@ impl<'a> Arbitrary<'a> for SchemaTable { let v2 = u32::arbitrary(u)?; let n1 = v1 % _arb_limit::_ARB_MAX_TUPLE_KEY_FIELD as u32 + 1; let n2 = v2 % _arb_limit::_ARB_MAX_TUPLE_VALUE_FIELD as u32 + 1; - let mut primary_key_fields = vec![]; - let mut value_fields = vec![]; + let mut columns = vec![]; + let mut key_indices = vec![]; + let mut value_indices = vec![]; for _i in 0..n1 { let s = SchemaColumn::arbitrary(u)?; - primary_key_fields.push(s); + key_indices.push(columns.len()); + columns.push(s); } for _i in 0..n2 { let s = SchemaColumn::arbitrary(u)?; - value_fields.push(s); + value_indices.push(columns.len()); + columns.push(s); } - let schema = SchemaTable::new(name, primary_key_fields, value_fields); + let schema = SchemaTable::new(name, columns, key_indices, value_indices); Ok(schema) } } impl SchemaTable { + // `columns` shall preserve the original column order of the table schema. + // `key_indices` / `value_indices` shall reference entries in `columns` via AttrIndex. + // This constructor shall be used only for new schema creation. + // During recovery, the schema shall be loaded from storage and deserialized. + // For any given SchemaTable value, TableInfo::new(...).table_desc() deterministically + // yields an identical field mapping and identical index semantics. + // Each SchemaColumn.index shall be normalized to its position within the + // corresponding key or value tuple. pub fn new( table_name: String, - key_columns: Vec, - value_columns: Vec, + columns: Vec, + key_indices: Vec, + value_indices: Vec, + ) -> Self { + Self::new_with_oid(gen_oid(), table_name, columns, key_indices, value_indices) + } + + pub fn new_with_oid( + oid: OID, + table_name: String, + columns: Vec, + key_indices: Vec, + value_indices: Vec, ) -> Self { let mut s = SchemaTable { - oid: gen_oid(), + oid, table_name, - key_columns, - value_columns, + columns, + key_indices, + value_indices, }; - for (i, sc) in s.key_columns.iter_mut().enumerate() { - sc.set_primary(true); - sc.set_index(i as u32); + for (i, index) in s.key_indices.iter().copied().enumerate() { + let sc = &mut s.columns[index]; + sc.set_primary_index(Some(i as AttrIndex)); + sc.set_index(i as AttrIndex); } - for (i, sc) in s.value_columns.iter_mut().enumerate() { - sc.set_primary(false); - sc.set_index(i as u32); + for (i, index) in s.value_indices.iter().copied().enumerate() { + let sc = &mut s.columns[index]; + sc.set_primary_index(None); + sc.set_index(i as AttrIndex); } s } @@ -94,19 +128,51 @@ impl SchemaTable { &self.table_name } - pub fn key_columns(&self) -> &Vec { - &self.key_columns + pub fn columns(&self) -> &Vec { + &self.columns + } + + pub fn column_by_index(&self, index: AttrIndex) -> &SchemaColumn { + &self.columns[index] + } + + pub fn key_indices(&self) -> &Vec { + &self.key_indices + } + + pub fn value_indices(&self) -> &Vec { + &self.value_indices + } + + pub fn key_columns(&self) -> Vec<&SchemaColumn> { + self.key_indices + .iter() + .map(|index| &self.columns[*index]) + .collect() } - pub fn value_columns(&self) -> &Vec { - &self.value_columns + pub fn value_columns(&self) -> Vec<&SchemaColumn> { + self.value_indices + .iter() + .map(|index| &self.columns[*index]) + .collect() } pub fn key_tuple_desc(&self) -> RS<(TupleDesc, Vec)> { - schema_columns_to_tuple_desc(&self.key_columns) + schema_columns_to_tuple_desc( + self.key_indices + .iter() + .map(|index| (*index, &self.columns[*index])) + .collect(), + ) } pub fn value_tuple_desc(&self) -> RS<(TupleDesc, Vec)> { - schema_columns_to_tuple_desc(&self.value_columns) + schema_columns_to_tuple_desc( + self.value_indices + .iter() + .map(|index| (*index, &self.columns[*index])) + .collect(), + ) } } diff --git a/mudu_kernel/src/contract/table_desc.rs b/mudu_kernel/src/contract/table_desc.rs index e20ce0e..31a9cfb 100644 --- a/mudu_kernel/src/contract/table_desc.rs +++ b/mudu_kernel/src/contract/table_desc.rs @@ -10,13 +10,12 @@ pub struct TableDesc { key_oid: Vec, value_oid: Vec, - // use AttrIndex index to access key/value - // [0 -- N ] , key datum, if 0 <= AttrIndex < N, this index would be key - // [N + 1 -- M ] , value datum, if N <= AttrIndex < M, this index would be value + // AttrIndex is the column order in the original table definition. key_desc: TupleDesc, value_desc: TupleDesc, - key_info: Vec, - value_info: Vec, + fields: Vec, + key_indices: Vec, + value_indices: Vec, name2oid: HashMap, oid2col: HashMap, column_oid: Vec, @@ -28,25 +27,15 @@ impl TableDesc { oid: OID, key_oid: Vec, value_oid: Vec, + key_indices: Vec, + value_indices: Vec, + fields: Vec, key_desc: TupleDesc, value_desc: TupleDesc, name2oid: HashMap, oid2col: HashMap, ) -> Self { - let mut vec: Vec<(&OID, &FieldInfo)> = oid2col.iter().collect(); - vec.sort_by(|a, b| a.1.column_index().cmp(&b.1.column_index())); - let column_oid: Vec = vec.iter().map(|(id, _)| *(*id)).collect(); - let mut key_info: Vec<_> = Vec::new(); - let mut value_info: Vec<_> = Vec::new(); - key_info.resize(key_desc.field_count(), FieldInfo::default()); - value_info.resize(value_desc.field_count(), FieldInfo::default()); - for (_oid, field) in oid2col.iter() { - if field.is_primary() { - key_info[field.column_index()] = field.clone(); - } else { - value_info[field.column_index()] = field.clone(); - } - } + let column_oid = fields.iter().map(|field| field.id()).collect(); Self { name, oid, @@ -54,8 +43,9 @@ impl TableDesc { value_oid, key_desc, value_desc, - key_info, - value_info, + fields, + key_indices, + value_indices, oid2col, name2oid, column_oid, @@ -70,22 +60,39 @@ impl TableDesc { &self.value_oid } + // AttrIndex always refers to the original table column order. + // Use FieldInfo.datum_index() to locate the field inside the key/value tuple. pub fn get_attr(&self, index: AttrIndex) -> &FieldInfo { - if index < self.key_info.len() { - &self.key_info[index] - } else { - &self.value_info[index - self.key_info.len()] - } + &self.fields[index] } - pub fn key_info(&self) -> &Vec { - &self.key_info + + pub fn fields(&self) -> &Vec { + &self.fields + } + + pub fn key_indices(&self) -> &Vec { + &self.key_indices + } + + pub fn value_indices(&self) -> &Vec { + &self.value_indices + } + + pub fn key_info(&self) -> Vec<&FieldInfo> { + self.key_indices + .iter() + .map(|index| &self.fields[*index]) + .collect() } pub fn key_desc(&self) -> &TupleDesc { &self.key_desc } - pub fn value_info(&self) -> &Vec { - &self.value_info + pub fn value_info(&self) -> Vec<&FieldInfo> { + self.value_indices + .iter() + .map(|index| &self.fields[*index]) + .collect() } pub fn value_desc(&self) -> &TupleDesc { &self.value_desc diff --git a/mudu_kernel/src/contract/table_info.rs b/mudu_kernel/src/contract/table_info.rs index 2c19087..e5e04e6 100644 --- a/mudu_kernel/src/contract/table_info.rs +++ b/mudu_kernel/src/contract/table_info.rs @@ -1,7 +1,7 @@ use crate::contract::field_info::FieldInfo; use crate::contract::schema_table::SchemaTable; use crate::contract::table_desc::TableDesc; -use mudu::common::id::OID; +use mudu::common::id::{AttrIndex, OID}; use mudu::common::result::RS; use mudu_contract::tuple::tuple_binary_desc::TupleBinaryDesc as TupleDesc; use std::collections::HashMap; @@ -17,8 +17,11 @@ struct TableInner { schema_table: Arc, name2oid: HashMap, oid2column: HashMap, + fields: Vec, key_oid: Vec, value_oid: Vec, + key_indices: Vec, + value_indices: Vec, key_tuple_desc: TupleDesc, value_tuple_desc: TupleDesc, } @@ -37,6 +40,9 @@ impl TableInfo { inner.id(), inner.key_oid.clone(), inner.value_oid.clone(), + inner.key_indices.clone(), + inner.value_indices.clone(), + inner.fields.clone(), inner.key_tuple_desc.clone(), inner.value_tuple_desc.clone(), inner.name2oid.clone(), @@ -52,6 +58,8 @@ impl TableInfo { impl TableInner { pub fn new(table_schema: SchemaTable) -> RS { + let key_indices = table_schema.key_indices().clone(); + let value_indices = table_schema.value_indices().clone(); let (key_tuple_desc, key_tuple_payload_info) = table_schema.key_tuple_desc()?; let (value_tuple_desc, value_tuple_payload_info) = table_schema.value_tuple_desc()?; if value_tuple_desc.field_count() != value_tuple_payload_info.len() { @@ -59,25 +67,33 @@ impl TableInner { } let mut name2oid = HashMap::new(); let mut oid2column = HashMap::new(); + let mut fields = vec![FieldInfo::default(); table_schema.columns().len()]; let mut key_oid = Vec::new(); let mut value_oid = Vec::new(); - for (payload, oids) in [ - (key_tuple_payload_info, &mut key_oid), - (value_tuple_payload_info, &mut value_oid), - ] { - for field_info in payload { - oids.push(field_info.id()); - name2oid.insert(field_info.name().clone(), field_info.id()); - oid2column.insert(field_info.id(), field_info.clone()); - } + for field_info in key_tuple_payload_info { + let column_index = field_info.column_index(); + key_oid.push(field_info.id()); + name2oid.insert(field_info.name().clone(), field_info.id()); + oid2column.insert(field_info.id(), field_info.clone()); + fields[column_index] = field_info; + } + for field_info in value_tuple_payload_info { + let column_index = field_info.column_index(); + value_oid.push(field_info.id()); + name2oid.insert(field_info.name().clone(), field_info.id()); + oid2column.insert(field_info.id(), field_info.clone()); + fields[column_index] = field_info; } Ok(Self { schema_table: Arc::new(table_schema), name2oid, oid2column, + fields, key_oid, value_oid, + key_indices, + value_indices, key_tuple_desc, value_tuple_desc, }) diff --git a/mudu_kernel/src/contract/test_schema.rs b/mudu_kernel/src/contract/test_schema.rs index 684ffd4..2de4eef 100644 --- a/mudu_kernel/src/contract/test_schema.rs +++ b/mudu_kernel/src/contract/test_schema.rs @@ -19,11 +19,11 @@ pub mod _fuzz { for s in vec.iter() { let (key_desc, key_mapping) = s.key_tuple_desc().unwrap(); let (value_desc, value_mapping) = s.value_tuple_desc().unwrap(); - let key_columns = s.key_columns(); - let value_columns = s.value_columns(); - for (_i, (columns, desc, mapping)) in vec![ - (key_columns, key_desc, key_mapping), - (value_columns, value_desc, value_mapping), + let key_indices = s.key_indices(); + let value_indices = s.value_indices(); + for (_i, (indices, desc, mapping)) in vec![ + (key_indices, key_desc, key_mapping), + (value_indices, value_desc, value_mapping), ] .into_iter() .enumerate() @@ -31,13 +31,13 @@ pub mod _fuzz { assert_eq!(desc.field_count(), mapping.len()); for (i, field_info) in mapping.iter().enumerate() { let fd = desc.get_field_desc(i); - let sc = &columns[field_info.column_index()]; + let sc = s.column_by_index(indices[field_info.column_index()]); if _i == 0 { assert!(sc.is_primary()) } else if _i == 1 { assert!(!sc.is_primary()) } - assert_eq!(sc.get_index(), field_info.column_index() as u32); + assert_eq!(sc.get_index(), field_info.column_index()); assert_eq!(sc.is_fixed_length(), fd.is_fixed_len()); assert_eq!(sc.type_id(), fd.data_type()); assert_eq!(sc.get_name(), field_info.name()); diff --git a/mudu_kernel/src/contract/timestamp.rs b/mudu_kernel/src/contract/timestamp.rs index f9c884d..fd85822 100644 --- a/mudu_kernel/src/contract/timestamp.rs +++ b/mudu_kernel/src/contract/timestamp.rs @@ -1,7 +1,6 @@ #[cfg(any(test, feature = "test"))] use arbitrary::Arbitrary; -use mudu::common::bc_dec::{DecErr, Decode, Decoder}; -use mudu::common::bc_enc::{EncErr, Encode, Encoder}; +use mudu::common::codec::{DecErr, Decode, Decoder, EncErr, Encode, Encoder}; #[cfg_attr(any(test, feature = "test"), derive(Arbitrary))] #[derive(Clone, Debug, Eq, PartialEq)] diff --git a/mudu_kernel/src/contract/worker_snapshot.rs b/mudu_kernel/src/contract/worker_snapshot.rs new file mode 100644 index 0000000..e69de29 diff --git a/mudu_kernel/src/contract/xl_batch.rs b/mudu_kernel/src/contract/xl_batch.rs deleted file mode 100644 index 01fb44d..0000000 --- a/mudu_kernel/src/contract/xl_batch.rs +++ /dev/null @@ -1,71 +0,0 @@ -use crate::contract::xl_rec::XLRec; -#[cfg(any(test, feature = "test"))] -use arbitrary::Arbitrary; -use mudu::common::bc_dec::{DecErr, Decode, Decoder}; -use mudu::common::bc_enc::{EncErr, Encode, Encoder}; - -#[cfg_attr(any(test, feature = "test"), derive(Arbitrary))] -#[derive(Debug, Eq, PartialEq)] -pub struct XLBatch { - lsn: u64, - records: Vec, -} - -impl XLBatch { - pub fn new(lsn: u64, records: Vec) -> XLBatch { - XLBatch { lsn, records } - } - - pub fn lsn(&self) -> u64 { - self.lsn - } - - pub fn records(&self) -> &Vec { - &self.records - } - - pub fn into_records(self) -> Vec { - self.records - } -} - -impl Encode for XLBatch { - fn encode(&self, encoder: &mut E) -> Result<(), EncErr> { - encoder.write_u64(self.lsn)?; - let len = self.records.len() as u32; - encoder.write_u32(len)?; - for r in self.records.iter() { - Encode::encode(r, encoder)? - } - Ok(()) - } - - fn size(&self) -> Result { - let mut len = size_of::() + size_of::(); - for r in self.records.iter() { - len += r.size()? - } - Ok(len) - } -} - -impl Decode for XLBatch { - fn decode(decoder: &mut E) -> Result { - let lsn = decoder.read_u64()?; - let len = decoder.read_u32()? as usize; - let mut records = vec![]; - for _i in 0..len { - let rec = XLRec::decode(decoder)?; - records.push(rec); - } - Ok(Self { lsn, records }) - } -} - -#[allow(unused)] -pub mod _fuzz { - #[allow(dead_code)] - pub fn _dc_en_x_l_batch(data: &[u8]) { - //_fuzz_decode_and_encode::(data); - } -} diff --git a/mudu_kernel/src/contract/xl_chunk.rs b/mudu_kernel/src/contract/xl_chunk.rs deleted file mode 100644 index 8a21034..0000000 --- a/mudu_kernel/src/contract/xl_chunk.rs +++ /dev/null @@ -1,247 +0,0 @@ -use crate::contract::lsn::LSN; -use crate::contract::xl_batch::XLBatch; -use crate::io::file::File; -use mudu::common::bc_dec::{DecErr, Decode, Decoder}; -use mudu::common::bc_enc::{EncErr, Encode, Encoder}; -use mudu::common::buf::Buf; -use mudu::common::crc::calc_crc; -use mudu::common::result::RS; -use mudu::common::result_of::std_io_error; -use mudu::common::slice::{SliceMutRef, SliceRef}; -use mudu_utils::task_trace; -use tokio::io::AsyncWriteExt; - -pub const LOG_CHUNK_PART: u8 = 1u8; -pub const LOG_CHUNK_WHOLE: u8 = 2u8; -pub const LOG_C_TYPE_SIZE: usize = size_of::(); -pub const LOG_C_CRC_SIZE: usize = size_of::(); -pub const LOG_C_COMMON_HDR_SIZE: usize = LOG_C_TYPE_SIZE + LOG_C_CRC_SIZE + size_of::(); -pub const LOG_C_HDR_SEQ_SIZE: usize = size_of::(); -pub const LOG_C_PART_HDR_SIZE: usize = LOG_C_COMMON_HDR_SIZE + LOG_C_HDR_SEQ_SIZE; -pub const LOG_C_TAIL_SIZE: usize = LOG_C_CRC_SIZE; - -const CHUNK_LAST_MASK: u32 = 1u32 << 31; -pub struct ChunkHdr { - lsn: u64, - body_crc: u64, - body_length: u32, - // first 1 bit, is this chunk the last? - // last chunk 31 bit sequence - chunk_seq: u32, -} - -pub struct ChunkTail { - body_crc: u64, -} - -pub enum XLChunk { - Part(Vec), - Whole(XLBatch), -} - -#[derive(Eq, PartialEq, Copy, Clone)] -pub enum XLChunkType { - Part, - Whole, -} - -impl XLChunkType { - pub fn from(t: u8) -> XLChunkType { - match t { - LOG_CHUNK_PART => XLChunkType::Part, - LOG_CHUNK_WHOLE => XLChunkType::Whole, - _ => { - panic!("unknown enum value") - } - } - } - - pub fn to_u8(&self) -> u8 { - match self { - XLChunkType::Part => LOG_CHUNK_PART, - XLChunkType::Whole => LOG_CHUNK_WHOLE, - } - } - - pub fn hdr_size(&self) -> usize { - match self { - XLChunkType::Part => LOG_C_COMMON_HDR_SIZE, - XLChunkType::Whole => LOG_C_PART_HDR_SIZE, - } - } -} - -pub async fn write_chunk_to_u_file( - file: &mut File, - lsn: LSN, - body: &[u8], - chunk_seq: Option<(u32, bool)>, -) -> RS { - let _ = task_trace!(); - const HEADER_SIZE: usize = LOG_C_COMMON_HDR_SIZE + LOG_C_HDR_SEQ_SIZE; - const TAIL_SIZE: usize = LOG_C_TAIL_SIZE; - let mut buf: [u8; HEADER_SIZE + TAIL_SIZE] = [0; HEADER_SIZE + TAIL_SIZE]; - let crc = calc_crc(body); - - let hdr = ChunkHdr::new(lsn, crc, body.len() as u32, chunk_seq); - let tail = ChunkTail::new(crc); - - // write header - let mut bf1 = SliceMutRef::new(&mut buf); - let _ = hdr.encode(&mut bf1).unwrap(); - let mut size = bf1.as_slice().len(); - let r = file.write_all(bf1.as_slice()).await; - std_io_error(r)?; - - // write body - size += body.len(); - let r = file.write_all(body).await; - std_io_error(r)?; - - // write tail - let mut bf2 = SliceMutRef::new(&mut buf); - let _ = tail.encode(&mut bf2).unwrap(); - size += bf2.as_slice().len(); - let r = file.write_all(bf2.as_slice()).await; - std_io_error(r)?; - - Ok(size) -} - -pub fn decode_chunk(buf: &[u8]) -> Result<(ChunkHdr, &[u8]), DecErr> { - let h = decode_chunk_hdr(buf)?; - let hdr_size = ChunkHdr::size_of(); - let body = &buf[hdr_size..hdr_size + h.body_length() as usize]; - let t = decode_chunk_tail(&buf[hdr_size + h.body_length() as usize..])?; - if t.body_crc() != h.body_crc() { - return Err(DecErr::ErrorCRC); - } - Ok((h, body)) -} - -fn decode_chunk_tail(buf: &[u8]) -> Result { - let mut r = SliceRef::new(buf); - let c = ChunkTail::decode(&mut r)?; - Ok(c) -} - -pub fn decode_chunk_hdr(buf: &[u8]) -> Result { - let mut r = SliceRef::new(buf); - let c = ChunkHdr::decode(&mut r)?; - Ok(c) -} - -impl ChunkHdr { - pub fn new(lsn: LSN, body_crc: u64, body_length: u32, chunk_seq: Option<(u32, bool)>) -> Self { - let v = match chunk_seq { - Some((seq, last)) => { - if last { - seq | CHUNK_LAST_MASK - } else { - seq - } - } - None => u32::MAX, - }; - Self { - body_length, - body_crc, - lsn, - chunk_seq: v, - } - } - - pub fn chunk_seq(&self) -> u32 { - // clear first bit - self.chunk_seq & !CHUNK_LAST_MASK - } - - pub fn chunk_type(&self) -> XLChunkType { - if self.chunk_seq == u32::MAX { - XLChunkType::Whole - } else { - XLChunkType::Part - } - } - - pub fn lsn(&self) -> LSN { - self.lsn - } - pub fn body_length(&self) -> u32 { - self.body_length - } - - pub fn body_crc(&self) -> u64 { - self.body_crc - } - - pub fn size_of() -> usize { - size_of::() - } -} - -impl Decode for ChunkHdr { - fn decode(decoder: &mut D) -> Result { - let lsn = decoder.read_u64()?; - let body_crc = decoder.read_u64()?; - let body_length = decoder.read_u32()?; - let pad = decoder.read_u32()?; - Ok(Self { - lsn, - body_length, - body_crc, - chunk_seq: pad, - }) - } -} - -impl Encode for ChunkHdr { - fn encode(&self, encoder: &mut E) -> Result<(), EncErr> { - encoder.write_u64(self.lsn)?; - encoder.write_u64(self.body_crc)?; - encoder.write_u32(self.body_length)?; - encoder.write_u32(self.chunk_seq)?; - Ok(()) - } - - fn size(&self) -> Result { - let mut size = 0; - size += size_of_val(&self.lsn); - size += size_of_val(&self.body_crc); - size += size_of_val(&self.body_length); - size += size_of_val(&self.chunk_seq); - Ok(size) - } -} - -impl ChunkTail { - fn new(body_crc: u64) -> Self { - Self { body_crc } - } - - pub fn body_crc(&self) -> u64 { - self.body_crc - } - - pub fn size_of() -> usize { - size_of::() - } -} - -impl Decode for ChunkTail { - fn decode(decoder: &mut D) -> Result { - let body_crc = decoder.read_u64()?; - Ok(Self { body_crc }) - } -} - -impl Encode for ChunkTail { - fn encode(&self, encoder: &mut E) -> Result<(), EncErr> { - encoder.write_u64(self.body_crc)?; - Ok(()) - } - - fn size(&self) -> Result { - Ok(LOG_C_CRC_SIZE) - } -} diff --git a/mudu_kernel/src/contract/xl_d_delete.rs b/mudu_kernel/src/contract/xl_d_delete.rs deleted file mode 100644 index ab5d941..0000000 --- a/mudu_kernel/src/contract/xl_d_delete.rs +++ /dev/null @@ -1,60 +0,0 @@ -#[cfg(test)] -use arbitrary::Arbitrary; -use mudu::common::bc_dec::{DecErr, Decode, Decoder}; -use mudu::common::bc_enc::{EncErr, Encode, Encoder}; -use mudu::common::buf::Buf; -use mudu::common::id::OID; - -// delete key value -#[cfg_attr(any(test, feature = "test"), derive(Arbitrary))] -#[derive(Debug, Eq, PartialEq)] -pub struct XLDDelete { - table_id: OID, - tuple_id: OID, - key: Buf, -} - -impl XLDDelete { - pub fn new(table_id: OID, tuple_id: OID, key: Buf) -> Self { - Self { - table_id, - tuple_id, - key, - } - } -} -impl Encode for XLDDelete { - fn encode(&self, encoder: &mut E) -> Result<(), EncErr> { - encoder.write_u128(self.table_id)?; - encoder.write_u128(self.tuple_id)?; - encoder.write_u32(self.key.len() as u32)?; - encoder.write_bytes(self.key.as_slice())?; - Ok(()) - } - - fn size(&self) -> Result { - let mut size = 0usize; - size += size_of_val(&self.table_id); - size += size_of_val(&self.tuple_id); - size += size_of::(); - size += self.key.len(); - Ok(size) - } -} - -impl Decode for XLDDelete { - fn decode(decoder: &mut E) -> Result { - let table_id = decoder.read_u128()?; - let tuple_id = decoder.read_u128()?; - let mut key = Buf::new(); - let len: u32 = decoder.read_u32()?; - key.resize(len as usize, 0); - decoder.read_bytes(key.as_mut_slice())?; - - Ok(Self { - table_id, - tuple_id, - key, - }) - } -} diff --git a/mudu_kernel/src/contract/xl_d_insert.rs b/mudu_kernel/src/contract/xl_d_insert.rs deleted file mode 100644 index eaa2f32..0000000 --- a/mudu_kernel/src/contract/xl_d_insert.rs +++ /dev/null @@ -1,89 +0,0 @@ -#[cfg(test)] -use arbitrary::Arbitrary; -use mudu::common::bc_dec::{DecErr, Decode, Decoder}; -use mudu::common::bc_enc::{EncErr, Encode, Encoder}; -use mudu::common::buf::Buf; -use mudu::common::id::OID; - -// insert or replace a key value pair -#[cfg_attr(any(test, feature = "test"), derive(Arbitrary))] -#[derive(Debug, Eq, PartialEq)] -pub struct XLDInsert { - table_id: OID, - tuple_id: OID, - key: Buf, - value: Buf, -} - -impl XLDInsert { - pub fn new(table_id: OID, tuple_id: OID, key: Buf, value: Buf) -> XLDInsert { - Self { - table_id, - tuple_id, - key, - value, - } - } - - pub fn table_id(&self) -> OID { - self.table_id - } - - pub fn tuple_id(&self) -> OID { - self.tuple_id - } - - pub fn key(&self) -> &Buf { - &self.key - } - - pub fn value(&self) -> &Buf { - &self.value - } -} - -impl Encode for XLDInsert { - fn encode(&self, encoder: &mut E) -> Result<(), EncErr> { - encoder.write_u128(self.table_id)?; - encoder.write_u128(self.tuple_id)?; - encoder.write_u32(self.key.len() as u32)?; - encoder.write_bytes(self.key.as_slice())?; - encoder.write_u32(self.value.len() as u32)?; - encoder.write_bytes(self.value.as_slice())?; - Ok(()) - } - - fn size(&self) -> Result { - let mut size = 0; - size += size_of_val(&self.table_id); - size += size_of_val(&self.tuple_id); - size += size_of::(); // key length - size += self.key.len(); // key - size += size_of::(); // value length - size += self.value.len(); // value - Ok(size) - } -} - -impl Decode for XLDInsert { - fn decode(decoder: &mut E) -> Result { - let table_id = decoder.read_u128()?; - let tuple_id = decoder.read_u128()?; - let mut key = Buf::new(); - let mut value = Buf::new(); - let len: u32 = decoder.read_u32()?; - - key.resize(len as usize, 0); - decoder.read_bytes(key.as_mut_slice())?; - - let len: u32 = decoder.read_u32()?; - value.resize(len as usize, 0); - decoder.read_bytes(value.as_mut_slice())?; - Ok(Self { - table_id, - tuple_id, - key, - value, - }) - } -} diff --git a/mudu_kernel/src/contract/xl_d_update.rs b/mudu_kernel/src/contract/xl_d_update.rs deleted file mode 100644 index 8d329b0..0000000 --- a/mudu_kernel/src/contract/xl_d_update.rs +++ /dev/null @@ -1,94 +0,0 @@ -#[cfg(test)] -use arbitrary::Arbitrary; -use mudu::common::bc_dec::{DecErr, Decode, Decoder}; -use mudu::common::bc_enc::{EncErr, Encode, Encoder}; -use mudu::common::buf::Buf; -use mudu::common::id::OID; -use mudu::common::update_delta::UpdateDelta; - -#[cfg_attr(any(test, feature = "test"), derive(Arbitrary))] -#[derive(Debug, Eq, PartialEq)] -pub struct XLDUpdate { - table_id: OID, - tuple_id: OID, - key: Buf, - delta: Vec, -} - -impl XLDUpdate { - pub fn new(table_id: OID, tuple_id: OID, key: Buf, delta: Vec) -> XLDUpdate { - Self { - table_id, - tuple_id, - key, - delta, - } - } - - pub fn table_id(&self) -> OID { - self.table_id - } - - pub fn tuple_id(&self) -> OID { - self.tuple_id - } - - pub fn key(&self) -> &Buf { - &self.key - } - - pub fn delta(&self) -> &Vec { - &self.delta - } -} - -impl Encode for XLDUpdate { - fn encode(&self, encoder: &mut E) -> Result<(), EncErr> { - encoder.write_u128(self.table_id)?; - encoder.write_u128(self.tuple_id)?; - encoder.write_u32(self.key.len() as u32)?; - encoder.write_bytes(self.key.as_slice())?; - encoder.write_u32(self.delta.len() as u32)?; - for d in self.delta.iter() { - Encode::encode(d, encoder)?; - } - Ok(()) - } - - fn size(&self) -> Result { - let mut size = 0; - size += size_of_val(&self.table_id); - size += size_of_val(&self.tuple_id); - size += size_of::(); - size += self.key.len(); - size += size_of::(); // delta len - for d in self.delta.iter() { - size += d.size()?; - } - Ok(size) - } -} - -impl Decode for XLDUpdate { - fn decode(decoder: &mut E) -> Result { - let table_id = decoder.read_u128()?; - let tuple_id = decoder.read_u128()?; - let mut key = Buf::new(); - let len: u32 = decoder.read_u32()?; - key.resize(len as usize, 0); - decoder.read_bytes(key.as_mut_slice())?; - let num_delta = decoder.read_u32()?; - let mut delta = vec![]; - for _i in 0..num_delta { - let data = UpdateDelta::decode(decoder)?; - delta.push(data); - } - - Ok(Self { - table_id, - tuple_id, - key, - delta, - }) - } -} diff --git a/mudu_kernel/src/contract/xl_op.rs b/mudu_kernel/src/contract/xl_op.rs deleted file mode 100644 index ad3f71b..0000000 --- a/mudu_kernel/src/contract/xl_op.rs +++ /dev/null @@ -1,101 +0,0 @@ -use crate::contract::xl_d_delete::XLDDelete; -use crate::contract::xl_d_insert::XLDInsert; -use crate::contract::xl_d_update::XLDUpdate; -#[cfg(test)] -use arbitrary::Arbitrary; -use mudu::common::bc_dec::{DecErr, Decode, Decoder}; -use mudu::common::bc_enc::{EncErr, Encode, Encoder}; - -#[cfg_attr(any(test, feature = "test"), derive(Arbitrary))] -#[derive(Debug, Eq, PartialEq)] -pub enum XLOp { - // transaction control op - CBegin, - CCommit, - CAbort, - // data op - DInsert(XLDInsert), - DUpdate(XLDUpdate), - DDelete(XLDDelete), -} - -const INVALID: u8 = 0; -const BEGIN: u8 = 1; -const COMMIT: u8 = 2; -const ABORT: u8 = 3; - -const INSERT: u8 = 4; -const UPDATE: u8 = 5; -const DELETE: u8 = 6; - -impl Encode for XLOp { - fn encode(&self, encoder: &mut E) -> Result<(), EncErr> { - match self { - XLOp::CBegin => { - encoder.write_u8(BEGIN)?; - } - XLOp::CCommit => { - encoder.write_u8(COMMIT)?; - } - XLOp::CAbort => { - encoder.write_u8(ABORT)?; - } - XLOp::DInsert(op) => { - encoder.write_u8(INSERT)?; - Encode::encode(op, encoder)?; - } - XLOp::DUpdate(op) => { - encoder.write_u8(UPDATE)?; - Encode::encode(op, encoder)?; - } - XLOp::DDelete(op) => { - encoder.write_u8(DELETE)?; - Encode::encode(op, encoder)?; - } - } - - Ok(()) - } - - fn size(&self) -> Result { - let size = size_of::(); - let n = match self { - XLOp::CBegin => 0, - XLOp::CCommit => 0, - XLOp::CAbort => 0, - XLOp::DInsert(op) => op.size()?, - XLOp::DUpdate(op) => op.size()?, - XLOp::DDelete(op) => op.size()?, - }; - Ok(size + n) - } -} - -impl Decode for XLOp { - fn decode(decoder: &mut E) -> Result { - let xl_type: u8 = decoder.read_u8()?; - let res = match xl_type { - BEGIN => XLOp::CBegin, - COMMIT => XLOp::CCommit, - ABORT => XLOp::CAbort, - INSERT => { - let op = XLDInsert::decode(decoder)?; - XLOp::DInsert(op) - } - UPDATE => { - let op = XLDUpdate::decode(decoder)?; - XLOp::DUpdate(op) - } - DELETE => { - let op = XLDDelete::decode(decoder)?; - XLOp::DDelete(op) - } - _ => { - return Err(DecErr::EmptyEnum { - type_name: "XLOp".to_string(), - }) - } - }; - Ok(res) - } -} diff --git a/mudu_kernel/src/contract/xl_rec.rs b/mudu_kernel/src/contract/xl_rec.rs deleted file mode 100644 index 231a4eb..0000000 --- a/mudu_kernel/src/contract/xl_rec.rs +++ /dev/null @@ -1,84 +0,0 @@ -use crate::contract::xl_d_delete::XLDDelete; -use crate::contract::xl_d_insert::XLDInsert; -use crate::contract::xl_d_update::XLDUpdate; -use crate::contract::xl_op::XLOp; -#[cfg(any(test, feature = "test"))] -use arbitrary::Arbitrary; -use mudu::common::bc_dec::{DecErr, Decode, Decoder}; -use mudu::common::bc_enc::{EncErr, Encode, Encoder}; -use mudu::common::buf::Buf; -use mudu::common::id::OID; -use mudu::common::update_delta::UpdateDelta; -use mudu::common::xid::XID; -use std::mem::size_of; - -#[cfg_attr(any(test, feature = "test"), derive(Arbitrary))] -#[derive(Debug, Eq, PartialEq)] -pub struct XLRec { - xid: XID, - ops: Vec, -} - -impl XLRec { - pub fn new(xid: XID) -> XLRec { - let ops = vec![]; - Self { xid, ops } - } - - pub fn add_insert(&mut self, table_id: OID, tuple_id: OID, key: Buf, value: Buf) { - let op = XLOp::DInsert(XLDInsert::new(table_id, tuple_id, key, value)); - self.ops.push(op); - } - - pub fn add_update(&mut self, table_id: OID, tuple_id: OID, key: Buf, value: Vec) { - let op = XLOp::DUpdate(XLDUpdate::new(table_id, tuple_id, key, value)); - self.ops.push(op); - } - - pub fn add_delete(&mut self, table_id: OID, tuple_id: OID, key: Buf) { - let op = XLOp::DDelete(XLDDelete::new(table_id, tuple_id, key)); - self.ops.push(op); - } - - pub fn commit(&mut self) { - let op = XLOp::CCommit; - self.ops.push(op) - } -} - -impl Encode for XLRec { - fn encode(&self, encoder: &mut E) -> Result<(), EncErr> { - encoder.write_u128(self.xid)?; - let len = self.ops.len() as u32; - encoder.write_u32(len)?; - for x in self.ops.iter() { - Encode::encode(x, encoder)? - } - Ok(()) - } - - fn size(&self) -> Result { - let mut len = 0; - len += size_of::(); - len += size_of::(); - for x in self.ops.iter() { - let n = Encode::size(x)?; - len += n; - } - Ok(len) - } -} - -impl Decode for XLRec { - fn decode(decoder: &mut E) -> Result { - let xid = decoder.read_u128()?; - let mut ops = vec![]; - let len = decoder.read_u32()? as usize; - for _i in 0..len { - let op = Decode::decode(decoder)?; - ops.push(op); - } - let res = Self { xid, ops }; - Ok(res) - } -} diff --git a/mudu_kernel/src/executor/index_access_key.rs b/mudu_kernel/src/executor/index_access_key.rs index 173719b..2f265cd 100644 --- a/mudu_kernel/src/executor/index_access_key.rs +++ b/mudu_kernel/src/executor/index_access_key.rs @@ -75,7 +75,13 @@ impl _IndexAccessKey { let p = &self.param; let row = self .x_contract - .read_key(p.xid, p.table_id, &p.pred_key, &p.select, &p.opt_read) + .read_key( + p.tx_mgr.clone(), + p.table_id, + &p.pred_key, + &p.select, + &p.opt_read, + ) .await?; Ok(row.map(TupleRow::new)) } diff --git a/mudu_kernel/src/executor/index_access_range.rs b/mudu_kernel/src/executor/index_access_range.rs index b4808af..e6710c3 100644 --- a/mudu_kernel/src/executor/index_access_range.rs +++ b/mudu_kernel/src/executor/index_access_range.rs @@ -66,7 +66,7 @@ impl _IndexAccessRange { let cursor = self .x_contract .read_range( - param.xid, + param.tx_mgr.clone(), param.table_id, ¶m.pred_key, ¶m.pred_non_key, diff --git a/mudu_kernel/src/fuzz/_fuzz_run.rs b/mudu_kernel/src/fuzz/_fuzz_run.rs index 8a8c77b..3ea58f0 100644 --- a/mudu_kernel/src/fuzz/_fuzz_run.rs +++ b/mudu_kernel/src/fuzz/_fuzz_run.rs @@ -19,7 +19,7 @@ lazy_static! { ), ( "_de_en_x_l_batch", - crate::contract::xl_batch::_fuzz::_dc_en_x_l_batch, + crate::wal::xl_batch::_fuzz::_de_en_x_l_batch, ), ("_gen_order_csv", crate::test::fuzz_gen_csv::_gen_order_csv,), ]; diff --git a/mudu_kernel/src/io/file.rs b/mudu_kernel/src/io/file.rs index 1dc7ab4..b9fdbd7 100644 --- a/mudu_kernel/src/io/file.rs +++ b/mudu_kernel/src/io/file.rs @@ -738,7 +738,12 @@ pub(crate) fn submit_file_io( ) -> FileInflightOp { match request { FileIoRequest::Open(request) => { - sqe.prep_openat(libc::AT_FDCWD, request.path().as_c_str(), request.flags(), request.mode()); + sqe.prep_openat( + libc::AT_FDCWD, + request.path().as_c_str(), + request.flags(), + request.mode(), + ); FileInflightOp::Open(Box::new(request)) } FileIoRequest::Close(request) => { @@ -747,7 +752,12 @@ pub(crate) fn submit_file_io( } FileIoRequest::Read(request) => { let mut buf = vec![0u8; request.len()]; - sqe.prep_read_raw(request.fd(), buf.as_mut_ptr(), request.len(), request.offset()); + sqe.prep_read_raw( + request.fd(), + buf.as_mut_ptr(), + request.len(), + request.offset(), + ); FileInflightOp::Read { request: Box::new(request), buf, diff --git a/mudu_kernel/src/io/socket.rs b/mudu_kernel/src/io/socket.rs index 0282d0a..20f5a23 100644 --- a/mudu_kernel/src/io/socket.rs +++ b/mudu_kernel/src/io/socket.rs @@ -193,11 +193,7 @@ impl SocketOpenRequest { } impl SocketConnectRequest { - fn new( - fd: RawFd, - addr: mudu_sys::uring::SockAddrBuf, - state: Arc>, - ) -> Self { + fn new(fd: RawFd, addr: mudu_sys::uring::SockAddrBuf, state: Arc>) -> Self { Self { fd, addr, state } } @@ -836,7 +832,12 @@ pub(crate) fn submit_socket_io( ) -> SocketInflightOp { match request { SocketIoRequest::Socket(request) => { - sqe.prep_socket(request.domain(), request.socket_type(), request.protocol(), 0); + sqe.prep_socket( + request.domain(), + request.socket_type(), + request.protocol(), + 0, + ); SocketInflightOp::Open(Box::new(request)) } SocketIoRequest::Connect(request) => { @@ -849,7 +850,12 @@ pub(crate) fn submit_socket_io( SocketInflightOp::Accept(request) } SocketIoRequest::Recv(request) => { - sqe.prep_recv_raw(request.fd(), request.buf_ptr().cast(), request.len(), request.flags()); + sqe.prep_recv_raw( + request.fd(), + request.buf_ptr().cast(), + request.len(), + request.flags(), + ); SocketInflightOp::Recv(Box::new(request)) } SocketIoRequest::Send(request) => { diff --git a/mudu_kernel/src/io/worker_ring.rs b/mudu_kernel/src/io/worker_ring.rs index 4b369d6..22f0aa4 100644 --- a/mudu_kernel/src/io/worker_ring.rs +++ b/mudu_kernel/src/io/worker_ring.rs @@ -1,4 +1,4 @@ -use std::cell::RefCell; +use std::cell::UnsafeCell; use std::collections::{HashMap, VecDeque}; use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::{Arc, Mutex}; @@ -12,7 +12,8 @@ use crate::io::socket::{complete_socket_io, submit_socket_io, SocketInflightOp, use crate::server::task_registry::WorkerTaskRegistry; thread_local! { - static CURRENT_WORKER_RING: RefCell>> = const { RefCell::new(None) }; + static CURRENT_WORKER_RING: UnsafeCell>> = + const { UnsafeCell::new(None) }; } pub(crate) enum WorkerRingOp { @@ -45,6 +46,7 @@ impl WorkerLocalRing { pub fn worker_task_registry(&self) -> &WorkerTaskRegistry { &self.worker_tasks } + pub(crate) fn register(&self, op: WorkerRingOp) -> RS { let op_id = self.next_op_id.fetch_add(1, Ordering::Relaxed); self.ops @@ -96,18 +98,27 @@ impl WorkerLocalRing { pub(crate) fn set_current_worker_ring(ring: Arc) { CURRENT_WORKER_RING.with(|slot| { - *slot.borrow_mut() = Some(ring); + // Safety: this slot is thread-local and only accessed through these helpers. + unsafe { + *slot.get() = Some(ring); + } }); } pub(crate) fn unset_current_worker_ring() { CURRENT_WORKER_RING.with(|slot| { - *slot.borrow_mut() = None; + // Safety: this slot is thread-local and only accessed through these helpers. + unsafe { + *slot.get() = None; + } }); } pub(crate) fn has_current_worker_ring() -> bool { - CURRENT_WORKER_RING.with(|slot| slot.borrow().is_some()) + CURRENT_WORKER_RING.with(|slot| { + // Safety: shared reads are confined to the current thread-local slot. + unsafe { (*slot.get()).is_some() } + }) } pub(crate) fn with_current_ring(f: F) -> RS @@ -115,7 +126,8 @@ where F: FnOnce(&Arc) -> RS, { CURRENT_WORKER_RING.with(|slot| { - let ring = slot.borrow(); + // Safety: shared reads are confined to the current thread-local slot. + let ring = unsafe { &*slot.get() }; let ring = ring .as_ref() .ok_or_else(|| m_error!(EC::NoSuchElement, "current worker ring is not set"))?; diff --git a/mudu_kernel/src/io/worker_ring_stub.rs b/mudu_kernel/src/io/worker_ring_stub.rs index 66d7c6c..9e0232f 100644 --- a/mudu_kernel/src/io/worker_ring_stub.rs +++ b/mudu_kernel/src/io/worker_ring_stub.rs @@ -27,6 +27,10 @@ pub(crate) fn has_current_worker_ring() -> bool { false } +pub(crate) fn current_ring() -> &'static WorkerLocalRing { + panic!("worker ring is only available on linux") +} + pub(crate) fn with_current_ring(_f: F) -> RS where F: FnOnce(&Arc) -> RS, diff --git a/mudu_kernel/src/lib.rs b/mudu_kernel/src/lib.rs index b01d241..9fae0b8 100644 --- a/mudu_kernel/src/lib.rs +++ b/mudu_kernel/src/lib.rs @@ -1,10 +1,10 @@ -mod collection; mod common; pub mod contract; pub mod fuzz; pub mod index; pub mod io; mod meta; +pub mod mudu_conn; pub mod sql; pub mod wal; diff --git a/mudu_kernel/src/meta/_fuzz.rs b/mudu_kernel/src/meta/_fuzz.rs index 9aea56b..cdbb6af 100644 --- a/mudu_kernel/src/meta/_fuzz.rs +++ b/mudu_kernel/src/meta/_fuzz.rs @@ -103,11 +103,12 @@ fn fuzz_row_for_schema<'a>( return Ok(()); } let key = loop { - let mut key = Vec::with_capacity(schema.key_columns().len()); + let key_columns = schema.key_columns(); + let mut key = Vec::with_capacity(key_columns.len()); if u.len() == 0 { return Ok(()); } - for c in schema.key_columns() { + for c in key_columns { let s = arb_string(c, u)?; key.push(s); } @@ -115,8 +116,9 @@ fn fuzz_row_for_schema<'a>( break key; } }; - let mut value = Vec::with_capacity(schema.value_columns().len()); - for c in schema.value_columns() { + let value_columns = schema.value_columns(); + let mut value = Vec::with_capacity(value_columns.len()); + for c in value_columns { let s = arb_string(c, u)?; value.push(s); } diff --git a/mudu_kernel/src/meta/meta_mgr.rs b/mudu_kernel/src/meta/meta_mgr.rs index 7956f0d..c9f389f 100644 --- a/mudu_kernel/src/meta/meta_mgr.rs +++ b/mudu_kernel/src/meta/meta_mgr.rs @@ -1,22 +1,42 @@ -use crate::contract::meta_mgr::MetaMgr; -use crate::contract::schema_table::SchemaTable; -use crate::contract::table_desc::TableDesc; -use crate::contract::table_info::TableInfo; +use std::collections::HashMap; +use std::fs; +use std::path::{Path, PathBuf}; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::{Arc, Mutex as StdMutex, OnceLock, Weak}; +use std::time::{SystemTime, UNIX_EPOCH}; + use async_trait::async_trait; use mudu::common::id::OID; use mudu::common::result::RS; -use mudu::common::result_of::rs_io; use mudu::error::ec::EC as ER; use mudu::m_error; -use std::collections::HashMap; -use std::fs; -use std::fs::{File, OpenOptions}; -use std::path::{Path, PathBuf}; -use std::sync::Arc; -use tracing::info; + +use crate::contract::meta_mgr::MetaMgr; +use crate::contract::schema_table::SchemaTable; +use crate::contract::table_desc::TableDesc; +use crate::contract::table_info::TableInfo; +use crate::meta::schema_catalog::{ + delete_schema_from_catalog, load_schemas_from_catalog, open_schema_catalog, + write_schema_to_catalog, +}; +use crate::storage::relation::relation::Relation; + +type MetaMgrRegistry = HashMap>>; + +fn registry() -> &'static StdMutex { + static REGISTRY: OnceLock> = OnceLock::new(); + REGISTRY.get_or_init(|| StdMutex::new(HashMap::new())) +} + +fn ddl_lock() -> &'static tokio::sync::Mutex<()> { + static DDL_LOCK: OnceLock> = OnceLock::new(); + DDL_LOCK.get_or_init(|| tokio::sync::Mutex::new(())) +} pub struct MetaMgrImpl { path: String, + schema_catalog: Relation, + next_catalog_xid: AtomicU64, id2table: scc::HashMap, name2id: scc::HashMap, table: scc::HashMap, @@ -24,124 +44,159 @@ pub struct MetaMgrImpl { impl MetaMgrImpl { pub fn new>(path: P) -> RS { - let mut hash_table = HashMap::new(); let path = PathBuf::from(path.as_ref()); - if fs::metadata(path.clone()).is_err() { - fs::create_dir(path.clone()).map_err(|e| m_error!(ER::IOErr, "", e))?; - } - - for entry in rs_io(fs::read_dir(path.clone()))? { - let entry = rs_io(entry)?; - let path = entry.path(); - - let metadata = rs_io(fs::metadata(&path))?; - if metadata.is_file() { - let schema = Self::read_schema_from_file(&path.to_str().unwrap().to_string())?; - hash_table.insert(schema.table_name().to_string(), TableInfo::new(schema)?); - } + if fs::metadata(&path).is_err() { + fs::create_dir_all(&path).map_err(|e| m_error!(ER::IOErr, "", e))?; } + let path_string = path.to_string_lossy().to_string(); + let schema_catalog = open_schema_catalog(&path_string)?; let this = Self { - path: path.to_str().unwrap().to_string(), + path: path_string, + schema_catalog, + next_catalog_xid: AtomicU64::new(now_catalog_xid()), id2table: Default::default(), name2id: Default::default(), table: Default::default(), }; - - for (table_name, table_info) in hash_table { - let table_id = table_info.schema().id(); - let _ = this - .table - .insert_sync(table_name.clone(), table_info.clone()); - let _ = this.id2table.insert_sync(table_id, table_info); - let _ = this.name2id.insert_sync(table_name, table_id); + for schema in load_schemas_from_catalog(&this.schema_catalog)? { + this.apply_create_table_local(&schema)?; } - Ok(this) } + pub fn register_global(self: &Arc) { + let mut guard = registry().lock().unwrap(); + guard + .entry(self.path.clone()) + .or_default() + .push(Arc::downgrade(self)); + } + pub fn lookup_table_info_by_id(&self, oid: OID) -> Option { let opt = self.id2table.get_sync(&oid); - opt.map(|e| e.get().clone()) + opt.map(|entry| entry.get().clone()) } pub fn lookup_table_by_name(&self, name: &String) -> RS>> { let opt = self.table.get_sync(name); let table_desc = match opt { None => return Ok(None), - Some(t) => t.get().table_desc()?, + Some(table) => table.get().table_desc()?, }; Ok(Some(table_desc)) } - pub fn create_table_inner(&self, schema: &SchemaTable) -> RS<()> { - if !self.table.contains_sync(schema.table_name()) { - let table_name = schema.table_name().clone(); - let mut pb = PathBuf::from(self.path.clone()); - pb.push(format!("{}.json", schema.table_name().clone())); - let r = Self::write_schema_to_file(&pb.to_str().unwrap().to_string(), &schema); - match r { - Ok(_) => {} - Err(e) => { - info!("{:?}", e) - } - } - let table_id = schema.id(); - let table = TableInfo::new(schema.clone())?; - let _ = self.table.insert_sync(table_name.clone(), table.clone()); - let _ = self.id2table.insert_sync(table_id, table); - let _ = self.name2id.insert_sync(table_name, table_id); - } else { + pub fn list_schemas_inner(&self) -> Vec { + let mut schemas = Vec::new(); + self.table.iter_sync(|_table_name, table_info| { + schemas.push(table_info.schema().as_ref().clone()); + true + }); + schemas.sort_by_key(|schema| schema.id()); + schemas + } + + pub async fn create_table_inner(&self, schema: &SchemaTable) -> RS<()> { + let _ddl_guard = ddl_lock() + .lock() + .await; + if self.table.contains_sync(schema.table_name()) { return Err(m_error!(ER::ExistingSuchElement, "")); } - Ok(()) + + write_schema_to_catalog(&self.schema_catalog, schema, self.next_catalog_xid()).await?; + self.broadcast_create(schema) } - pub fn drop_table_inner(&self, oid: OID) -> RS<()> { + pub async fn drop_table_inner(&self, oid: OID) -> RS<()> { + let _ddl_guard = ddl_lock() + .lock() + .await; let table = self .lookup_table_info_by_id(oid) .ok_or_else(|| m_error!(ER::NoSuchElement, format!("no such table {}", oid)))?; - let schema = table.schema(); - let table_name = schema.table_name().clone(); - let mut pb = PathBuf::from(self.path.clone()); - pb.push(format!("{}.json", table_name)); - if pb.exists() { - fs::remove_file(&pb).map_err(|e| m_error!(ER::IOErr, "remove schema file error", e))?; + delete_schema_from_catalog(&self.schema_catalog, oid, self.next_catalog_xid()).await?; + self.broadcast_drop(table.schema().table_name(), oid) + } + + fn next_catalog_xid(&self) -> u64 { + let mut next = self.next_catalog_xid.load(Ordering::Relaxed); + loop { + let candidate = now_catalog_xid().max(next.saturating_add(1)); + match self.next_catalog_xid.compare_exchange( + next, + candidate, + Ordering::SeqCst, + Ordering::SeqCst, + ) { + Ok(_) => return candidate, + Err(actual) => next = actual, + } } + } + fn apply_create_table_local(&self, schema: &SchemaTable) -> RS<()> { + let table_id = schema.id(); + let table_name = schema.table_name().clone(); + let table = TableInfo::new(schema.clone())?; + let _ = self.table.insert_sync(table_name.clone(), table.clone()); + let _ = self.id2table.insert_sync(table_id, table); + let _ = self.name2id.insert_sync(table_name, table_id); + Ok(()) + } + + fn apply_drop_table_local(&self, table_name: &str, oid: OID) { let _ = self.id2table.remove_sync(&oid); - let _ = self.name2id.remove_sync(&table_name); - let _ = self.table.remove_sync(&table_name); + let _ = self.name2id.remove_sync(table_name); + let _ = self.table.remove_sync(table_name); + } + + fn broadcast_create(&self, schema: &SchemaTable) -> RS<()> { + let peers = self.peer_instances(); + if peers.is_empty() { + return self.apply_create_table_local(schema); + } + for mgr in peers { + mgr.apply_create_table_local(schema)?; + } Ok(()) } - fn read_schema_from_file(path: &String) -> RS { - let r_open = File::open(path); - let file = rs_io(r_open)?; - let r_from_reader = serde_json::from_reader::<_, SchemaTable>(file); - let schema = match r_from_reader { - Ok(e) => e, - Err(e) => { - return Err(m_error!(ER::DecodeErr, "read schema error", e)); - } - }; - Ok(schema) - } - - fn write_schema_to_file(path: &String, schema: &SchemaTable) -> RS<()> { - let r_open = OpenOptions::new() - .create(true) - .truncate(true) - .write(true) - .open(path); - let file = rs_io(r_open)?; - let r = serde_json::to_writer_pretty(file, schema); - match r { - Ok(_) => Ok(()), - Err(e) => Err(m_error!(ER::EncodeErr, "write schema error", e)), + fn broadcast_drop(&self, table_name: &str, oid: OID) -> RS<()> { + let peers = self.peer_instances(); + if peers.is_empty() { + self.apply_drop_table_local(table_name, oid); + return Ok(()); } + for mgr in peers { + mgr.apply_drop_table_local(table_name, oid); + } + Ok(()) } + + fn peer_instances(&self) -> Vec> { + let mut guard = registry().lock().unwrap(); + let peers = guard.entry(self.path.clone()).or_default(); + let mut live = Vec::with_capacity(peers.len()); + peers.retain(|weak| match weak.upgrade() { + Some(peer) => { + live.push(peer); + true + } + None => false, + }); + live + } +} + +fn now_catalog_xid() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_nanos() + .min(u64::MAX as u128) as u64 } #[async_trait] @@ -149,7 +204,7 @@ impl MetaMgr for MetaMgrImpl { async fn get_table_by_id(&self, oid: OID) -> RS> { let opt = self.lookup_table_info_by_id(oid); match opt { - Some(t) => t.table_desc(), + Some(table) => table.table_desc(), None => Err(m_error!( ER::NoSuchElement, format!("no such table {}", oid) @@ -162,14 +217,88 @@ impl MetaMgr for MetaMgrImpl { } async fn create_table(&self, schema: &SchemaTable) -> RS<()> { - self.create_table_inner(schema) + self.create_table_inner(schema).await } async fn drop_table(&self, table_id: OID) -> RS<()> { - self.drop_table_inner(table_id) + self.drop_table_inner(table_id).await + } + + async fn list_schemas(&self) -> RS> { + Ok(self.list_schemas_inner()) } } unsafe impl Sync for MetaMgrImpl {} unsafe impl Send for MetaMgrImpl {} + +#[cfg(test)] +mod tests { + use std::env::temp_dir; + + use mudu_type::dat_type_id::DatTypeID; + use mudu_type::dt_info::DTInfo; + + use crate::contract::schema_column::SchemaColumn; + + use super::*; + + fn test_schema() -> SchemaTable { + SchemaTable::new( + "meta_recovery_t".to_string(), + vec![ + SchemaColumn::new( + "id".to_string(), + DatTypeID::I32, + DTInfo::from_text(DatTypeID::I32, String::new()), + ), + SchemaColumn::new( + "v".to_string(), + DatTypeID::I32, + DTInfo::from_text(DatTypeID::I32, String::new()), + ), + ], + vec![0], + vec![1], + ) + } + + #[test] + fn meta_mgr_recovers_schema_catalog_after_reopen() { + let dir = temp_dir().join(format!("meta_mgr_catalog_{}", mudu::common::id::gen_oid())); + let mgr = Arc::new(MetaMgrImpl::new(&dir).unwrap()); + mgr.register_global(); + + let schema = test_schema(); + futures::executor::block_on(mgr.create_table(&schema)).unwrap(); + assert_eq!( + crate::meta::schema_catalog::load_schemas_from_catalog(&mgr.schema_catalog) + .unwrap() + .len(), + 1 + ); + drop(mgr); + + let reopened = MetaMgrImpl::new(&dir).unwrap(); + let table = futures::executor::block_on(reopened.get_table_by_id(schema.id())).unwrap(); + assert_eq!(table.name(), schema.table_name()); + } + + #[test] + fn meta_mgr_broadcasts_ddl_to_peer_instances() { + let dir = temp_dir().join(format!("meta_mgr_peer_{}", mudu::common::id::gen_oid())); + let mgr1 = Arc::new(MetaMgrImpl::new(&dir).unwrap()); + mgr1.register_global(); + let mgr2 = Arc::new(MetaMgrImpl::new(&dir).unwrap()); + mgr2.register_global(); + + let schema = test_schema(); + futures::executor::block_on(mgr1.create_table(&schema)).unwrap(); + let table = futures::executor::block_on(mgr2.get_table_by_id(schema.id())).unwrap(); + assert_eq!(table.name(), schema.table_name()); + + futures::executor::block_on(mgr2.drop_table(schema.id())).unwrap(); + assert!(futures::executor::block_on(mgr1.get_table_by_id(schema.id())).is_err()); + } +} diff --git a/mudu_kernel/src/meta/meta_mgr_factory.rs b/mudu_kernel/src/meta/meta_mgr_factory.rs index bc99572..b421401 100644 --- a/mudu_kernel/src/meta/meta_mgr_factory.rs +++ b/mudu_kernel/src/meta/meta_mgr_factory.rs @@ -10,7 +10,8 @@ impl MetaMgrFactory { pub fn create(path: String) -> RS> { let mut path = PathBuf::from(path); path.push("meta"); - let meta_mgr = MetaMgrImpl::new(path)?; - Ok(Arc::new(meta_mgr)) + let meta_mgr = Arc::new(MetaMgrImpl::new(path)?); + meta_mgr.register_global(); + Ok(meta_mgr) } } diff --git a/mudu_kernel/src/meta/mod.rs b/mudu_kernel/src/meta/mod.rs index ccb4610..7fd0105 100644 --- a/mudu_kernel/src/meta/mod.rs +++ b/mudu_kernel/src/meta/mod.rs @@ -4,3 +4,4 @@ pub mod _fuzz; pub mod meta_mgr; pub mod meta_mgr_factory; +pub mod schema_catalog; diff --git a/mudu_kernel/src/meta/schema_catalog.rs b/mudu_kernel/src/meta/schema_catalog.rs new file mode 100644 index 0000000..ba9ce12 --- /dev/null +++ b/mudu_kernel/src/meta/schema_catalog.rs @@ -0,0 +1,137 @@ +use std::ops::Bound; +use std::sync::Arc; +use std::time::{SystemTime, UNIX_EPOCH}; + +use mudu::common::endian; +use mudu::common::id::OID; +use mudu::common::result::RS; +use mudu_type::dat_type_id::DatTypeID; +use mudu_type::dt_info::DTInfo; + +use crate::contract::schema_column::SchemaColumn; +use crate::contract::schema_table::SchemaTable; +use crate::contract::table_desc::TableDesc; +use crate::contract::table_info::TableInfo; +use crate::server::worker_snapshot::WorkerSnapshot; +use crate::storage::relation::relation::Relation; + +pub const SCHEMA_CATALOG_PARTITION_ID: OID = 0; +pub const SCHEMA_CATALOG_TABLE_ID: OID = 0x1; +const SCHEMA_CATALOG_TABLE_NAME: &str = "__meta_schema_table"; +const SCHEMA_CATALOG_TABLE_OID_COLUMN_ID: OID = 0x10001; +const SCHEMA_CATALOG_SCHEMA_COLUMN_ID: OID = 0x10002; + +pub fn schema_catalog_schema() -> SchemaTable { + SchemaTable::new_with_oid( + SCHEMA_CATALOG_TABLE_ID, + SCHEMA_CATALOG_TABLE_NAME.to_string(), + vec![ + SchemaColumn::new_with_oid( + SCHEMA_CATALOG_TABLE_OID_COLUMN_ID, + "table_oid".to_string(), + DatTypeID::U128, + DTInfo::from_text(DatTypeID::U128, String::new()), + ), + SchemaColumn::new_with_oid( + SCHEMA_CATALOG_SCHEMA_COLUMN_ID, + "schema".to_string(), + DatTypeID::Binary, + DTInfo::from_text(DatTypeID::Binary, String::new()), + ), + ], + vec![0], + vec![1], + ) +} + +pub fn schema_catalog_desc() -> RS> { + TableInfo::new(schema_catalog_schema())?.table_desc() +} + +pub fn open_schema_catalog(path: &str) -> RS { + let desc = schema_catalog_desc()?; + Ok(Relation::new( + SCHEMA_CATALOG_TABLE_ID, + SCHEMA_CATALOG_PARTITION_ID, + path.to_string(), + desc.as_ref(), + )) +} + +pub fn encode_schema_catalog_key(oid: OID) -> RS> { + let mut key = vec![0; std::mem::size_of::()]; + endian::write_u128(&mut key, oid); + Ok(key) +} + +pub fn encode_schema_catalog_value(schema: &SchemaTable) -> RS> { + rmp_serde::to_vec(schema).map_err(|e| { + mudu::m_error!( + mudu::error::ec::EC::EncodeErr, + "encode schema catalog schema error", + e + ) + }) +} + +pub fn decode_schema_catalog_key(tuple: &[u8]) -> RS { + Ok(endian::read_u128(tuple)) +} + +pub fn decode_schema_catalog_value(tuple: &[u8]) -> RS { + rmp_serde::from_slice(tuple).map_err(|e| { + mudu::m_error!( + mudu::error::ec::EC::DecodeErr, + "decode schema catalog schema error", + e + ) + }) +} + +pub fn load_schemas_from_catalog(relation: &Relation) -> RS> { + let rows = relation.visible_range_sync( + (Bound::Unbounded, Bound::Unbounded), + &WorkerSnapshot::new(visible_snapshot_xid(), vec![]), + )?; + let mut schemas = Vec::with_capacity(rows.len()); + for (key, value) in rows { + let key_oid = decode_schema_catalog_key(&key)?; + let schema = decode_schema_catalog_value(&value)?; + if key_oid != schema.id() { + return Err(mudu::m_error!( + mudu::error::ec::EC::DecodeErr, + format!( + "schema catalog key oid {} does not match schema oid {}", + key_oid, + schema.id() + ) + )); + } + schemas.push(schema); + } + Ok(schemas) +} + +fn visible_snapshot_xid() -> u64 { + let base = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_nanos() + .min((u64::MAX - 2) as u128) as u64; + base.saturating_add(1) +} + +pub async fn write_schema_to_catalog( + relation: &Relation, + schema: &SchemaTable, + xid: u64, +) -> RS<()> { + let key = encode_schema_catalog_key(schema.id())?; + let value = encode_schema_catalog_value(schema)?; + relation.write_value(key, value, xid).await +} + +pub async fn delete_schema_from_catalog(relation: &Relation, oid: OID, xid: u64) -> RS<()> { + let key = encode_schema_catalog_key(oid)?; + relation.write_delete(key, xid).await +} diff --git a/mudu_kernel/src/mudu_conn/mod.rs b/mudu_kernel/src/mudu_conn/mod.rs new file mode 100644 index 0000000..609c32e --- /dev/null +++ b/mudu_kernel/src/mudu_conn/mod.rs @@ -0,0 +1,4 @@ +pub mod mudu_conn_core; +pub mod mudu_conn_async; +pub mod mudu_prepared_stmt; +pub mod mudu_result_set_async; diff --git a/mudu_kernel/src/mudu_conn/mudu_conn_async.rs b/mudu_kernel/src/mudu_conn/mudu_conn_async.rs new file mode 100644 index 0000000..ffa2c6f --- /dev/null +++ b/mudu_kernel/src/mudu_conn/mudu_conn_async.rs @@ -0,0 +1,125 @@ +use async_trait::async_trait; +use mudu::common::id::OID; +use mudu::common::result::RS; +use mudu::common::xid::XID; +use mudu::error::ec::EC; +use mudu::m_error; +use mudu_contract::database::db_conn::DBConnAsync; +use mudu_contract::database::prepared_stmt::PreparedStmt; +use mudu_contract::database::result_set::ResultSetAsync; +use mudu_contract::database::sql_params::SQLParams; +use mudu_contract::database::sql_stmt::SQLStmt; +use sql_parser::ast::parser::SQLParser; +use sql_parser::ast::stmt_type::StmtType; +use std::sync::Arc; +use tokio::sync::Mutex; + +use crate::mudu_conn::mudu_prepared_stmt::MuduPreparedStmt; +use crate::server::worker_local::{current_worker_local, WorkerExecute, WorkerLocalRef}; +use crate::sql::describer::Describer; + +pub struct MuduConnAsync { + worker_local: WorkerLocalRef, + parser: Arc, + session_id: Arc>>, +} + +impl MuduConnAsync { + pub fn new() -> Self { + Self { + worker_local: current_worker_local(), + parser: Arc::new(SQLParser::new()), + session_id: Arc::new(Mutex::new(None)), + } + } + + fn parse_one(&self, sql: &dyn SQLStmt) -> RS { + let stmt_list = self.parser.parse(&sql.to_sql_string())?; + let mut stmts = stmt_list.into_stmts(); + if stmts.len() != 1 { + return Err(m_error!(EC::ParseErr, "expected exactly one statement")); + } + Ok(stmts.remove(0)) + } + + async fn ensure_session_id(&self) -> RS { + let mut guard = self.session_id.lock().await; + if let Some(session_id) = *guard { + return Ok(session_id); + } + let session_id = self.worker_local.open_async().await?; + *guard = Some(session_id); + Ok(session_id) + } + + async fn active_session_id(&self) -> RS { + let guard = self.session_id.lock().await; + guard.ok_or_else(|| m_error!(EC::NoSuchElement, "no active session")) + } +} + +#[async_trait] +impl DBConnAsync for MuduConnAsync { + async fn prepare(&self, stmt: Box) -> RS> { + let parsed = self.parse_one(stmt.as_ref())?; + let desc = Describer:: + describe(self.worker_local.meta_mgr().as_ref(), parsed) + .await?; + Ok(Arc::new(MuduPreparedStmt::new( + self.worker_local.clone(), + self.session_id.clone(), + stmt, + Arc::new(desc), + ))) + } + + async fn exec_silent(&self, sql_text: String) -> RS<()> { + let session_id = self.ensure_session_id().await?; + let _ = self + .worker_local + .batch(session_id, Box::new(sql_text), Box::new(())) + .await?; + Ok(()) + } + + async fn begin_tx(&self) -> RS { + let session_id = self.ensure_session_id().await?; + self.worker_local + .execute_async(session_id, WorkerExecute::BeginTx) + .await?; + Ok(session_id) + } + + async fn rollback_tx(&self) -> RS<()> { + let session_id = self.active_session_id().await?; + self.worker_local + .execute_async(session_id, WorkerExecute::RollbackTx) + .await + } + + async fn commit_tx(&self) -> RS<()> { + let session_id = self.active_session_id().await?; + self.worker_local + .execute_async(session_id, WorkerExecute::CommitTx) + .await + } + + async fn query( + &self, + sql: Box, + param: Box, + ) -> RS> { + let session_id = self.ensure_session_id().await?; + self.worker_local.query(session_id, sql, param).await + } + + async fn execute(&self, sql: Box, param: Box) -> RS { + let session_id = self.ensure_session_id().await?; + self.worker_local.execute(session_id, sql, param).await + } + + async fn batch(&self, sql: Box, param: Box) -> RS { + let session_id = self.ensure_session_id().await?; + self.worker_local.batch(session_id, sql, param).await + } +} diff --git a/mudu_kernel/src/mudu_conn/mudu_conn_core.rs b/mudu_kernel/src/mudu_conn/mudu_conn_core.rs new file mode 100644 index 0000000..e2c0349 --- /dev/null +++ b/mudu_kernel/src/mudu_conn/mudu_conn_core.rs @@ -0,0 +1,155 @@ +use mudu::common::result::RS; +use mudu::error::ec::EC; +use mudu::m_error; +use mudu_contract::database::sql_params::SQLParams; +use mudu_contract::database::sql_stmt::SQLStmt; +use mudu_contract::tuple::tuple_field_desc::TupleFieldDesc; +use mudu_contract::tuple::tuple_value::TupleValue; +use mudu_contract::tuple::typed_bin::TypedBin; +use mudu_type::datum::DatumDyn; +use sql_parser::ast::parser::SQLParser; +use sql_parser::ast::stmt_type::StmtType; +use std::sync::Arc; + +use crate::contract::meta_mgr::MetaMgr; +use crate::contract::query_exec::QueryExec; +use crate::mudu_conn::mudu_result_set_async::MuduResultSetAsync; +use crate::sql::binder::Binder; +use crate::sql::bound_stmt::BoundStmt; +use crate::sql::describer::Describer; +use crate::sql::plan_ctx::PlanCtx; +use crate::sql::planner::Planner; +use crate::x_engine::api::XContract; +use crate::x_engine::tx_mgr::TxMgr; + +pub struct MuduConnCore { + meta_mgr: Arc, + parser: Arc, +} + +impl MuduConnCore { + pub fn new(meta_mgr: Arc) -> Self { + Self { + meta_mgr, + parser: Arc::new(SQLParser::new()), + } + } + + pub fn parse_one(&self, sql: &dyn SQLStmt) -> RS { + let stmt_list = self.parser.parse(&sql.to_sql_string())?; + let mut stmts = stmt_list.into_stmts(); + if stmts.len() != 1 { + return Err(m_error!(EC::ParseErr, "expected exactly one statement")); + } + Ok(stmts.remove(0)) + } + + pub fn parse_many(&self, sql: &dyn SQLStmt) -> RS> { + Ok(self.parser.parse(&sql.to_sql_string())?.into_stmts()) + } + + pub async fn describe_stmt(&self, stmt: StmtType) -> RS> { + let desc = Describer::describe(self.meta_mgr.as_ref(), stmt).await?; + Ok(Arc::new(desc)) + } + + pub async fn query( + &self, + stmt: StmtType, + params: Box, + tx_mgr: Arc, + x_contract: Arc, + ) -> RS> { + let (rows, desc) = self.query_rows(stmt, params, tx_mgr, x_contract).await?; + Ok(Arc::new(MuduResultSetAsync::from_rows(rows, desc))) + } + + pub async fn query_rows( + &self, + stmt: StmtType, + params: Box, + tx_mgr: Arc, + x_contract: Arc, + ) -> RS<(Vec, TupleFieldDesc)> { + self.query_inner(stmt, params, tx_mgr, x_contract).await + } + + pub async fn execute( + &self, + stmt: StmtType, + params: Box, + tx_mgr: Arc, + x_contract: Arc, + ) -> RS { + self.execute_inner(stmt, params, tx_mgr, x_contract).await + } + + async fn query_inner( + &self, + stmt: StmtType, + params: Box, + tx_mgr: Arc, + x_contract: Arc, + ) -> RS<(Vec, TupleFieldDesc)> { + let bound = Binder::new(self.meta_mgr.clone()) + .bind(stmt, params.as_ref()) + .await?; + let BoundStmt::Query(bound_query) = bound else { + return Err(m_error!(EC::TypeErr, "statement is not a query")); + }; + let planner = Planner::new(PlanCtx { + tx_mgr, + meta_mgr: self.meta_mgr.clone(), + x_contract, + }); + let exec = planner.plan_query(bound_query).await?; + query_exec_to_rows(exec).await + } + + async fn execute_inner( + &self, + stmt: StmtType, + params: Box, + tx_mgr: Arc, + x_contract: Arc, + ) -> RS { + let bound = Binder::new(self.meta_mgr.clone()) + .bind(stmt, params.as_ref()) + .await?; + let BoundStmt::Command(bound_command) = bound else { + return Err(m_error!(EC::TypeErr, "statement is not a command")); + }; + let planner = Planner::new(PlanCtx { + tx_mgr, + meta_mgr: self.meta_mgr.clone(), + x_contract, + }); + let cmd = planner.plan_command(bound_command).await?; + cmd.prepare().await?; + cmd.run().await?; + cmd.affected_rows().await + } +} + +pub async fn query_exec_to_rows(exec: Arc) -> RS<(Vec, TupleFieldDesc)> { + exec.open().await?; + let desc = exec.tuple_desc()?; + let mut rows = Vec::new(); + while let Some(row) = exec.next().await? { + rows.push(tuple_field_to_value(row, &desc)?); + } + Ok((rows, desc)) +} + +fn tuple_field_to_value( + row: mudu_contract::tuple::tuple_field::TupleField, + desc: &TupleFieldDesc, +) -> RS { + let mut values = Vec::with_capacity(row.fields().len()); + for (index, field) in row.fields().iter().enumerate() { + let datum_desc = &desc.fields()[index]; + let typed = TypedBin::new(datum_desc.dat_type_id(), field.clone()); + values.push(typed.to_value(datum_desc.dat_type())?); + } + Ok(TupleValue::from(values)) +} diff --git a/mudu_kernel/src/mudu_conn/mudu_prepared_stmt.rs b/mudu_kernel/src/mudu_conn/mudu_prepared_stmt.rs new file mode 100644 index 0000000..fccfce7 --- /dev/null +++ b/mudu_kernel/src/mudu_conn/mudu_prepared_stmt.rs @@ -0,0 +1,63 @@ +use async_trait::async_trait; +use mudu::common::id::OID; +use mudu::common::result::RS; +use mudu_contract::database::prepared_stmt::PreparedStmt; +use mudu_contract::database::result_set::ResultSetAsync; +use mudu_contract::database::sql_params::SQLParams; +use mudu_contract::database::sql_stmt::SQLStmt; +use mudu_contract::tuple::tuple_field_desc::TupleFieldDesc; +use std::sync::Arc; +use tokio::sync::Mutex; + +use crate::server::worker_local::WorkerLocalRef; + +pub struct MuduPreparedStmt { + worker_local: WorkerLocalRef, + session_id: Arc>>, + sql: Box, + desc: Arc, +} + +impl MuduPreparedStmt { + pub fn new( + worker_local: WorkerLocalRef, + session_id: Arc>>, + sql: Box, + desc: Arc, + ) -> Self { + Self { + worker_local, + session_id, + sql, + desc, + } + } + + async fn current_oid(&self) -> OID { + let guard = self.session_id.lock().await; + guard.unwrap_or(0) + } +} + +#[async_trait] +impl PreparedStmt for MuduPreparedStmt { + async fn query(&self, params: Box) -> RS> { + self.worker_local + .query(self.current_oid().await, self.sql.clone_boxed(), params) + .await + } + + async fn execute(&self, params: Box) -> RS { + self.worker_local + .execute(self.current_oid().await, self.sql.clone_boxed(), params) + .await + } + + async fn desc(&self) -> RS> { + Ok(self.desc.clone()) + } + + async fn reset(&self) -> RS<()> { + Ok(()) + } +} diff --git a/mudu_runtime/src/backend/mudu_result_set_async.rs b/mudu_kernel/src/mudu_conn/mudu_result_set_async.rs similarity index 74% rename from mudu_runtime/src/backend/mudu_result_set_async.rs rename to mudu_kernel/src/mudu_conn/mudu_result_set_async.rs index eb8b8b6..ef22424 100644 --- a/mudu_runtime/src/backend/mudu_result_set_async.rs +++ b/mudu_kernel/src/mudu_conn/mudu_result_set_async.rs @@ -1,4 +1,3 @@ -use crate::backend::mudu_conn_core::query_exec_to_rows; use async_trait::async_trait; use mudu::common::result::RS; use mudu_contract::database::result_set::ResultSetAsync; @@ -7,6 +6,8 @@ use mudu_contract::tuple::tuple_value::TupleValue; use std::sync::Arc; use tokio::sync::Mutex; +use crate::contract::query_exec::QueryExec; + pub struct MuduResultSetAsync { desc: Arc, inner: Mutex, @@ -18,14 +19,16 @@ struct ResultRows { } impl MuduResultSetAsync { - pub async fn from_query_exec( - exec: Arc, - ) -> RS { - let (rows, desc) = query_exec_to_rows(exec).await?; - Ok(Self { + pub fn from_rows(rows: Vec, desc: TupleFieldDesc) -> Self { + Self { desc: Arc::new(desc), inner: Mutex::new(ResultRows { rows, index: 0 }), - }) + } + } + + pub async fn from_query_exec(exec: Arc) -> RS { + let (rows, desc) = super::mudu_conn_core::query_exec_to_rows(exec).await?; + Ok(Self::from_rows(rows, desc)) } } diff --git a/mudu_kernel/src/server/frame_dispatch.rs b/mudu_kernel/src/server/frame_dispatch.rs index 7e79aae..cccd215 100644 --- a/mudu_kernel/src/server/frame_dispatch.rs +++ b/mudu_kernel/src/server/frame_dispatch.rs @@ -8,9 +8,7 @@ use crate::server::worker::IoUringWorker; use mudu::common::result::RS; use mudu::error::ec::EC; use mudu::m_error; -use mudu_contract::protocol::{ - decode_client_request, encode_server_response, Frame, MessageType, ServerResponse, HEADER_LEN, -}; +use mudu_contract::protocol::{Frame, MessageType, HEADER_LEN}; pub fn try_decode_next_frame(buf: &[u8]) -> RS> { if buf.len() < HEADER_LEN { @@ -39,24 +37,12 @@ pub async fn dispatch_frame_async( return result; } match frame.header().message_type() { - MessageType::Query | MessageType::Execute => { - let request = decode_client_request(frame)?; - Ok(HandleResult::Response(encode_server_response( - frame.header().request_id(), - &ServerResponse::new( - vec![], - vec![], - 0, - Some(format!( - "SQL interface is disabled in the client backend for app '{}'", - request.app_name() - )), - ), - )?)) - } MessageType::Get | MessageType::Put | MessageType::RangeScan + | MessageType::Query + | MessageType::Execute + | MessageType::Batch | MessageType::ProcedureInvoke | MessageType::SessionCreate | MessageType::SessionClose => unreachable!(), diff --git a/mudu_kernel/src/server/handlers/batch.rs b/mudu_kernel/src/server/handlers/batch.rs new file mode 100644 index 0000000..6c1494b --- /dev/null +++ b/mudu_kernel/src/server/handlers/batch.rs @@ -0,0 +1,22 @@ +use async_trait::async_trait; +use mudu::common::result::RS; +use mudu_contract::protocol::{decode_client_request, Frame, MessageType}; + +use crate::server::async_func_task::HandleResult; +use crate::server::message_dispatcher::MessageHandler; +use crate::server::request_ctx::RequestCtx; + +pub(in crate::server) struct BatchHandler; + +#[async_trait] +impl MessageHandler for BatchHandler { + fn message_type(&self) -> MessageType { + MessageType::Batch + } + + async fn handle(&self, ctx: &RequestCtx, frame: &Frame) -> RS { + let request = decode_client_request(frame)?; + ctx.batch(request.oid() as _, request.app_name(), request.sql()) + .await + } +} diff --git a/mudu_kernel/src/server/handlers/execute.rs b/mudu_kernel/src/server/handlers/execute.rs new file mode 100644 index 0000000..792b301 --- /dev/null +++ b/mudu_kernel/src/server/handlers/execute.rs @@ -0,0 +1,22 @@ +use async_trait::async_trait; +use mudu::common::result::RS; +use mudu_contract::protocol::{decode_client_request, Frame, MessageType}; + +use crate::server::async_func_task::HandleResult; +use crate::server::message_dispatcher::MessageHandler; +use crate::server::request_ctx::RequestCtx; + +pub(in crate::server) struct ExecuteHandler; + +#[async_trait] +impl MessageHandler for ExecuteHandler { + fn message_type(&self) -> MessageType { + MessageType::Execute + } + + async fn handle(&self, ctx: &RequestCtx, frame: &Frame) -> RS { + let request = decode_client_request(frame)?; + ctx.execute_sql(request.oid() as _, request.app_name(), request.sql()) + .await + } +} diff --git a/mudu_kernel/src/server/handlers/mod.rs b/mudu_kernel/src/server/handlers/mod.rs index 4bd792d..5f4c6ff 100644 --- a/mudu_kernel/src/server/handlers/mod.rs +++ b/mudu_kernel/src/server/handlers/mod.rs @@ -1,13 +1,19 @@ +mod batch; +mod execute; mod get; mod procedure_invoke; mod put; +mod query; mod range_scan; mod session_close; mod session_create; +pub(in crate::server) use batch::BatchHandler; +pub(in crate::server) use execute::ExecuteHandler; pub(in crate::server) use get::GetHandler; pub(in crate::server) use procedure_invoke::ProcedureInvokeHandler; pub(in crate::server) use put::PutHandler; +pub(in crate::server) use query::QueryHandler; pub(in crate::server) use range_scan::RangeScanHandler; pub(in crate::server) use session_close::SessionCloseHandler; pub(in crate::server) use session_create::SessionCreateHandler; diff --git a/mudu_kernel/src/server/handlers/query.rs b/mudu_kernel/src/server/handlers/query.rs new file mode 100644 index 0000000..9f450b1 --- /dev/null +++ b/mudu_kernel/src/server/handlers/query.rs @@ -0,0 +1,22 @@ +use async_trait::async_trait; +use mudu::common::result::RS; +use mudu_contract::protocol::{decode_client_request, Frame, MessageType}; + +use crate::server::async_func_task::HandleResult; +use crate::server::message_dispatcher::MessageHandler; +use crate::server::request_ctx::RequestCtx; + +pub(in crate::server) struct QueryHandler; + +#[async_trait] +impl MessageHandler for QueryHandler { + fn message_type(&self) -> MessageType { + MessageType::Query + } + + async fn handle(&self, ctx: &RequestCtx, frame: &Frame) -> RS { + let request = decode_client_request(frame)?; + ctx.query(request.oid() as _, request.app_name(), request.sql()) + .await + } +} diff --git a/mudu_kernel/src/server/message_dispatcher.rs b/mudu_kernel/src/server/message_dispatcher.rs index de6bce6..63a6900 100644 --- a/mudu_kernel/src/server/message_dispatcher.rs +++ b/mudu_kernel/src/server/message_dispatcher.rs @@ -6,8 +6,8 @@ use mudu_contract::protocol::{Frame, MessageType}; use crate::server::async_func_task::HandleResult; use crate::server::handlers::{ - GetHandler, ProcedureInvokeHandler, PutHandler, RangeScanHandler, SessionCloseHandler, - SessionCreateHandler, + BatchHandler, ExecuteHandler, GetHandler, ProcedureInvokeHandler, PutHandler, QueryHandler, + RangeScanHandler, SessionCloseHandler, SessionCreateHandler, }; use crate::server::request_ctx::RequestCtx; @@ -29,6 +29,9 @@ impl MessageDispatcher { fn new() -> Self { let mut handlers: Vec<(MessageType, Box)> = Vec::new(); + register(&mut handlers, Box::new(QueryHandler)); + register(&mut handlers, Box::new(ExecuteHandler)); + register(&mut handlers, Box::new(BatchHandler)); register(&mut handlers, Box::new(GetHandler)); register(&mut handlers, Box::new(PutHandler)); register(&mut handlers, Box::new(RangeScanHandler)); diff --git a/mudu_kernel/src/server/perf_test.rs b/mudu_kernel/src/server/perf_test.rs index 0cd3d41..088b8cd 100644 --- a/mudu_kernel/src/server/perf_test.rs +++ b/mudu_kernel/src/server/perf_test.rs @@ -1,7 +1,6 @@ use crate::server::routing::{route_worker, RoutingContext, RoutingMode}; use crate::server::server::{IoUringTcpBackend, IoUringTcpServerConfig}; use crate::server::worker_registry::{load_or_create_worker_registry, WorkerRegistry}; -use tracing::log::info; use mudu::common::result::RS; use mudu::error::ec::EC; use mudu::error::err::MError; @@ -22,7 +21,7 @@ use std::net::{Ipv4Addr, SocketAddr, TcpListener}; use std::ops::RangeInclusive; use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; use std::sync::mpsc::{self, Receiver, TryRecvError}; -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, Mutex, OnceLock}; use std::thread; use std::time::Duration; use tokio::io::{AsyncReadExt, AsyncWriteExt}; @@ -30,6 +29,7 @@ use tokio::net::{TcpSocket as TokioTcpSocket, TcpStream as TokioTcpStream}; use tokio::sync::Notify; use tokio::task::JoinSet; use tracing::debug; +use tracing::info; use uuid::Uuid; struct AsyncPerfClient { @@ -187,6 +187,11 @@ fn reserve_listener() -> Option { None } +fn network_perf_test_lock() -> &'static Mutex<()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())) +} + fn bind_reserved_listener(port: u16) -> std::io::Result { let listener = TcpListener::bind(("127.0.0.1", port))?; listener.set_nonblocking(true)?; @@ -400,6 +405,7 @@ fn avg_us(samples: &[u64]) -> Option { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn iouring_backend_perf_put_get() -> RS<()> { + let _guard = network_perf_test_lock().lock().unwrap(); log_setup("info"); let notifier = NotifyWait::new(); { @@ -599,6 +605,7 @@ async fn iouring_backend_perf_put_get() -> RS<()> { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn iouring_backend_recovery_replays_worker_logs() -> RS<()> { + let _guard = network_perf_test_lock().lock().unwrap(); let Some(listener) = reserve_listener() else { return Ok(()); }; @@ -674,6 +681,7 @@ async fn iouring_backend_recovery_replays_worker_logs() -> RS<()> { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn iouring_backend_recovery_replays_across_multiple_chunks() -> RS<()> { + let _guard = network_perf_test_lock().lock().unwrap(); let Some(listener) = reserve_listener() else { return Ok(()); }; @@ -738,6 +746,7 @@ async fn iouring_backend_recovery_replays_across_multiple_chunks() -> RS<()> { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn iouring_backend_open_session_routes_connection_to_requested_partition() -> RS<()> { + let _guard = network_perf_test_lock().lock().unwrap(); let Some(listener) = reserve_listener() else { return Ok(()); }; @@ -814,6 +823,7 @@ async fn iouring_backend_open_session_routes_connection_to_requested_partition() #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn iouring_backend_open_session_rebind_keeps_same_session_id() -> RS<()> { + let _guard = network_perf_test_lock().lock().unwrap(); let Some(listener) = reserve_listener() else { return Ok(()); }; diff --git a/mudu_kernel/src/server/request_ctx.rs b/mudu_kernel/src/server/request_ctx.rs index c17141b..bf39579 100644 --- a/mudu_kernel/src/server/request_ctx.rs +++ b/mudu_kernel/src/server/request_ctx.rs @@ -1,11 +1,16 @@ use mudu::common::id::OID; use mudu::common::result::RS; +use mudu::error::ec::EC; +use mudu::m_error; +use mudu_contract::database::result_set::ResultSetAsync; use mudu_contract::protocol::{ encode_get_response, encode_procedure_invoke_response, encode_put_response, encode_range_scan_response, encode_session_close_response, encode_session_create_response, - GetResponse, KeyValue, ProcedureInvokeResponse, PutResponse, RangeScanResponse, + encode_server_response, GetResponse, KeyValue, ProcedureInvokeResponse, PutResponse, + RangeScanResponse, ServerResponse, SessionCloseResponse, SessionCreateResponse, }; +use mudu_type::datum::DatumDyn; use std::sync::Arc; use crate::server::async_func_task::HandleResult; @@ -120,6 +125,51 @@ impl RequestCtx { )?)) } + pub(in crate::server) async fn query( + &self, + oid: OID, + app_name: &str, + sql: &str, + ) -> RS { + let _ = app_name; + let response = self + .worker + .query(oid, Box::new(sql.to_string()), Box::new(())) + .await?; + let response = Self::query_response(response).await?; + self.encode_server_response(response) + } + + pub(in crate::server) async fn execute_sql( + &self, + oid: OID, + app_name: &str, + sql: &str, + ) -> RS { + let _ = app_name; + let affected_rows = self + .worker + .execute(oid, Box::new(sql.to_string()), Box::new(())) + .await?; + let response = ServerResponse::new(Vec::new(), Vec::new(), affected_rows, None); + self.encode_server_response(response) + } + + pub(in crate::server) async fn batch( + &self, + oid: OID, + app_name: &str, + sql: &str, + ) -> RS { + let _ = app_name; + let affected_rows = self + .worker + .batch(oid, Box::new(sql.to_string()), Box::new(())) + .await?; + let response = ServerResponse::new(Vec::new(), Vec::new(), affected_rows, None); + self.encode_server_response(response) + } + pub(in crate::server) async fn session_create( &self, config: SessionOpenConfig, @@ -155,4 +205,32 @@ impl RequestCtx { ), )?)) } + + fn encode_server_response(&self, response: ServerResponse) -> RS { + Ok(HandleResult::Response(encode_server_response( + self.request_id, + &response, + )?)) + } + + async fn query_response(result_set: Arc) -> RS { + let desc = result_set.desc(); + let columns = desc + .fields() + .iter() + .map(|field| field.name().to_string()) + .collect(); + let mut rows = Vec::new(); + while let Some(row) = result_set.next().await? { + if row.values().len() != desc.fields().len() { + return Err(m_error!(EC::FatalError, "non consistent column number")); + } + let mut values = Vec::with_capacity(row.values().len()); + for (value, field_desc) in row.values().iter().zip(desc.fields().iter()) { + values.push(value.to_textual(field_desc.dat_type())?.into()); + } + rows.push(values); + } + Ok(ServerResponse::new(columns, rows, 0, None)) + } } diff --git a/mudu_kernel/src/server/session_bound_worker_runtime.rs b/mudu_kernel/src/server/session_bound_worker_runtime.rs index be18adb..b69316c 100644 --- a/mudu_kernel/src/server/session_bound_worker_runtime.rs +++ b/mudu_kernel/src/server/session_bound_worker_runtime.rs @@ -4,14 +4,20 @@ use crate::server::worker::IoUringWorker; use crate::server::worker_local::{WorkerExecute, WorkerLocal, WorkerLocalRef}; use crate::server::worker_registry::WorkerRegistry; use crate::server::worker_snapshot::KvItem; +use crate::contract::meta_mgr::MetaMgr; use async_trait::async_trait; use mudu::common::id::OID; use mudu::common::result::RS; use mudu::error::ec::EC; use mudu::m_error; +use mudu_contract::database::result_set::ResultSetAsync; +use mudu_contract::database::sql_params::SQLParams; +use mudu_contract::database::sql_stmt::SQLStmt; use mudu_contract::protocol::{ProcedureInvokeRequest, ProcedureInvokeResponse}; use std::sync::Arc; +use crate::x_engine::api::XContract; + struct SessionBoundWorkerRuntime { worker: Arc, current_session_id: OID, @@ -33,6 +39,14 @@ pub(crate) fn as_worker_local_ref(worker: WorkerRuntimeRef) -> WorkerLocalRef { #[async_trait] impl WorkerLocal for SessionBoundWorkerRuntime { + fn x_contract(&self) -> Arc { + self.worker.x_contract() + } + + fn meta_mgr(&self) -> Arc { + self.worker.meta_mgr() + } + async fn open_async(&self) -> RS { self.worker.open_session(self.current_session_id) } @@ -71,7 +85,7 @@ impl WorkerLocal for SessionBoundWorkerRuntime { } async fn get_async(&self, session_id: OID, key: &[u8]) -> RS>> { - self.worker.get_in_session(session_id, key) + self.worker.get_in_session(session_id, key).await } async fn range_async( @@ -80,7 +94,34 @@ impl WorkerLocal for SessionBoundWorkerRuntime { start_key: &[u8], end_key: &[u8], ) -> RS> { - self.worker.range_in_session(session_id, start_key, end_key) + self.worker.range_in_session(session_id, start_key, end_key).await + } + + async fn query( + &self, + oid: OID, + sql: Box, + param: Box, + ) -> RS> { + self.worker.query(oid, sql, param).await + } + + async fn execute( + &self, + oid: OID, + sql: Box, + param: Box, + ) -> RS { + self.worker.execute(oid, sql, param).await + } + + async fn batch( + &self, + oid: OID, + sql: Box, + param: Box, + ) -> RS { + self.worker.batch(oid, sql, param).await } } diff --git a/mudu_kernel/src/server/worker.rs b/mudu_kernel/src/server/worker.rs index d9647b0..bb562d3 100644 --- a/mudu_kernel/src/server/worker.rs +++ b/mudu_kernel/src/server/worker.rs @@ -1,3 +1,5 @@ +use crate::mudu_conn::mudu_conn_core::MuduConnCore; +use crate::contract::meta_mgr::MetaMgr; use crate::server::async_func_runtime::AsyncFuncInvokerPtr; use crate::server::routing::{ route_worker, RoutingContext, RoutingMode, SessionOpenConfig, SessionOpenTransferAction, @@ -9,14 +11,18 @@ use crate::server::worker_local::{WorkerExecute, WorkerLocalRef}; use crate::server::worker_registry::{WorkerIdentity, WorkerRegistry}; use crate::server::worker_session_manager::{SessionContext, WorkerSessionManager}; use crate::server::worker_snapshot::KvItem; -use crate::server::worker_tx_manager::WorkerTxManager; use crate::server::x_contract::IoUringXContract; use crate::wal::worker_log::{ChunkedWorkerLogBackend, WorkerLogBatching, WorkerLogLayout}; use crate::wal::xl_batch::XLBatch; +use crate::x_engine::api::XContract; +use crate::x_engine::tx_mgr::TxMgr; use mudu::common::id::OID; use mudu::common::result::RS; use mudu::error::ec::EC; use mudu::m_error; +use mudu_contract::database::result_set::ResultSetAsync; +use mudu_contract::database::sql_params::SQLParams; +use mudu_contract::database::sql_stmt::SQLStmt; use mudu_contract::protocol::{ProcedureInvokeRequest, ProcedureInvokeResponse}; use std::collections::BTreeMap; use std::net::SocketAddr; @@ -29,6 +35,12 @@ use std::sync::Arc; /// The `IoUringWorker` name is also historical. The type is shared by both the /// Linux native `io_uring` loop and the non-Linux fallback loop so upper /// layers do not need target-specific worker abstractions. +/// +/// Workers are sized around execution resources such as CPU cores, while +/// partitions are derived from user-defined data partitioning. The system does +/// not require partitions to map one-to-one to workers, although the current +/// runtime path still operates on a single active partition per worker. A +/// worker may own multiple partitions in the future. pub struct IoUringWorker { worker_index: usize, worker_id: OID, @@ -78,6 +90,8 @@ impl IoUringWorker { registry: Arc, ) -> RS { let active_sessions = Arc::new(AtomicUsize::new(0)); + // The runtime currently activates only the first partition assigned to + // this worker, while preserving `partition_ids` for future multi-partition support. let partition_id = identity.partition_ids.first().copied().ok_or_else(|| { m_error!( EC::ParseErr, @@ -91,20 +105,25 @@ impl IoUringWorker { log_layout.clone(), active_sessions.clone(), )?; + let contract = Arc::new(IoUringXContract::with_worker_log_and_data_dir( + log, + partition_id, + data_dir, + )); + let session_manager = Arc::new(WorkerSessionManager::new( + active_sessions, + contract.meta_mgr(), + )); Ok(Self { worker_index: identity.worker_index, worker_id, partition_ids: identity.partition_ids, worker_count, routing_mode, - contract: Arc::new(IoUringXContract::with_worker_log_and_data_dir( - log, - partition_id, - data_dir, - )), + contract: contract.clone(), log_layout, procedure_runtime, - session_manager: Arc::new(WorkerSessionManager::new(active_sessions)), + session_manager, registry, }) } @@ -167,14 +186,14 @@ impl IoUringWorker { self.session_manager.session_context(session_id) } - pub fn get_for_connection( + pub async fn get_for_connection( &self, conn_id: u64, session_id: OID, key: &[u8], ) -> RS>> { self.ensure_session_owned_by_connection(conn_id, session_id)?; - self.get_in_session(session_id, key) + self.get_in_session(session_id, key).await } pub fn put_for_connection( @@ -199,7 +218,7 @@ impl IoUringWorker { self.put_in_session_async(session_id, key, value).await } - pub fn range_for_connection( + pub async fn range_for_connection( &self, conn_id: u64, session_id: OID, @@ -207,47 +226,22 @@ impl IoUringWorker { end_key: &[u8], ) -> RS> { self.ensure_session_owned_by_connection(conn_id, session_id)?; - self.range_in_session(session_id, start_key, end_key) + self.range_in_session(session_id, start_key, end_key).await } #[allow(dead_code)] fn execute_tx(&self, session_id: OID, instruction: WorkerExecute) -> RS<()> { - let session = self.session_context(session_id)?; match instruction { - WorkerExecute::BeginTx => { - if session.tx_manager_ref().is_some() { - return Err(m_error!( - EC::ExistingSuchElement, - format!("session {} already has an active transaction", session_id) - )); - } - session - .set_tx_manager(Some(WorkerTxManager::new(self.contract.worker_begin_tx()?))); - Ok(()) - } + WorkerExecute::BeginTx => self + .session_manager + .begin_session_tx(session_id, self.contract.worker_begin_tx()?), WorkerExecute::CommitTx => { - let tx_manager = session.take_tx_manager().ok_or_else(|| { - m_error!( - EC::NoSuchElement, - format!("session {} has no active transaction", session_id) - ) - })?; - let snapshot = tx_manager.snapshot().clone(); - let xid = tx_manager.xid(); - let items = tx_manager.staged_put_items(); - let batch = tx_manager.into_xl_batch(); - self.contract - .worker_commit_put_batch(&snapshot, xid, items, batch)?; - Ok(()) + let tx_manager = self.session_manager.take_session_tx(session_id)?; + self.contract.worker_commit_tx(tx_manager) } WorkerExecute::RollbackTx => { - let tx_manager = session.take_tx_manager().ok_or_else(|| { - m_error!( - EC::NoSuchElement, - format!("session {} has no active transaction", session_id) - ) - })?; - self.contract.worker_rollback_tx(tx_manager.xid())?; + let tx_manager = self.session_manager.take_session_tx(session_id)?; + self.contract.worker_rollback_tx(tx_manager)?; Ok(()) } } @@ -258,56 +252,32 @@ impl IoUringWorker { session_id: OID, instruction: WorkerExecute, ) -> RS<()> { - let session = self.session_context(session_id)?; match instruction { - WorkerExecute::BeginTx => { - if session.tx_manager_ref().is_some() { - return Err(m_error!( - EC::ExistingSuchElement, - format!("session {} already has an active transaction", session_id) - )); - } - session - .set_tx_manager(Some(WorkerTxManager::new(self.contract.worker_begin_tx()?))); - Ok(()) - } + WorkerExecute::BeginTx => self + .session_manager + .begin_session_tx(session_id, self.contract.worker_begin_tx()?), WorkerExecute::CommitTx => { - let tx_manager = session.take_tx_manager().ok_or_else(|| { - m_error!( - EC::NoSuchElement, - format!("session {} has no active transaction", session_id) - ) - })?; - let snapshot = tx_manager.snapshot().clone(); - let xid = tx_manager.xid(); - let items = tx_manager.staged_put_items(); - let batch = tx_manager.into_xl_batch(); - self.contract - .worker_commit_put_batch_async(&snapshot, xid, items, batch) - .await + let tx_manager = self.session_manager.take_session_tx(session_id)?; + self.contract.worker_commit_tx_async(tx_manager).await } WorkerExecute::RollbackTx => { - let tx_manager = session.take_tx_manager().ok_or_else(|| { - m_error!( - EC::NoSuchElement, - format!("session {} has no active transaction", session_id) - ) - })?; - self.contract.worker_rollback_tx(tx_manager.xid())?; + let tx_manager = self.session_manager.take_session_tx(session_id)?; + self.contract.worker_rollback_tx(tx_manager)?; Ok(()) } } } fn put_in_session(&self, session_id: OID, key: Vec, value: Vec) -> RS<()> { - let session = self.session_context(session_id)?; - match session.tx_manager_mut().as_mut() { - Some(tx_manager) => { - tx_manager.put(key, value); - Ok(()) + self.session_manager.with_session_tx(session_id, |tx_manager| { + match tx_manager { + Some(tx_manager) => { + tx_manager.put(key, value); + Ok(()) + } + None => self.contract.worker_put(key, value), } - None => self.contract.worker_put(key, value), - } + }) } pub(crate) async fn put_in_session_async( @@ -316,65 +286,86 @@ impl IoUringWorker { key: Vec, value: Vec, ) -> RS<()> { - let session = self.session_context(session_id)?; - match session.tx_manager_mut().as_mut() { - Some(tx_manager) => { - tx_manager.put(key, value); - Ok(()) - } - None => self.contract.worker_put_async(key, value).await, + let handled = self + .session_manager + .with_session_tx(session_id, |tx_manager| match tx_manager { + Some(tx_manager) => { + tx_manager.put(key.clone(), value.clone()); + Ok(true) + } + None => Ok(false), + })?; + if handled { + Ok(()) + } else { + self.contract.worker_put_async(key, value).await } } pub(crate) async fn delete_in_session_async(&self, session_id: OID, key: &[u8]) -> RS<()> { - let session = self.session_context(session_id)?; - match session.tx_manager_mut().as_mut() { - Some(tx_manager) => { - tx_manager.delete(key.to_vec()); - Ok(()) - } - None => self.contract.worker_delete_async(key).await, + let key_vec = key.to_vec(); + let handled = self + .session_manager + .with_session_tx(session_id, |tx_manager| match tx_manager { + Some(tx_manager) => { + tx_manager.delete(key_vec.clone()); + Ok(true) + } + None => Ok(false), + })?; + if handled { + Ok(()) + } else { + self.contract.worker_delete_async(key).await } } - pub(crate) fn get_in_session(&self, session_id: OID, key: &[u8]) -> RS>> { - let session = self.session_context(session_id)?; - let staged = session - .tx_manager_ref() + pub(crate) async fn get_in_session(&self, session_id: OID, key: &[u8]) -> RS>> { + let tx_manager = self + .session_manager + .with_session_tx(session_id, |tx_manager| Ok(tx_manager))?; + let staged = tx_manager .as_ref() .and_then(|tx_manager| tx_manager.get(key)); match staged { Some(value) => Ok(value), - None => match session.tx_manager_ref().as_ref() { - Some(tx_manager) => self - .contract - .worker_get_with_snapshot(tx_manager.snapshot(), key), - None => self.contract.worker_get(key), + None => match tx_manager { + Some(tx_manager) => { + self.contract + .worker_get_with_snapshot_async(&tx_manager.snapshot(), key) + .await + } + None => self.contract.worker_get_async(key).await, }, } } - pub(crate) fn range_in_session( + pub(crate) async fn range_in_session( &self, session_id: OID, start_key: &[u8], end_key: &[u8], ) -> RS> { - let session = self.session_context(session_id)?; - let staged = session - .tx_manager_ref() + let tx_manager = self + .session_manager + .with_session_tx(session_id, |tx_manager| Ok(tx_manager))?; + let staged = tx_manager .as_ref() .map(|tx_manager| tx_manager.staged_items_in_range(start_key, end_key)) .unwrap_or_default(); let mut merged = BTreeMap::new(); - let base_items = match session.tx_manager_ref().as_ref() { - Some(tx_manager) => self.contract.worker_range_scan_with_snapshot( - tx_manager.snapshot(), - start_key, - end_key, - )?, - None => self.contract.worker_range_scan(start_key, end_key)?, + let base_items = match tx_manager { + Some(tx_manager) => { + self.contract + .worker_range_scan_with_snapshot_async( + &tx_manager.snapshot(), + start_key, + end_key, + ) + .await? + } + None => self.contract.worker_range_scan_async(start_key, end_key).await?, }; for item in base_items { merged.insert(item.key, Some(item.value)); @@ -441,6 +432,198 @@ impl IoUringWorker { self.contract.worker_log() } + pub fn x_contract(&self) -> Arc { + self.contract.clone() + } + + pub fn meta_mgr(&self) -> Arc { + self.contract.meta_mgr() + } + + fn sql_core(&self, oid: OID) -> RS> { + if oid == 0 { + return Ok(Arc::new(MuduConnCore::new(self.meta_mgr()))); + } + Ok(self.session_context(oid)?.mudu_conn_core()) + } + + fn sql_tx_mgr(&self, oid: OID) -> RS>> { + if oid == 0 { + return Ok(None); + } + self.session_manager + .with_session_tx(oid, |tx_manager| Ok(tx_manager)) + } + + async fn run_sql_query_with_tx( + &self, + core: Arc, + stmt: Box, + param: Box, + tx_mgr: Arc, + ) -> RS> { + let stmt = core.parse_one(stmt.as_ref())?; + core.query(stmt, param, tx_mgr, self.contract.clone()).await + } + + async fn run_sql_execute_with_tx( + &self, + core: Arc, + stmt: Box, + param: Box, + tx_mgr: Arc, + ) -> RS { + let stmt = core.parse_one(stmt.as_ref())?; + core.execute(stmt, param, tx_mgr, self.contract.clone()).await + } + + pub(crate) async fn query( + &self, + oid: OID, + sql: Box, + param: Box, + ) -> RS> { + let core = self.sql_core(oid)?; + if oid == 0 { + let tx_mgr = self.contract.begin_tx().await?; + let result = self + .run_sql_query_with_tx(core, sql, param, tx_mgr.clone()) + .await; + if result.is_ok() { + self.contract.commit_tx(tx_mgr).await?; + } else { + self.contract.abort_tx(tx_mgr).await?; + } + return result; + } + let started_tx = if self.session_manager.has_session_tx(oid)? { + false + } else { + self.session_manager + .begin_session_tx(oid, self.contract.worker_begin_tx()?)?; + true + }; + let tx_mgr = self + .sql_tx_mgr(oid)? + .ok_or_else(|| m_error!(EC::InternalErr, "session transaction is missing"))?; + let result = self.run_sql_query_with_tx(core, sql, param, tx_mgr).await; + if started_tx { + let tx_manager = self.session_manager.take_session_tx(oid)?; + if result.is_ok() { + self.contract.worker_commit_tx_async(tx_manager).await?; + } else { + self.contract.worker_rollback_tx(tx_manager)?; + } + } + result + } + + pub(crate) async fn execute( + &self, + oid: OID, + sql: Box, + param: Box, + ) -> RS { + let core = self.sql_core(oid)?; + if oid == 0 { + let tx_mgr = self.contract.begin_tx().await?; + let result = self + .run_sql_execute_with_tx(core, sql, param, tx_mgr.clone()) + .await; + if result.is_ok() { + self.contract.commit_tx(tx_mgr).await?; + } else { + self.contract.abort_tx(tx_mgr).await?; + } + return result; + } + let started_tx = if self.session_manager.has_session_tx(oid)? { + false + } else { + self.session_manager + .begin_session_tx(oid, self.contract.worker_begin_tx()?)?; + true + }; + let tx_mgr = self + .sql_tx_mgr(oid)? + .ok_or_else(|| m_error!(EC::InternalErr, "session transaction is missing"))?; + let result = self.run_sql_execute_with_tx(core, sql, param, tx_mgr).await; + if started_tx { + let tx_manager = self.session_manager.take_session_tx(oid)?; + if result.is_ok() { + self.contract.worker_commit_tx_async(tx_manager).await?; + } else { + self.contract.worker_rollback_tx(tx_manager)?; + } + } + result + } + + pub(crate) async fn batch( + &self, + oid: OID, + sql: Box, + param: Box, + ) -> RS { + if param.size() != 0 { + return Err(m_error!( + EC::NotImplemented, + "batch with parameters is not implemented" + )); + } + let core = self.sql_core(oid)?; + let stmts = core.parse_many(sql.as_ref())?; + if oid == 0 { + let tx_mgr = self.contract.begin_tx().await?; + let mut total = 0; + for stmt in stmts { + match core + .execute(stmt, Box::new(()), tx_mgr.clone(), self.contract.clone()) + .await + { + Ok(affected) => total += affected, + Err(err) => { + self.contract.abort_tx(tx_mgr).await?; + return Err(err); + } + } + } + self.contract.commit_tx(tx_mgr).await?; + return Ok(total); + } + let started_tx = if self.session_manager.has_session_tx(oid)? { + false + } else { + self.session_manager + .begin_session_tx(oid, self.contract.worker_begin_tx()?)?; + true + }; + let tx_mgr = self + .sql_tx_mgr(oid)? + .ok_or_else(|| m_error!(EC::InternalErr, "session transaction is missing"))?; + let mut total = 0; + for stmt in stmts { + match core + .execute(stmt, Box::new(()), tx_mgr.clone(), self.contract.clone()) + .await + { + Ok(affected) => total += affected, + Err(err) => { + if started_tx { + let tx_manager = self.session_manager.take_session_tx(oid)?; + self.contract.worker_rollback_tx(tx_manager)?; + } + return Err(err); + } + } + } + if started_tx { + let tx_manager = self.session_manager.take_session_tx(oid)?; + self.contract.worker_commit_tx_async(tx_manager).await?; + } + Ok(total) + } + pub fn replay_log_batch(&self, batch: XLBatch) -> RS<()> { self.contract.replay_worker_log_batch(batch) } @@ -636,16 +819,20 @@ mod tests { fn test_schema() -> SchemaTable { SchemaTable::new( "t".to_string(), - vec![SchemaColumn::new( - "id".to_string(), - DatTypeID::I32, - DTInfo::from_text(DatTypeID::I32, String::new()), - )], - vec![SchemaColumn::new( - "v".to_string(), - DatTypeID::I32, - DTInfo::from_text(DatTypeID::I32, String::new()), - )], + vec![ + SchemaColumn::new( + "id".to_string(), + DatTypeID::I32, + DTInfo::from_text(DatTypeID::I32, String::new()), + ), + SchemaColumn::new( + "v".to_string(), + DatTypeID::I32, + DTInfo::from_text(DatTypeID::I32, String::new()), + ), + ], + vec![0], + vec![1], ) } @@ -820,7 +1007,9 @@ mod tests { ); let schema = test_schema(); let table_id = schema.id(); - contract.create_table(0, &schema).await.unwrap(); + let tx_mgr = contract.begin_tx().await.unwrap(); + contract.create_table(tx_mgr.clone(), &schema).await.unwrap(); + contract.commit_tx(tx_mgr).await.unwrap(); let key_path = TimeSeriesFile::relation_file_path(&log_dir, partition_id, table_id, 0); let value_path = TimeSeriesFile::relation_file_path(&log_dir, partition_id, table_id, 1); @@ -933,7 +1122,7 @@ mod tests { .prepare_connection_transfer(conn_id, Some(action)) .unwrap(); assert_eq!(transferred.len(), 2); - assert!(source.get_for_connection(conn_id, session_a, b"k").is_err()); + assert!(futures::executor::block_on(source.get_for_connection(conn_id, session_a, b"k")).is_err()); target .adopt_connection_sessions(conn_id, &transferred) @@ -948,7 +1137,7 @@ mod tests { .put_for_connection(conn_id, session_b, b"k".to_vec(), b"v".to_vec()) .unwrap(); assert_eq!( - target.get_for_connection(conn_id, session_b, b"k").unwrap(), + futures::executor::block_on(target.get_for_connection(conn_id, session_b, b"k")).unwrap(), Some(b"v".to_vec()) ); } diff --git a/mudu_kernel/src/server/worker_local.rs b/mudu_kernel/src/server/worker_local.rs index afa2a03..76383a3 100644 --- a/mudu_kernel/src/server/worker_local.rs +++ b/mudu_kernel/src/server/worker_local.rs @@ -1,9 +1,21 @@ use crate::server::worker_snapshot::KvItem; +use crate::contract::meta_mgr::MetaMgr; use async_trait::async_trait; use mudu::common::id::OID; use mudu::common::result::RS; +use mudu_contract::database::result_set::ResultSetAsync; +use mudu_contract::database::sql_params::SQLParams; +use mudu_contract::database::sql_stmt::SQLStmt; +use std::cell::UnsafeCell; use std::sync::Arc; +use crate::x_engine::api::XContract; + +thread_local! { + static CURRENT_WORKER_LOCAL: UnsafeCell> = + const { UnsafeCell::new(None) }; +} + #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum WorkerExecute { BeginTx, @@ -13,6 +25,9 @@ pub enum WorkerExecute { #[async_trait] pub trait WorkerLocal: Send + Sync { + fn x_contract(&self) -> Arc; + fn meta_mgr(&self) -> Arc; + async fn open_async(&self) -> RS; async fn open_argv_async(&self, worker_id: OID) -> RS { @@ -42,6 +57,56 @@ pub trait WorkerLocal: Send + Sync { start_key: &[u8], end_key: &[u8], ) -> RS>; + + async fn query( + &self, + oid: OID, + sql: Box, + param: Box, + ) -> RS>; + + async fn execute( + &self, + oid: OID, + sql: Box, + param: Box, + ) -> RS; + + async fn batch( + &self, + oid: OID, + sql: Box, + param: Box, + ) -> RS; } pub type WorkerLocalRef = Arc; + +pub(crate) fn set_current_worker_local(worker_local: WorkerLocalRef) { + CURRENT_WORKER_LOCAL.with(|slot| { + // Safety: the slot is thread-local and only mutated through these helpers. + unsafe { + *slot.get() = Some(worker_local); + } + }); +} + +pub(crate) fn unset_current_worker_local() { + CURRENT_WORKER_LOCAL.with(|slot| { + // Safety: the slot is thread-local and only mutated through these helpers. + unsafe { + *slot.get() = None; + } + }); +} + +pub(crate) fn current_worker_local() -> WorkerLocalRef { + CURRENT_WORKER_LOCAL.with(|slot| { + // Safety: shared reads are confined to the current thread-local slot. + let worker_local = unsafe { &*slot.get() }; + worker_local + .as_ref() + .cloned() + .unwrap_or_else(|| panic!("current worker local is not set")) + }) +} diff --git a/mudu_kernel/src/server/worker_registry.rs b/mudu_kernel/src/server/worker_registry.rs index c079956..4204444 100644 --- a/mudu_kernel/src/server/worker_registry.rs +++ b/mudu_kernel/src/server/worker_registry.rs @@ -12,6 +12,10 @@ const PARTITION_MARKER_SUFFIX: &str = ".pid"; #[derive(Debug, Clone, PartialEq, Eq)] pub struct WorkerIdentity { + // Workers are provisioned according to execution resources such as CPU cores, + // whereas partitions are defined by user-level data partitioning. + // Partitions are not automatically required to map one-to-one to workers. + // A worker may own multiple partitions in the future. pub worker_index: usize, pub worker_id: OID, pub partition_ids: Vec, diff --git a/mudu_kernel/src/server/worker_ring_loop.rs b/mudu_kernel/src/server/worker_ring_loop.rs index 996a15d..4824312 100644 --- a/mudu_kernel/src/server/worker_ring_loop.rs +++ b/mudu_kernel/src/server/worker_ring_loop.rs @@ -13,8 +13,10 @@ use crate::server::loop_user_io::{ }; use crate::server::server_iouring; use crate::server::server_iouring::RecoveryCoordinator; +use crate::server::session_bound_worker_runtime::{as_worker_local_ref, new_session_bound_worker_runtime}; use crate::server::worker::IoUringWorker; use crate::server::worker_loop_stats::WorkerLoopStats; +use crate::server::worker_local::{set_current_worker_local, unset_current_worker_local}; use crate::server::worker_mailbox::WorkerMailboxMsg; use crate::server::worker_task::{spawn_system_worker_task, WorkerTaskFuture}; use crate::wal::worker_log::ChunkedWorkerLogBackend; @@ -138,15 +140,21 @@ impl WorkerRingLoop { /// The worker-local ring pointer is installed for the duration of the run /// so user-level async file I/O can enqueue requests onto this loop. pub(in crate::server) fn run(&mut self) -> RS { + set_current_worker_local(as_worker_local_ref(new_session_bound_worker_runtime( + self.worker.clone(), + 0, + ))); set_current_worker_ring(self.worker_local_ring.clone()); if let Err(err) = self.recover_worker_log() { unset_current_worker_ring(); + unset_current_worker_local(); self.recovery_coordinator.worker_failed(); return Err(err); } self.recovery_coordinator.worker_succeeded()?; let r = self.run_service_loop(); unset_current_worker_ring(); + unset_current_worker_local(); r } @@ -489,10 +497,7 @@ mod tests { continue; } let cqe = loop_state.ring.wait().map_err(|wait_rc| { - m_error!( - EC::NetErr, - format!("io_uring_wait_cqe error {}", wait_rc) - ) + m_error!(EC::NetErr, format!("io_uring_wait_cqe error {}", wait_rc)) })?; loop_state.process_cqe(cqe)?; yield_now().await; diff --git a/mudu_kernel/src/server/worker_ring_loop/recovery.rs b/mudu_kernel/src/server/worker_ring_loop/recovery.rs index cc65c2e..238787f 100644 --- a/mudu_kernel/src/server/worker_ring_loop/recovery.rs +++ b/mudu_kernel/src/server/worker_ring_loop/recovery.rs @@ -88,7 +88,12 @@ impl WorkerRingLoop { continue; }; sqe.set_user_data(0); - sqe.prep_read_raw(file.as_raw_fd(), buf[offset..].as_mut_ptr(), size - offset, offset as u64); + sqe.prep_read_raw( + file.as_raw_fd(), + buf[offset..].as_mut_ptr(), + size - offset, + offset as u64, + ); let submitted = self.ring.submit(); if submitted < 0 { return Err(m_error!( diff --git a/mudu_kernel/src/server/worker_session_manager.rs b/mudu_kernel/src/server/worker_session_manager.rs index 576e792..1edb2f6 100644 --- a/mudu_kernel/src/server/worker_session_manager.rs +++ b/mudu_kernel/src/server/worker_session_manager.rs @@ -1,4 +1,6 @@ -use crate::server::worker_tx_manager::WorkerTxManager; +use crate::contract::meta_mgr::MetaMgr; +use crate::mudu_conn::mudu_conn_core::MuduConnCore; +use crate::x_engine::tx_mgr::TxMgr; use mudu::common::id::OID; use mudu::common::result::RS; use mudu::common::xid::new_xid; @@ -14,23 +16,25 @@ pub(crate) struct WorkerSessionManager { connection_sessions: SccHashMap>>, session_contexts: SccHashMap>, active_sessions: Arc, + meta_mgr: Arc, } -#[derive(Default)] pub(crate) struct SessionContext { - tx_manager: UnsafeCell>, + tx_manager: UnsafeCell>>, + mudu_conn_core: Arc, } unsafe impl Send for SessionContext {} unsafe impl Sync for SessionContext {} impl WorkerSessionManager { - pub(crate) fn new(active_sessions: Arc) -> Self { + pub(crate) fn new(active_sessions: Arc, meta_mgr: Arc) -> Self { Self { session_owner: SccHashMap::new(), connection_sessions: SccHashMap::new(), session_contexts: SccHashMap::new(), active_sessions, + meta_mgr, } } @@ -40,7 +44,7 @@ impl WorkerSessionManager { if self.session_owner.insert_sync(session_id, conn_id).is_err() { continue; } - let session_context = Arc::new(SessionContext::default()); + let session_context = Arc::new(SessionContext::new(self.meta_mgr.clone())); if self .session_contexts .insert_sync(session_id, session_context) @@ -182,7 +186,10 @@ impl WorkerSessionManager { })?; if self .session_contexts - .insert_sync(session_id, Arc::new(SessionContext::default())) + .insert_sync( + session_id, + Arc::new(SessionContext::new(self.meta_mgr.clone())), + ) .is_err() { let _ = self.session_owner.remove_sync(&session_id); @@ -203,14 +210,47 @@ impl WorkerSessionManager { pub(crate) fn connection_has_active_tx(&self, conn_id: u64) -> RS { let session_ids = self.connection_session_ids(conn_id); for session_id in session_ids { - let session = self.session_context(session_id)?; - if session.tx_manager_ref().is_some() { + if self.has_session_tx(session_id)? { return Ok(true); } } Ok(false) } + pub(crate) fn has_session_tx(&self, session_id: OID) -> RS { + Ok(self.session_context(session_id)?.tx_manager_ref().is_some()) + } + + pub(crate) fn begin_session_tx(&self, session_id: OID, tx_mgr: Arc) -> RS<()> { + let session = self.session_context(session_id)?; + if session.tx_manager_ref().is_some() { + return Err(m_error!( + EC::ExistingSuchElement, + format!("session {} already has an active transaction", session_id) + )); + } + session.set_tx_manager(Some(tx_mgr)); + Ok(()) + } + + pub(crate) fn take_session_tx(&self, session_id: OID) -> RS> { + let session = self.session_context(session_id)?; + session.take_tx_manager().ok_or_else(|| { + m_error!( + EC::NoSuchElement, + format!("session {} has no active transaction", session_id) + ) + }) + } + + pub(crate) fn with_session_tx(&self, session_id: OID, f: F) -> RS + where + F: FnOnce(Option>) -> RS, + { + let session = self.session_context(session_id)?; + f(session.tx_manager_ref().clone()) + } + pub(crate) fn detach_connection_sessions(&self, conn_id: u64) -> RS> { let Some((_conn_id, conn_sessions)) = self.connection_sessions.remove_sync(&conn_id) else { return Ok(Vec::new()); @@ -263,21 +303,28 @@ impl WorkerSessionManager { } impl SessionContext { - pub(crate) fn tx_manager_ref(&self) -> &Option { - unsafe { &*self.tx_manager.get() } + fn new(meta_mgr: Arc) -> Self { + Self { + tx_manager: UnsafeCell::new(None), + mudu_conn_core: Arc::new(MuduConnCore::new(meta_mgr)), + } } - pub(crate) fn tx_manager_mut(&self) -> &mut Option { - unsafe { &mut *self.tx_manager.get() } + pub(crate) fn tx_manager_ref(&self) -> &Option> { + unsafe { &*self.tx_manager.get() } } - pub(crate) fn set_tx_manager(&self, tx_manager: Option) { + pub(crate) fn set_tx_manager(&self, tx_manager: Option>) { unsafe { *self.tx_manager.get() = tx_manager; } } - pub(crate) fn take_tx_manager(&self) -> Option { - self.tx_manager_mut().take() + pub(crate) fn take_tx_manager(&self) -> Option> { + unsafe { (&mut *self.tx_manager.get()).take() } + } + + pub(crate) fn mudu_conn_core(&self) -> Arc { + self.mudu_conn_core.clone() } } diff --git a/mudu_kernel/src/server/worker_snapshot.rs b/mudu_kernel/src/server/worker_snapshot.rs index e6b1f76..677dacc 100644 --- a/mudu_kernel/src/server/worker_snapshot.rs +++ b/mudu_kernel/src/server/worker_snapshot.rs @@ -2,6 +2,8 @@ use crate::contract::snapshot::{RunningXList, Snapshot}; use mudu::common::result::RS; use mudu::error::ec::EC; use mudu::m_error; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::Mutex; #[derive(Debug, Clone, PartialEq, Eq)] pub struct KvItem { @@ -15,10 +17,9 @@ pub struct WorkerSnapshot { running: Vec, } -#[derive(Default)] pub struct WorkerSnapshotMgr { - next_ts: u64, - running: Vec, + next_ts: AtomicU64, + running: Mutex>, } impl WorkerSnapshot { @@ -40,32 +41,36 @@ impl WorkerSnapshot { } impl WorkerSnapshotMgr { - pub fn begin_tx(&mut self) -> WorkerSnapshot { - self.next_ts += 1; - let xid = self.next_ts; + pub fn begin_tx(&self) -> WorkerSnapshot { + let xid = self.next_ts.fetch_add(1, Ordering::Relaxed) + 1; + let mut running = self + .running + .lock() + .expect("worker snapshot manager running list lock poisoned"); let snapshot = WorkerSnapshot { xid, - running: self.running.clone(), + running: running.clone(), }; - insert_sorted_unique(&mut self.running, xid); + insert_sorted_unique(&mut running, xid); snapshot } - pub fn alloc_committed_ts(&mut self) -> u64 { - self.next_ts += 1; - self.next_ts + pub fn alloc_committed_ts(&self) -> u64 { + self.next_ts.fetch_add(1, Ordering::Relaxed) + 1 } - pub fn observe_committed_ts(&mut self, xid: u64) { - if self.next_ts < xid { - self.next_ts = xid; - } + pub fn observe_committed_ts(&self, xid: u64) { + self.next_ts.fetch_max(xid, Ordering::Relaxed); } - pub fn end_tx(&mut self, xid: u64) -> RS<()> { - match self.running.binary_search(&xid) { + pub fn end_tx(&self, xid: u64) -> RS<()> { + let mut running = self + .running + .lock() + .expect("worker snapshot manager running list lock poisoned"); + match running.binary_search(&xid) { Ok(index) => { - self.running.remove(index); + running.remove(index); Ok(()) } Err(_) => Err(m_error!( @@ -76,6 +81,15 @@ impl WorkerSnapshotMgr { } } +impl Default for WorkerSnapshotMgr { + fn default() -> Self { + Self { + next_ts: AtomicU64::new(0), + running: Mutex::new(Vec::new()), + } + } +} + fn is_visible_to_snapshot(version_xid: u64, snapshot: &WorkerSnapshot) -> bool { if version_xid > snapshot.xid { return false; diff --git a/mudu_kernel/src/server/worker_storage.rs b/mudu_kernel/src/server/worker_storage.rs index 60cd89b..2ba0ba7 100644 --- a/mudu_kernel/src/server/worker_storage.rs +++ b/mudu_kernel/src/server/worker_storage.rs @@ -1,7 +1,8 @@ use std::collections::{BTreeMap, Bound}; use std::ops::Bound::{Excluded, Included, Unbounded}; -use std::sync::Arc; +use std::sync::{Arc, Mutex, OnceLock, Weak}; +use futures::executor::block_on; use mudu::common::id::OID; use mudu::common::result::RS; use mudu::error::ec::EC; @@ -21,6 +22,15 @@ use crate::storage::relation::relation::Relation; use crate::wal::xl_batch::XLBatch; use crate::wal::xl_data_op::{XLDelete, XLInsert}; use crate::wal::xl_entry::TxOp; +use crate::x_engine::tx_mgr::TxMgr; + +type WorkerStorageRegistry = std::collections::HashMap>>; + +fn storage_registry() -> &'static Mutex { + static REGISTRY: OnceLock> = OnceLock::new(); + REGISTRY.get_or_init(|| Mutex::new(std::collections::HashMap::new())) +} + #[derive(Clone, Debug)] pub(crate) struct PreparedWorkerCommit { xid: u64, @@ -48,72 +58,92 @@ impl WorkerStorage { } } + pub fn register_global(self: &Arc) { + let mut guard = storage_registry().lock().unwrap(); + guard + .entry(self.relation_path.clone()) + .or_default() + .push(Arc::downgrade(self)); + } + + pub fn bootstrap_existing_tables_sync(&self) -> RS<()> { + for schema in block_on(self.mgr.list_schemas())? { + self.apply_create_table_local(&schema)?; + } + Ok(()) + } + pub async fn create_table_async(&self, schema: &SchemaTable) -> RS<()> { - let oid = schema.id(); self.mgr.create_table(schema).await?; - let table_desc = self.mgr.get_table_by_id(oid).await?; - self.create_relation_index(oid, table_desc.as_ref())?; - Ok(()) + self.broadcast_create_table(schema) } pub async fn drop_table_async(&self, oid: OID) -> RS<()> { self.mgr.drop_table(oid).await?; - let _ = self.relation_store.remove_sync(&oid); - Ok(()) + self.broadcast_drop_table(oid) } + #[allow(dead_code)] - pub fn contains_key(&self, oid: OID, key: &KeyTuple, txm: &mut WorkerTxManager) -> RS { + pub async fn contains_key(&self, oid: OID, key: &KeyTuple, txm: &dyn TxMgr) -> RS { if let Some(staged) = txm.get_relation(oid, key.as_slice()) { return Ok(staged.is_some()); } - self.read_visible_relation_exists(oid, key, txm.snapshot()) + self.read_visible_relation_exists(oid, key, &txm.snapshot()).await } - pub fn get(&self, oid: OID, key: &[u8], txm: &mut WorkerTxManager) -> RS>> { + + pub async fn get(&self, oid: OID, key: &[u8], txm: &dyn TxMgr) -> RS>> { if let Some(staged) = txm.get_relation(oid, key) { return Ok(staged); } - let key = KeyTuple::from(key.to_vec()); - self.read_visible_relation_value(oid, &key, txm.snapshot()) + self.read_visible_relation_value(oid, &key, &txm.snapshot()).await } - pub fn insert( + pub async fn put( &self, oid: OID, key: Vec, value: Vec, - txm: &mut WorkerTxManager, + txm: &dyn TxMgr, ) -> RS<()> { let key_tuple = KeyTuple::from(key.clone()); - self.ensure_no_relation_write_conflict(oid, &key_tuple, txm.snapshot())?; + + self.ensure_no_relation_write_conflict(oid, &key_tuple, &txm.snapshot()) + .await?; txm.put_relation(oid, key, value); Ok(()) } - pub fn remove(&self, oid: OID, key: &[u8], txm: &mut WorkerTxManager) -> RS>> { - let key_tuple = KeyTuple::from(key.to_vec()); - self.ensure_no_relation_write_conflict(oid, &key_tuple, txm.snapshot())?; + pub async fn remove(&self, oid: OID, key: &[u8], txm: &dyn TxMgr) -> RS>> { + let key_tuple = KeyTuple::from(key.to_vec()); + self.ensure_no_relation_write_conflict(oid, &key_tuple, &txm.snapshot()) + .await?; let current = match txm.get_relation(oid, key) { Some(staged) => staged, - None => self.read_visible_relation_value(oid, &key_tuple, txm.snapshot())?, + None => self + .read_visible_relation_value(oid, &key_tuple, &txm.snapshot()) + .await?, }; - if current.is_some() { txm.delete_relation(oid, key.to_vec()); } Ok(current) + } - pub fn range( + + pub async fn range( &self, oid: OID, bounds: (Bound<&[u8]>, Bound<&[u8]>), - txm: &mut WorkerTxManager, + txm: &dyn TxMgr, ) -> RS, Vec)>> { - let base_items = self.range_visible_relation(oid, bounds, txm.snapshot())?; + let base_items = self + .range_visible_relation(oid, bounds, &txm.snapshot()) + .await?; let (start_key, end_key) = bounds_to_scan(&bounds); let staged_items = txm.staged_relation_items_in_range(oid, &start_key, &end_key); @@ -131,67 +161,95 @@ impl WorkerStorage { .collect()) } - pub fn worker_get(&self, key: &[u8], snapshot: Option<&WorkerSnapshot>) -> RS>> { - let row = self.kv_store.get_sync(key); + pub async fn kv_get( + &self, + key: &[u8], + snapshot: Option<&WorkerSnapshot>, + ) -> RS>> { + let row = self.kv_store.get_sync(key).map(|entry| entry.get().clone()); let version = match snapshot { - Some(snapshot) => row.and_then(|row| { - let snapshot = snapshot.to_snapshot(); - read_visible_version(row.get(), &snapshot) - }), - None => row.and_then(|row| latest_version(row.get())), + Some(snapshot) => match row { + Some(row) => { + let snapshot = snapshot.to_snapshot(); + row.read(&snapshot).await? + } + None => None, + }, + None => match row { + Some(row) => row.read_latest().await?, + None => None, + }, }; Ok(version .filter(|version| !version.is_deleted()) .map(|version| version.tuple().clone())) } - pub fn worker_range( + + pub async fn kv_range( &self, start_key: &[u8], end_key: &[u8], snapshot: Option<&WorkerSnapshot>, ) -> RS> { - let mut items = Vec::new(); + let mut rows = Vec::new(); self.kv_store.iter_sync(|key, row| { let in_range = if end_key.is_empty() { key.as_slice() >= start_key } else { key.as_slice() >= start_key && key.as_slice() < end_key }; - if !in_range { - return true; + if in_range { + rows.push((key.clone(), row.clone())); } + true + }); + let mut items = Vec::new(); + for (key, row) in rows { let visible = match snapshot { Some(snapshot) => { let snapshot = snapshot.to_snapshot(); - read_visible_version(row, &snapshot) + row.read(&snapshot).await? } - None => latest_version(row), + None => row.read_latest().await?, }; if let Some(visible) = visible.filter(|version| !version.is_deleted()) { items.push(KvItem { - key: key.clone(), + key, value: visible.tuple().clone(), }); } - true - }); + } items.sort_by(|left, right| left.key.cmp(&right.key)); Ok(items) } + #[allow(dead_code)] - pub(crate) fn commit_tx(&self, txm: &mut WorkerTxManager) -> RS<()> { - let prepared = self.prepare_commit(txm)?; - self.apply_prepared_commit(prepared) + pub(crate) async fn commit_tx(&self, txm: &mut WorkerTxManager) -> RS<()> { + let prepared = self.prepare_commit_async(txm).await?; + self.apply_relation_rows_async(&prepared).await?; + self.apply_kv_rows_async(&prepared).await?; + Ok(()) } - pub(crate) fn prepare_commit(&self, txm: &WorkerTxManager) -> RS { + pub(crate) async fn prepare_commit_async(&self, txm: &dyn TxMgr) -> RS { + self.prepare_commit_parts_async( + &txm.snapshot(), + txm.xid(), + txm.staged_relation_ops(), + txm.staged_put_items().into_iter().collect(), + txm.xl_batch(), + ) + .await + } + + pub(crate) fn prepare_commit(&self, txm: &dyn TxMgr) -> RS { self.prepare_commit_parts( - txm.snapshot(), + &txm.snapshot(), txm.xid(), - txm.staged_relation_ops().clone(), + txm.staged_relation_ops(), txm.staged_put_items().into_iter().collect(), txm.xl_batch(), ) @@ -228,6 +286,12 @@ impl WorkerStorage { Ok(()) } + pub(crate) async fn apply_prepared_commit_async(&self, prepared: PreparedWorkerCommit) -> RS<()> { + self.apply_relation_rows_async(&prepared).await?; + self.apply_kv_rows_async(&prepared).await?; + Ok(()) + } + pub(crate) fn replay_batch(&self, batch: XLBatch) -> RS<()> { for entry in batch.entries { for op in entry.ops { @@ -278,6 +342,26 @@ impl WorkerStorage { }) } + async fn prepare_commit_parts_async( + &self, + snapshot: &WorkerSnapshot, + xid: u64, + relation_rows: BTreeMap, Option>>>, + kv_rows: BTreeMap, Option>>, + batch: XLBatch, + ) -> RS { + self.ensure_no_relation_conflicts_async(snapshot, xid, &relation_rows) + .await?; + self.ensure_no_kv_conflicts(snapshot, xid, &kv_rows)?; + + Ok(PreparedWorkerCommit { + xid, + relation_rows, + kv_rows, + batch, + }) + } + fn ensure_no_relation_conflicts( &self, snapshot: &WorkerSnapshot, @@ -291,7 +375,34 @@ impl WorkerStorage { .ok_or_else(|| m_error!(EC::NoSuchElement, format!("no such table {}", oid)))?; for key in rows.keys() { let key_tuple = KeyTuple::from(key.clone()); - if relation.get().has_write_conflict(&key_tuple, snapshot)? { + if relation.get().has_write_conflict_sync(&key_tuple, snapshot)? { + return Err(m_error!( + EC::TxErr, + format!( + "write-write conflict on table {} key {:?} for transaction {}", + oid, key, xid + ) + )); + } + } + } + Ok(()) + } + + async fn ensure_no_relation_conflicts_async( + &self, + snapshot: &WorkerSnapshot, + xid: u64, + relation_rows: &BTreeMap, Option>>>, + ) -> RS<()> { + for (oid, rows) in relation_rows { + let relation = self + .relation_store + .get_sync(oid) + .ok_or_else(|| m_error!(EC::NoSuchElement, format!("no such table {}", oid)))?; + for key in rows.keys() { + let key_tuple = KeyTuple::from(key.clone()); + if relation.get().has_write_conflict(&key_tuple, snapshot).await? { return Err(m_error!( EC::TxErr, format!( @@ -341,7 +452,23 @@ impl WorkerStorage { for (key, value) in rows { relation .get() - .write_row(key.clone(), value.clone(), prepared.xid)?; + .write_row_sync(key.clone(), value.clone(), prepared.xid)?; + } + } + Ok(()) + } + + async fn apply_relation_rows_async(&self, prepared: &PreparedWorkerCommit) -> RS<()> { + for (oid, rows) in &prepared.relation_rows { + let relation = self + .relation_store + .get_sync(oid) + .ok_or_else(|| m_error!(EC::NoSuchElement, format!("no such table {}", oid)))?; + for (key, value) in rows { + relation + .get() + .write_row(key.clone(), value.clone(), prepared.xid) + .await?; } } Ok(()) @@ -354,6 +481,19 @@ impl WorkerStorage { Ok(()) } + async fn apply_kv_rows_async(&self, prepared: &PreparedWorkerCommit) -> RS<()> { + for (key, value) in &prepared.kv_rows { + write_version_to_kv_store_async( + &self.kv_store, + key.clone(), + value.clone(), + prepared.xid, + ) + .await?; + } + Ok(()) + } + fn apply_relation_replay_insert(&self, insert: XLInsert, xid: u64) -> RS<()> { let relation = self .relation_store @@ -364,7 +504,7 @@ impl WorkerStorage { format!("no such table {}", insert.table_id) ) })?; - relation.get().write_value(insert.key, insert.value, xid) + relation.get().write_value_sync(insert.key, insert.value, xid) } fn apply_relation_replay_delete(&self, delete: XLDelete, xid: u64) -> RS<()> { @@ -377,10 +517,11 @@ impl WorkerStorage { format!("no such table {}", delete.table_id) ) })?; - relation.get().write_delete(delete.key, xid) + relation.get().write_delete_sync(delete.key, xid) } - fn read_visible_relation_exists( + #[allow(dead_code)] + async fn read_visible_relation_exists( &self, oid: OID, key: &KeyTuple, @@ -390,10 +531,10 @@ impl WorkerStorage { .relation_store .get_sync(&oid) .ok_or_else(|| m_error!(EC::NoSuchElement, format!("no such table {}", oid)))?; - relation.get().has_visible_version(key, snapshot) + relation.get().has_visible_version(key, snapshot).await } - fn read_visible_relation_value( + async fn read_visible_relation_value( &self, oid: OID, key: &KeyTuple, @@ -403,10 +544,10 @@ impl WorkerStorage { .relation_store .get_sync(&oid) .ok_or_else(|| m_error!(EC::NoSuchElement, format!("no such table {}", oid)))?; - relation.get().visible_value(key, snapshot) + relation.get().visible_value(key, snapshot).await } - fn range_visible_relation( + async fn range_visible_relation( &self, oid: OID, bounds: (Bound<&[u8]>, Bound<&[u8]>), @@ -416,10 +557,10 @@ impl WorkerStorage { .relation_store .get_sync(&oid) .ok_or_else(|| m_error!(EC::NoSuchElement, format!("no such table {}", oid)))?; - relation.get().visible_range(bounds, snapshot) + relation.get().visible_range(bounds, snapshot).await } - fn ensure_no_relation_write_conflict( + async fn ensure_no_relation_write_conflict( &self, oid: OID, key: &KeyTuple, @@ -429,7 +570,7 @@ impl WorkerStorage { .relation_store .get_sync(&oid) .ok_or_else(|| m_error!(EC::NoSuchElement, format!("no such table {}", oid)))?; - if relation.get().has_write_conflict(key, snapshot)? { + if relation.get().has_write_conflict(key, snapshot).await? { return Err(m_error!( EC::TxErr, format!( @@ -455,6 +596,53 @@ impl WorkerStorage { ); Ok(()) } + + fn apply_create_table_local(&self, schema: &SchemaTable) -> RS<()> { + let table_desc = + crate::contract::table_info::TableInfo::new(schema.clone())?.table_desc()?; + self.create_relation_index(schema.id(), table_desc.as_ref()) + } + + fn apply_drop_table_local(&self, oid: OID) { + let _ = self.relation_store.remove_sync(&oid); + } + + fn broadcast_create_table(&self, schema: &SchemaTable) -> RS<()> { + let peers = self.peer_instances(); + if peers.is_empty() { + return self.apply_create_table_local(schema); + } + for storage in peers { + storage.apply_create_table_local(schema)?; + } + Ok(()) + } + + fn broadcast_drop_table(&self, oid: OID) -> RS<()> { + let peers = self.peer_instances(); + if peers.is_empty() { + self.apply_drop_table_local(oid); + return Ok(()); + } + for storage in peers { + storage.apply_drop_table_local(oid); + } + Ok(()) + } + + fn peer_instances(&self) -> Vec> { + let mut guard = storage_registry().lock().unwrap(); + let peers = guard.entry(self.relation_path.clone()).or_default(); + let mut live = Vec::with_capacity(peers.len()); + peers.retain(|weak| match weak.upgrade() { + Some(storage) => { + live.push(storage); + true + } + None => false, + }); + live + } } impl PreparedWorkerCommit { @@ -486,15 +674,27 @@ fn write_version_to_kv_store( Ok(()) } -fn latest_version(row: &DataRow) -> Option { - row.read_latest_sync().ok().flatten() +async fn write_version_to_kv_store_async( + kv_store: &SccHashMap, DataRow>, + key: Vec, + value: Option>, + xid: u64, +) -> RS<()> { + let row = kv_store + .get_sync(&key) + .map(|entry| entry.get().clone()) + .unwrap_or_else(|| DataRow::new(0)); + let version = match value { + Some(value) => new_value_version(xid, value), + None => VersionTuple::new_delete(Timestamp::new(xid, u64::MAX)), + }; + row.write(version, None).await?; + let _ = kv_store.insert_sync(key, row); + Ok(()) } -fn read_visible_version( - row: &DataRow, - snapshot: &crate::contract::snapshot::Snapshot, -) -> Option { - row.read_sync(snapshot).ok().flatten() +fn latest_version(row: &DataRow) -> Option { + row.read_latest_sync().ok().flatten() } fn bounds_to_scan(bounds: &(Bound<&[u8]>, Bound<&[u8]>)) -> (Vec, Vec) { @@ -573,16 +773,20 @@ mod tests { fn test_schema() -> SchemaTable { SchemaTable::new( "t".to_string(), - vec![SchemaColumn::new( - "id".to_string(), - DatTypeID::I32, - DTInfo::from_text(DatTypeID::I32, String::new()), - )], - vec![SchemaColumn::new( - "v".to_string(), - DatTypeID::I32, - DTInfo::from_text(DatTypeID::I32, String::new()), - )], + vec![ + SchemaColumn::new( + "id".to_string(), + DatTypeID::I32, + DTInfo::from_text(DatTypeID::I32, String::new()), + ), + SchemaColumn::new( + "v".to_string(), + DatTypeID::I32, + DTInfo::from_text(DatTypeID::I32, String::new()), + ), + ], + vec![0], + vec![1], ) } @@ -605,6 +809,33 @@ mod tests { (storage, oid) } + fn test_shared_storage() -> ( + Arc, + Arc, + Arc, + OID, + ) { + let mgr = Arc::new(TestMetaMgr::new()); + let root = std::env::temp_dir() + .join(format!( + "worker_storage_shared_test_{}", + mudu::common::id::gen_oid() + )) + .to_string_lossy() + .to_string(); + let storage1 = Arc::new(WorkerStorage::new(mgr.clone(), 1, root.clone())); + storage1.register_global(); + storage1.bootstrap_existing_tables_sync().unwrap(); + let storage2 = Arc::new(WorkerStorage::new(mgr.clone(), 2, root)); + storage2.register_global(); + storage2.bootstrap_existing_tables_sync().unwrap(); + + let schema = test_schema(); + let oid = schema.id(); + futures::executor::block_on(storage1.create_table_async(&schema)).unwrap(); + (mgr, storage1, storage2, oid) + } + fn begin_tx(xid: u64, running: Vec) -> WorkerTxManager { WorkerTxManager::new(WorkerSnapshot::new(xid, running)) } @@ -613,42 +844,53 @@ mod tests { v.to_be_bytes().to_vec() } + #[test] + fn worker_storage_broadcasts_create_and_drop_to_peer_workers() { + let (mgr, _storage1, storage2, oid) = test_shared_storage(); + let mut tx = begin_tx(1, vec![]); + block_on(storage2.put(oid, i32_bytes(7), i32_bytes(70), &mut tx)).unwrap(); + block_on(storage2.commit_tx(&mut tx)).unwrap(); + assert!(futures::executor::block_on(mgr.get_table_by_id(oid)).is_ok()); + + futures::executor::block_on(storage2.drop_table_async(oid)).unwrap(); + assert!(futures::executor::block_on(mgr.get_table_by_id(oid)).is_err()); + + let mut tx = begin_tx(2, vec![]); + let err = block_on(storage2.put(oid, i32_bytes(8), i32_bytes(80), &mut tx)) + .unwrap_err(); + assert!(format!("{err}").contains("no such table")); + } + #[test] fn worker_storage_reads_own_writes() { let (storage, oid) = test_storage(); let mut tx = begin_tx(10, vec![]); - storage - .insert(oid, i32_bytes(1), i32_bytes(11), &mut tx) - .unwrap(); + block_on(storage.put(oid, i32_bytes(1), i32_bytes(11), &mut tx)).unwrap(); assert_eq!( - storage.get(oid, &i32_bytes(1), &mut tx).unwrap(), + block_on(storage.get(oid, &i32_bytes(1), &mut tx)).unwrap(), Some(i32_bytes(11)) ); - assert!(storage - .contains_key(oid, &KeyTuple::from(i32_bytes(1)), &mut tx) - .unwrap()); + assert!( + block_on(storage.contains_key(oid, &KeyTuple::from(i32_bytes(1)), &mut tx)).unwrap() + ); } #[test] fn worker_storage_snapshot_hides_later_commit() { let (storage, oid) = test_storage(); let mut tx1 = begin_tx(1, vec![]); - storage - .insert(oid, i32_bytes(1), i32_bytes(10), &mut tx1) - .unwrap(); - storage.commit_tx(&mut tx1).unwrap(); + block_on(storage.put(oid, i32_bytes(1), i32_bytes(10), &mut tx1)).unwrap(); + block_on(storage.commit_tx(&mut tx1)).unwrap(); let mut old_tx = begin_tx(2, vec![]); let mut new_tx = begin_tx(3, vec![2]); - storage - .insert(oid, i32_bytes(1), i32_bytes(20), &mut new_tx) - .unwrap(); - storage.commit_tx(&mut new_tx).unwrap(); + block_on(storage.put(oid, i32_bytes(1), i32_bytes(20), &mut new_tx)).unwrap(); + block_on(storage.commit_tx(&mut new_tx)).unwrap(); assert_eq!( - storage.get(oid, &i32_bytes(1), &mut old_tx).unwrap(), + block_on(storage.get(oid, &i32_bytes(1), &mut old_tx)).unwrap(), Some(i32_bytes(10)) ); } @@ -657,19 +899,15 @@ mod tests { fn worker_storage_range_is_stable_with_snapshot() { let (storage, oid) = test_storage(); let mut seed = begin_tx(1, vec![]); - storage - .insert(oid, i32_bytes(1), i32_bytes(10), &mut seed) - .unwrap(); - storage.commit_tx(&mut seed).unwrap(); + block_on(storage.put(oid, i32_bytes(1), i32_bytes(10), &mut seed)).unwrap(); + block_on(storage.commit_tx(&mut seed)).unwrap(); let mut old_tx = begin_tx(2, vec![]); let mut new_tx = begin_tx(3, vec![2]); - storage - .insert(oid, i32_bytes(2), i32_bytes(20), &mut new_tx) - .unwrap(); - storage.commit_tx(&mut new_tx).unwrap(); + block_on(storage.put(oid, i32_bytes(2), i32_bytes(20), &mut new_tx)).unwrap(); + block_on(storage.commit_tx(&mut new_tx)).unwrap(); - let rows = storage + let rows = block_on(storage .range( oid, ( @@ -678,7 +916,8 @@ mod tests { ), &mut old_tx, ) - .unwrap(); + ) + .unwrap(); assert_eq!(rows, vec![(i32_bytes(1), i32_bytes(10))]); } @@ -686,21 +925,15 @@ mod tests { fn worker_storage_first_committer_wins() { let (storage, oid) = test_storage(); let mut seed = begin_tx(1, vec![]); - storage - .insert(oid, i32_bytes(1), i32_bytes(10), &mut seed) - .unwrap(); - storage.commit_tx(&mut seed).unwrap(); + block_on(storage.put(oid, i32_bytes(1), i32_bytes(10), &mut seed)).unwrap(); + block_on(storage.commit_tx(&mut seed)).unwrap(); let mut tx1 = begin_tx(2, vec![]); let mut tx2 = begin_tx(3, vec![2]); - storage - .insert(oid, i32_bytes(1), i32_bytes(11), &mut tx1) - .unwrap(); - storage - .insert(oid, i32_bytes(1), i32_bytes(12), &mut tx2) - .unwrap(); - storage.commit_tx(&mut tx1).unwrap(); - let err = storage.commit_tx(&mut tx2).unwrap_err(); + block_on(storage.put(oid, i32_bytes(1), i32_bytes(11), &mut tx1)).unwrap(); + block_on(storage.put(oid, i32_bytes(1), i32_bytes(12), &mut tx2)).unwrap(); + block_on(storage.commit_tx(&mut tx1)).unwrap(); + let err = block_on(storage.commit_tx(&mut tx2)).unwrap_err(); assert!(err.to_string().contains("write-write conflict")); } @@ -709,26 +942,24 @@ mod tests { fn worker_storage_delete_respects_snapshot() { let (storage, oid) = test_storage(); let mut seed = begin_tx(1, vec![]); - storage - .insert(oid, i32_bytes(1), i32_bytes(10), &mut seed) - .unwrap(); - storage.commit_tx(&mut seed).unwrap(); + block_on(storage.put(oid, i32_bytes(1), i32_bytes(10), &mut seed)).unwrap(); + block_on(storage.commit_tx(&mut seed)).unwrap(); let mut old_tx = begin_tx(2, vec![]); let mut delete_tx = begin_tx(3, vec![2]); assert_eq!( - storage.remove(oid, &i32_bytes(1), &mut delete_tx).unwrap(), + block_on(storage.remove(oid, &i32_bytes(1), &mut delete_tx)).unwrap(), Some(i32_bytes(10)) ); - storage.commit_tx(&mut delete_tx).unwrap(); + block_on(storage.commit_tx(&mut delete_tx)).unwrap(); assert_eq!( - storage.get(oid, &i32_bytes(1), &mut old_tx).unwrap(), + block_on(storage.get(oid, &i32_bytes(1), &mut old_tx)).unwrap(), Some(i32_bytes(10)) ); let mut fresh_tx = begin_tx(4, vec![]); assert_eq!( - storage.get(oid, &i32_bytes(1), &mut fresh_tx).unwrap(), + block_on(storage.get(oid, &i32_bytes(1), &mut fresh_tx)).unwrap(), None ); } @@ -750,10 +981,10 @@ mod tests { storage.apply_prepared_commit(prepared).unwrap(); assert_eq!( - storage.worker_get(b"a", Some(&snapshot)).unwrap(), + block_on(storage.kv_get(b"a", Some(&snapshot))).unwrap(), Some(b"0".to_vec()) ); - assert_eq!(storage.worker_get(b"a", None).unwrap(), Some(b"1".to_vec())); + assert_eq!(block_on(storage.kv_get(b"a", None)).unwrap(), Some(b"1".to_vec())); } #[test] @@ -767,7 +998,7 @@ mod tests { .worker_put_local(b"b".to_vec(), b"2".to_vec(), 3) .unwrap(); - let rows = storage.worker_range(b"a", b"z", Some(&snapshot)).unwrap(); + let rows = block_on(storage.kv_range(b"a", b"z", Some(&snapshot))).unwrap(); assert_eq!( rows, vec![KvItem { @@ -803,8 +1034,8 @@ mod tests { storage.apply_prepared_commit(prepared1).unwrap(); storage.apply_prepared_commit(prepared2).unwrap(); - assert_eq!(storage.worker_get(b"a", None).unwrap(), Some(b"1".to_vec())); - assert_eq!(storage.worker_get(b"b", None).unwrap(), Some(b"2".to_vec())); + assert_eq!(block_on(storage.kv_get(b"a", None)).unwrap(), Some(b"1".to_vec())); + assert_eq!(block_on(storage.kv_get(b"b", None)).unwrap(), Some(b"2".to_vec())); } #[test] @@ -834,10 +1065,10 @@ mod tests { storage.replay_batch(batch).unwrap(); - assert_eq!(storage.worker_get(b"k", None).unwrap(), Some(b"v".to_vec())); + assert_eq!(block_on(storage.kv_get(b"k", None)).unwrap(), Some(b"v".to_vec())); let mut tx = begin_tx(10, vec![]); assert_eq!( - storage.get(oid, &i32_bytes(7), &mut tx).unwrap(), + block_on(storage.get(oid, &i32_bytes(7), &mut tx)).unwrap(), Some(i32_bytes(70)) ); } @@ -866,6 +1097,6 @@ mod tests { storage.replay_batch(batch).unwrap(); - assert_eq!(storage.worker_get(b"k", None).unwrap(), None); + assert_eq!(block_on(storage.kv_get(b"k", None)).unwrap(), None); } } diff --git a/mudu_kernel/src/server/worker_tx_manager.rs b/mudu_kernel/src/server/worker_tx_manager.rs index 1a10563..4db55cd 100644 --- a/mudu_kernel/src/server/worker_tx_manager.rs +++ b/mudu_kernel/src/server/worker_tx_manager.rs @@ -2,39 +2,50 @@ use crate::server::worker_snapshot::WorkerSnapshot; use crate::wal::xl_batch::XLBatch; use crate::wal::xl_data_op::{XLDelete, XLInsert}; use crate::wal::xl_entry::{TxOp, XLEntry}; +use crate::x_engine::tx_mgr::TxMgr; use mudu::common::id::OID; use std::collections::BTreeMap; +use std::sync::Mutex; -pub struct WorkerTxManager { - snapshot: WorkerSnapshot, +struct WorkerTxState { staged_puts: BTreeMap, Option>>, staged_relation_ops: BTreeMap, Option>>>, write_ops: Vec<(OID, Vec)>, log_buffer: Vec, } +pub struct WorkerTxManager { + snapshot: WorkerSnapshot, + state: Mutex, +} + impl WorkerTxManager { pub fn new(snapshot: WorkerSnapshot) -> Self { Self { snapshot, - staged_puts: BTreeMap::new(), - staged_relation_ops: BTreeMap::new(), - write_ops: vec![], - log_buffer: Vec::new(), + state: Mutex::new(WorkerTxState { + staged_puts: BTreeMap::new(), + staged_relation_ops: BTreeMap::new(), + write_ops: Vec::new(), + log_buffer: Vec::new(), + }), } } +} - pub fn xid(&self) -> u64 { +impl TxMgr for WorkerTxManager { + fn xid(&self) -> u64 { self.snapshot.xid() } - pub fn snapshot(&self) -> &WorkerSnapshot { - &self.snapshot + fn snapshot(&self) -> WorkerSnapshot { + self.snapshot.clone() } - pub fn put(&mut self, key: Vec, value: Vec) { - self.staged_puts.insert(key.clone(), Some(value.clone())); - self.log_buffer.push(TxOp::Insert(XLInsert { + fn put(&self, key: Vec, value: Vec) { + let mut state = self.state.lock().unwrap(); + state.staged_puts.insert(key.clone(), Some(value.clone())); + state.log_buffer.push(TxOp::Insert(XLInsert { table_id: 0, tuple_id: 0, key, @@ -42,25 +53,29 @@ impl WorkerTxManager { })); } - pub fn delete(&mut self, key: Vec) { - self.staged_puts.insert(key.clone(), None); - self.log_buffer.push(TxOp::Delete(XLDelete { + fn delete(&self, key: Vec) { + let mut state = self.state.lock().unwrap(); + state.staged_puts.insert(key.clone(), None); + state.log_buffer.push(TxOp::Delete(XLDelete { table_id: 0, tuple_id: 0, key, })); } - pub fn get(&self, key: &[u8]) -> Option>> { - self.staged_puts.get(key).cloned() + fn get(&self, key: &[u8]) -> Option>> { + let state = self.state.lock().unwrap(); + state.staged_puts.get(key).cloned() } - pub fn put_relation(&mut self, oid: OID, key: Vec, value: Vec) { - self.staged_relation_ops + fn put_relation(&self, oid: OID, key: Vec, value: Vec) { + let mut state = self.state.lock().unwrap(); + state + .staged_relation_ops .entry(oid) .or_default() .insert(key.clone(), Some(value.clone())); - self.log_buffer.push(TxOp::Insert(XLInsert { + state.log_buffer.push(TxOp::Insert(XLInsert { table_id: oid, tuple_id: 0, key, @@ -68,31 +83,37 @@ impl WorkerTxManager { })); } - pub fn delete_relation(&mut self, oid: OID, key: Vec) { - self.staged_relation_ops + fn delete_relation(&self, oid: OID, key: Vec) { + let mut state = self.state.lock().unwrap(); + state + .staged_relation_ops .entry(oid) .or_default() .insert(key.clone(), None); - self.log_buffer.push(TxOp::Delete(XLDelete { + state.log_buffer.push(TxOp::Delete(XLDelete { table_id: oid, tuple_id: 0, key, })); } - pub fn get_relation(&self, oid: OID, key: &[u8]) -> Option>> { - self.staged_relation_ops + fn get_relation(&self, oid: OID, key: &[u8]) -> Option>> { + let state = self.state.lock().unwrap(); + state + .staged_relation_ops .get(&oid) - .and_then(|rows| rows.get(key).map(|value| value.clone())) + .and_then(|rows| rows.get(key).cloned()) } - pub fn staged_relation_items_in_range( + fn staged_relation_items_in_range( &self, oid: OID, start_key: &[u8], end_key: &[u8], ) -> Vec<(Vec, Option>)> { - self.staged_relation_ops + let state = self.state.lock().unwrap(); + state + .staged_relation_ops .get(&oid) .map(|rows| { rows.iter() @@ -103,64 +124,62 @@ impl WorkerTxManager { .unwrap_or_default() } - pub fn staged_relation_ops(&self) -> &BTreeMap, Option>>> { - &self.staged_relation_ops + fn staged_relation_ops(&self) -> BTreeMap, Option>>> { + let state = self.state.lock().unwrap(); + state.staged_relation_ops.clone() } - #[allow(dead_code)] - pub fn drain_relation_ops(&mut self) -> BTreeMap, Option>>> { - std::mem::take(&mut self.staged_relation_ops) - } - - pub fn staged_items_in_range( + fn staged_items_in_range( &self, start_key: &[u8], end_key: &[u8], ) -> Vec<(Vec, Option>)> { - self.staged_puts + let state = self.state.lock().unwrap(); + state + .staged_puts .iter() .filter(|(key, _)| is_key_in_range(key, start_key, end_key)) .map(|(key, value)| (key.clone(), value.clone())) .collect() } - pub fn staged_put_items(&self) -> BTreeMap, Option>> { - self.staged_puts.clone() + fn staged_put_items(&self) -> BTreeMap, Option>> { + let state = self.state.lock().unwrap(); + state.staged_puts.clone() } - pub fn write_ops(&self) -> &Vec<(OID, Vec)> { - &self.write_ops + fn is_empty(&self) -> bool { + let state = self.state.lock().unwrap(); + state.staged_puts.is_empty() && state.staged_relation_ops.is_empty() } - pub fn build_write_ops(&mut self) { - for (k, _) in self.staged_puts.iter() { - self.write_ops.push((0, k.clone())); - } - - for (oid, ops) in self.staged_relation_ops.iter() { - for (k, _) in ops.iter() { - self.write_ops.push((*oid, k.clone())); - } - } - self.write_ops.sort(); + fn write_ops(&self) -> Vec<(OID, Vec)> { + let state = self.state.lock().unwrap(); + state.write_ops.clone() } - pub fn xl_batch(&self) -> XLBatch { - let xid = self.snapshot.xid(); - let mut ops = Vec::with_capacity(self.log_buffer.len() + 2); - ops.push(TxOp::Begin); - ops.extend(self.log_buffer.clone()); - ops.push(TxOp::Commit); - XLBatch { - entries: vec![XLEntry { xid, ops }], + fn build_write_ops(&self) { + let mut state = self.state.lock().unwrap(); + state.write_ops.clear(); + let mut write_ops = Vec::new(); + for key in state.staged_puts.keys() { + write_ops.push((0, key.clone())); + } + for (oid, ops) in &state.staged_relation_ops { + for key in ops.keys() { + write_ops.push((*oid, key.clone())); + } } + state.write_ops = write_ops; + state.write_ops.sort(); } - pub fn into_xl_batch(self) -> XLBatch { + fn xl_batch(&self) -> XLBatch { + let state = self.state.lock().unwrap(); let xid = self.snapshot.xid(); - let mut ops = Vec::with_capacity(self.log_buffer.len() + 2); + let mut ops = Vec::with_capacity(state.log_buffer.len() + 2); ops.push(TxOp::Begin); - ops.extend(self.log_buffer); + ops.extend(state.log_buffer.clone()); ops.push(TxOp::Commit); XLBatch { entries: vec![XLEntry { xid, ops }], diff --git a/mudu_kernel/src/server/x_contract.rs b/mudu_kernel/src/server/x_contract.rs index 1a90cfd..b459ff2 100644 --- a/mudu_kernel/src/server/x_contract.rs +++ b/mudu_kernel/src/server/x_contract.rs @@ -1,22 +1,22 @@ use async_trait::async_trait; +use futures::executor::block_on; use mudu::common::buf::Buf; use mudu::common::id::{AttrIndex, OID}; use mudu::common::result::RS; -use mudu::common::xid::XID; use mudu::error::ec::EC; use mudu::m_error; use mudu_contract::tuple::build_tuple::build_tuple; use mudu_contract::tuple::tuple_binary::TupleBinary as TupleRaw; use mudu_contract::tuple::update_tuple::update_tuple; -use std::collections::HashMap; use std::ops::Bound; use std::sync::{Arc, Mutex}; use crate::contract::meta_mgr::MetaMgr; use crate::contract::schema_table::SchemaTable; use crate::contract::table_desc::TableDesc; +use crate::meta::meta_mgr_factory::MetaMgrFactory; use crate::server::worker_snapshot::{KvItem, WorkerSnapshot, WorkerSnapshotMgr}; -use crate::server::worker_storage::{PreparedWorkerCommit, WorkerStorage}; +use crate::server::worker_storage::WorkerStorage; use crate::server::worker_tx_manager::WorkerTxManager; use crate::server::x_lock_mgr::XLockMgr; use crate::wal::worker_log::ChunkedWorkerLogBackend; @@ -25,20 +25,16 @@ use crate::x_engine::api::{ AlterTable, Filter, OptDelete, OptInsert, OptRead, OptUpdate, Predicate, RSCursor, RangeData, TupleRow, VecDatum, VecSelTerm, XContract, }; +use crate::x_engine::tx_mgr::TxMgr; type DatBin = Buf; pub struct IoUringXContract { - inner: Mutex, - // commit_gate: AsyncMutex<()>, -} - -struct IoUringXContractInner { meta_mgr: Arc, storage: Arc, log: Option, snapshot_mgr: WorkerSnapshotMgr, - tx_ctx: HashMap, tx_lock: XLockMgr, + // commit_gate: AsyncMutex<()>, } struct VecCursor { @@ -65,15 +61,17 @@ impl IoUringXContract { partition_id: OID, data_dir: String, ) -> Self { + let storage = Arc::new(WorkerStorage::new(meta_mgr.clone(), partition_id, data_dir)); + storage.register_global(); + storage + .bootstrap_existing_tables_sync() + .unwrap_or_else(|e| panic!("bootstrap worker storage from meta failed: {e}")); Self { - inner: Mutex::new(IoUringXContractInner { - meta_mgr: meta_mgr.clone(), - storage: Arc::new(WorkerStorage::new(meta_mgr, partition_id, data_dir)), - log, - snapshot_mgr: WorkerSnapshotMgr::default(), - tx_ctx: HashMap::new(), - tx_lock: XLockMgr::new(), - }), + meta_mgr: meta_mgr.clone(), + storage, + log, + snapshot_mgr: WorkerSnapshotMgr::default(), + tx_lock: XLockMgr::new(), } } @@ -86,46 +84,30 @@ impl IoUringXContract { partition_id: OID, data_dir: String, ) -> Self { - let meta_mgr: Arc = Arc::new(NoopMetaMgr); - Self { - inner: Mutex::new(IoUringXContractInner { - meta_mgr: meta_mgr.clone(), - storage: Arc::new(WorkerStorage::new(meta_mgr, partition_id, data_dir)), - log: Some(log.clone()), - snapshot_mgr: WorkerSnapshotMgr::default(), - tx_ctx: HashMap::new(), - tx_lock: XLockMgr::new(), - }), - } - } - - fn lock_inner(&self) -> RS> { - self.inner - .lock() - .map_err(|_| m_error!(EC::InternalErr, "io_uring xcontract lock poisoned")) + let meta_mgr = MetaMgrFactory::create(data_dir.clone()) + .unwrap_or_else(|e| panic!("create worker meta manager failed: {e}")); + Self::with_log_and_data_dir(meta_mgr, Some(log.clone()), partition_id, data_dir) } pub fn worker_log(&self) -> Option { - self.lock_inner().ok().and_then(|guard| guard.log.clone()) + self.log.clone() } - pub fn worker_begin_tx(&self) -> RS { - let mut guard = self.lock_inner()?; - Ok(guard.snapshot_mgr.begin_tx()) + pub fn worker_begin_tx(&self) -> RS> { + Ok(Arc::new(WorkerTxManager::new(self.snapshot_mgr.begin_tx()))) } - pub fn worker_rollback_tx(&self, xid: u64) -> RS<()> { - self.lock_inner()?.snapshot_mgr.end_tx(xid) + pub fn worker_rollback_tx(&self, tx_mgr: Arc) -> RS<()> { + self.snapshot_mgr.end_tx(tx_mgr.xid()) } pub fn worker_put(&self, key: Vec, value: Vec) -> RS<()> { let prepared = { - let mut guard = self.lock_inner()?; - let xid = guard.snapshot_mgr.alloc_committed_ts(); + let xid = self.snapshot_mgr.alloc_committed_ts(); ( - guard.storage.clone(), - guard.log.clone(), - guard.storage.prepare_worker_kv_autocommit( + self.storage.clone(), + self.log.clone(), + self.storage.prepare_worker_kv_autocommit( xid, key.clone(), Some(value.clone()), @@ -142,12 +124,11 @@ impl IoUringXContract { pub async fn worker_put_async(&self, key: Vec, value: Vec) -> RS<()> { let (storage, log, prepared) = { - let mut guard = self.lock_inner()?; - let xid = guard.snapshot_mgr.alloc_committed_ts(); + let xid = self.snapshot_mgr.alloc_committed_ts(); ( - guard.storage.clone(), - guard.log.clone(), - guard.storage.prepare_worker_kv_autocommit( + self.storage.clone(), + self.log.clone(), + self.storage.prepare_worker_kv_autocommit( xid, key.clone(), Some(value.clone()), @@ -158,18 +139,17 @@ impl IoUringXContract { if let Some(log) = log { new_xl_batch_writer(log).append(prepared.batch()).await?; } - storage.apply_prepared_commit(prepared) + storage.apply_prepared_commit_async(prepared).await } pub fn worker_delete(&self, key: &[u8]) -> RS<()> { let key = key.to_vec(); let prepared = { - let mut guard = self.lock_inner()?; - let xid = guard.snapshot_mgr.alloc_committed_ts(); + let xid = self.snapshot_mgr.alloc_committed_ts(); ( - guard.storage.clone(), - guard.log.clone(), - guard.storage.prepare_worker_kv_autocommit( + self.storage.clone(), + self.log.clone(), + self.storage.prepare_worker_kv_autocommit( xid, key.clone(), None, @@ -187,12 +167,11 @@ impl IoUringXContract { pub async fn worker_delete_async(&self, key: &[u8]) -> RS<()> { let key = key.to_vec(); let (storage, log, prepared) = { - let mut guard = self.lock_inner()?; - let xid = guard.snapshot_mgr.alloc_committed_ts(); + let xid = self.snapshot_mgr.alloc_committed_ts(); ( - guard.storage.clone(), - guard.log.clone(), - guard.storage.prepare_worker_kv_autocommit( + self.storage.clone(), + self.log.clone(), + self.storage.prepare_worker_kv_autocommit( xid, key.clone(), None, @@ -203,12 +182,23 @@ impl IoUringXContract { if let Some(log) = log { new_xl_batch_writer(log).append(prepared.batch()).await?; } - storage.apply_prepared_commit(prepared) + storage.apply_prepared_commit_async(prepared).await + } + + pub async fn worker_get_async(&self, key: &[u8]) -> RS>> { + self.storage.kv_get(key, None).await + } + + pub async fn worker_get_with_snapshot_async( + &self, + snapshot: &WorkerSnapshot, + key: &[u8], + ) -> RS>> { + self.storage.kv_get(key, Some(snapshot)).await } pub fn worker_get(&self, key: &[u8]) -> RS>> { - let storage = { self.lock_inner()?.storage.clone() }; - storage.worker_get(key, None) + block_on(self.storage.kv_get(key, None)) } pub fn worker_get_with_snapshot( @@ -216,13 +206,19 @@ impl IoUringXContract { snapshot: &WorkerSnapshot, key: &[u8], ) -> RS>> { - let storage = { self.lock_inner()?.storage.clone() }; - storage.worker_get(key, Some(snapshot)) + block_on(self.storage.kv_get(key, Some(snapshot))) } pub fn worker_range_scan(&self, start_key: &[u8], end_key: &[u8]) -> RS> { - let storage = { self.lock_inner()?.storage.clone() }; - storage.worker_range(start_key, end_key, None) + block_on(self.storage.kv_range(start_key, end_key, None)) + } + + pub async fn worker_range_scan_async( + &self, + start_key: &[u8], + end_key: &[u8], + ) -> RS> { + self.storage.kv_range(start_key, end_key, None).await } pub fn worker_range_scan_with_snapshot( @@ -231,8 +227,18 @@ impl IoUringXContract { start_key: &[u8], end_key: &[u8], ) -> RS> { - let storage = { self.lock_inner()?.storage.clone() }; - storage.worker_range(start_key, end_key, Some(snapshot)) + block_on(self.storage.kv_range(start_key, end_key, Some(snapshot))) + } + + pub async fn worker_range_scan_with_snapshot_async( + &self, + snapshot: &WorkerSnapshot, + start_key: &[u8], + end_key: &[u8], + ) -> RS> { + self.storage + .kv_range(start_key, end_key, Some(snapshot)) + .await } pub fn worker_commit_put_batch( @@ -243,21 +249,20 @@ impl IoUringXContract { batch: XLBatch, ) -> RS<()> { if items.is_empty() { - return self.worker_rollback_tx(xid); + return self.snapshot_mgr.end_tx(xid); } let (storage, log, prepared) = { - let guard = self.lock_inner()?; - let prepared = guard + let prepared = self .storage .prepare_worker_kv_commit(snapshot, xid, items, batch)?; - (guard.storage.clone(), guard.log.clone(), prepared) + (self.storage.clone(), self.log.clone(), prepared) }; if let Some(log) = log { new_xl_batch_writer(log.clone()).append_sync(prepared.batch())?; log.flush()?; } storage.apply_prepared_commit(prepared)?; - self.worker_rollback_tx(xid) + self.snapshot_mgr.end_tx(xid) } pub async fn worker_commit_put_batch_async( @@ -268,14 +273,13 @@ impl IoUringXContract { batch: XLBatch, ) -> RS<()> { if items.is_empty() { - return self.worker_rollback_tx(xid); + return self.snapshot_mgr.end_tx(xid); } let (storage, log, prepared) = { - let guard = self.lock_inner()?; - let prepared = guard + let prepared = self .storage .prepare_worker_kv_commit(snapshot, xid, items, batch)?; - (guard.storage.clone(), guard.log.clone(), prepared) + (self.storage.clone(), self.log.clone(), prepared) }; if let Some(log) = log { new_xl_batch_writer(log.clone()) @@ -283,20 +287,81 @@ impl IoUringXContract { .await?; log.flush_async().await?; } - storage.apply_prepared_commit(prepared)?; - self.worker_rollback_tx(xid) + storage.apply_prepared_commit_async(prepared).await?; + self.snapshot_mgr.end_tx(xid) } - pub fn replay_worker_log_batch(&self, batch: XLBatch) -> RS<()> { - let max_xid = batch.entries.iter().map(|entry| entry.xid).max(); - let storage = { - let mut guard = self.lock_inner()?; - if let Some(max_xid) = max_xid { - guard.snapshot_mgr.observe_committed_ts(max_xid); + pub fn worker_commit_tx(&self, tx: Arc) -> RS<()> { + let xid = tx.xid(); + if tx.is_empty() { + return self.worker_rollback_tx(tx); + } + tx.build_write_ops(); + let (storage, log, prepared) = { + let write_ops = tx.write_ops(); + let can_commit = self.tx_lock.try_lock_some(xid as OID, &write_ops); + if !can_commit { + return Err(m_error!( + EC::TxErr, + format!("transaction {} failed to acquire commit locks", xid) + )); } - guard.storage.clone() + let prepared = self.storage.prepare_commit(tx.as_ref())?; + (self.storage.clone(), self.log.clone(), prepared) }; - storage.replay_batch(batch) + let result = (|| -> RS<()> { + if let Some(log) = log { + new_xl_batch_writer(log.clone()).append_sync(prepared.batch())?; + log.flush()?; + } + storage.apply_prepared_commit(prepared)?; + Ok(()) + })(); + let write_ops = tx.write_ops(); + self.tx_lock.release(xid as OID, &write_ops); + self.worker_rollback_tx(tx)?; + result + } + + pub async fn worker_commit_tx_async(&self, tx: Arc) -> RS<()> { + let xid = tx.xid(); + if tx.is_empty() { + return self.worker_rollback_tx(tx); + } + tx.build_write_ops(); + let (storage, log, prepared) = { + let write_ops = tx.write_ops(); + let can_commit = self.tx_lock.try_lock_some(xid as OID, &write_ops); + if !can_commit { + return Err(m_error!( + EC::TxErr, + format!("transaction {} failed to acquire commit locks", xid) + )); + } + let prepared = self.storage.prepare_commit_async(tx.as_ref()).await?; + (self.storage.clone(), self.log.clone(), prepared) + }; + let result = async { + if let Some(log) = log { + new_xl_batch_writer(log.clone()).append(prepared.batch()).await?; + log.flush_async().await?; + } + storage.apply_prepared_commit_async(prepared).await?; + Ok(()) + } + .await; + let write_ops = tx.write_ops(); + self.tx_lock.release(xid as OID, &write_ops); + self.worker_rollback_tx(tx)?; + result + } + + pub fn replay_worker_log_batch(&self, batch: XLBatch) -> RS<()> { + let max_xid = batch.entries.iter().map(|entry| entry.xid).max(); + if let Some(max_xid) = max_xid { + self.snapshot_mgr.observe_committed_ts(max_xid); + } + self.storage.replay_batch(batch) } } @@ -310,86 +375,15 @@ fn default_worker_storage_data_dir() -> String { .to_string() } -struct NoopMetaMgr; - -#[async_trait] -impl MetaMgr for NoopMetaMgr { - async fn get_table_by_id(&self, oid: OID) -> RS> { - Err(m_error!( - EC::NoSuchElement, - format!("no such table {} in worker-local io_uring xcontract", oid) - )) - } - - async fn get_table_by_name(&self, _name: &String) -> RS>> { - Ok(None) - } - - async fn create_table(&self, _schema: &SchemaTable) -> RS<()> { - Err(m_error!( - EC::NotImplemented, - "create table is not available in worker-local io_uring xcontract" - )) - } - - async fn drop_table(&self, _table_id: OID) -> RS<()> { - Err(m_error!( - EC::NotImplemented, - "drop table is not available in worker-local io_uring xcontract" - )) - } -} - -impl IoUringXContractInner { - fn begin_tx(&mut self) -> XID { - let snapshot = self.snapshot_mgr.begin_tx(); - let xid = snapshot.xid() as XID; - self.tx_ctx.insert(xid, WorkerTxManager::new(snapshot)); - xid - } - - #[allow(dead_code)] - fn commit_tx(&mut self, xid: XID) -> RS<()> { - let mut tx = self.take_tx(xid)?; - let result = self.storage.commit_tx(&mut tx); - self.end_tx(xid); - result - } - - fn commit_tx_prepare( - &mut self, - xid: XID, - ) -> RS<( - Option, - WorkerTxManager, - Arc, - Option, - )> { - let mut tx = self.take_tx(xid)?; - tx.build_write_ops(); - let can_commit = self.tx_lock.try_lock_some(xid, tx.write_ops()); - if can_commit { - let prepared = self.storage.prepare_commit(&tx)?; - Ok((Some(prepared), tx, self.storage.clone(), self.log.clone())) - } else { - Ok((None, tx, self.storage.clone(), self.log.clone())) - } - } - - fn finish_tx(&mut self, xid: XID) { - self.end_tx(xid); - } - - fn abort_tx(&mut self, xid: XID) -> RS<()> { - let _ = self.take_tx(xid)?; - self.end_tx(xid); - Ok(()) +impl IoUringXContract { + fn _begin_tx(&self) -> Arc { + Arc::new(WorkerTxManager::new(self.snapshot_mgr.begin_tx())) } - fn insert( - &mut self, + async fn _insert( + & self, desc: Arc, - xid: XID, + tx_mgr: Arc, table_id: OID, keys: &VecDatum, values: &VecDatum, @@ -397,35 +391,35 @@ impl IoUringXContractInner { ) -> RS<()> { let key = build_key_tuple(keys, &desc)?; let value = build_value_tuple(values, &desc)?; - let mut tx = self.take_tx(xid)?; - let result = self.storage.insert(table_id, key, value, &mut tx); - self.tx_ctx.insert(xid, tx); - result + let contain_key = self.storage.get(table_id, &key, tx_mgr.as_ref()).await?; + if contain_key.is_some() { + Err(m_error!(EC::ExistingSuchElement, "existing key")) + } else { + self.storage.put(table_id, key, value, tx_mgr.as_ref()).await + } } - fn read_key( - &mut self, + async fn _read_key( + & self, desc: Arc, - xid: XID, + tx_mgr: Arc, table_id: OID, pred_key: &VecDatum, select: &VecSelTerm, _opt_read: &OptRead, ) -> RS>> { let key = build_key_tuple(pred_key, &desc)?; - let mut tx = self.take_tx(xid)?; - let opt_value = self.storage.get(table_id, &key, &mut tx)?; - self.tx_ctx.insert(xid, tx); + let opt_value = self.storage.get(table_id, &key, tx_mgr.as_ref()).await?; match opt_value { Some(value) => project_selected_fields(&desc, &key, &value, select).map(Some), None => Ok(None), } } - fn read_range( - &mut self, + async fn _read_range( + & self, desc: Arc, - xid: XID, + tx_mgr: Arc, table_id: OID, pred_key: &RangeData, pred_non_key: &Predicate, @@ -435,9 +429,7 @@ impl IoUringXContractInner { ensure_supported_predicate(pred_non_key)?; let start = build_bound_key(pred_key.start(), &desc)?; let end = build_bound_key(pred_key.end(), &desc)?; - let mut tx = self.take_tx(xid)?; - let rows = self.storage.range(table_id, (start, end), &mut tx)?; - self.tx_ctx.insert(xid, tx); + let rows = self.storage.range(table_id, (start, end), tx_mgr.as_ref()).await?; let projected = rows .into_iter() .map(|(key, value)| { @@ -452,10 +444,10 @@ impl IoUringXContractInner { })) } - fn delete( - &mut self, + async fn _delete( + & self, desc: Arc, - xid: XID, + tx_mgr: Arc, table_id: OID, pred_key: &VecDatum, pred_non_key: &Predicate, @@ -463,16 +455,14 @@ impl IoUringXContractInner { ) -> RS { ensure_supported_predicate(pred_non_key)?; let key = build_key_tuple(pred_key, &desc)?; - let mut tx = self.take_tx(xid)?; - let deleted = self.storage.remove(table_id, &key, &mut tx)?; - self.tx_ctx.insert(xid, tx); + let deleted = self.storage.remove(table_id, &key, tx_mgr.as_ref()).await?; Ok(usize::from(deleted.is_some())) } - fn update( - &mut self, + async fn _update( + & self, desc: Arc, - xid: XID, + tx_mgr: Arc, table_id: OID, pred_key: &VecDatum, pred_non_key: &Predicate, @@ -481,179 +471,139 @@ impl IoUringXContractInner { ) -> RS { ensure_supported_predicate(pred_non_key)?; let key = build_key_tuple(pred_key, &desc)?; - let mut tx = self.take_tx(xid)?; - let current = self.storage.get(table_id, &key, &mut tx)?; + let current = self.storage.get(table_id, &key, tx_mgr.as_ref()).await?; let Some(current) = current else { - self.tx_ctx.insert(xid, tx); return Ok(0); }; let updated = apply_value_update(¤t, values, &desc)?; - let result = self.storage.insert(table_id, key, updated, &mut tx); - self.tx_ctx.insert(xid, tx); - result.map(|()| 1) - } - - fn take_tx(&mut self, xid: XID) -> RS { - self.tx_ctx - .remove(&xid) - .ok_or_else(|| m_error!(EC::NoSuchElement, format!("no such transaction {}", xid))) - } - - fn end_tx(&mut self, xid: XID) { - let _ = self.snapshot_mgr.end_tx(xid as u64); + self.storage + .put(table_id, key, updated, tx_mgr.as_ref()) + .await + .map(|()| 1) } } #[async_trait] impl XContract for IoUringXContract { - async fn create_table(&self, _xid: XID, schema: &SchemaTable) -> RS<()> { - let storage = { - let guard = self.lock_inner()?; - guard.storage.clone() - }; - storage.create_table_async(schema).await + async fn create_table(&self, _tx_mgr: Arc, schema: &SchemaTable) -> RS<()> { + self.storage.create_table_async(schema).await } - async fn drop_table(&self, _xid: XID, oid: OID) -> RS<()> { - let storage = { - let guard = self.lock_inner()?; - guard.storage.clone() - }; - storage.drop_table_async(oid).await + async fn drop_table(&self, _tx_mgr: Arc, oid: OID) -> RS<()> { + self.storage.drop_table_async(oid).await } - async fn alter_table(&self, _xid: XID, _oid: OID, _alter_table: &AlterTable) -> RS<()> { + async fn alter_table( + &self, + _tx_mgr: Arc, + _oid: OID, + _alter_table: &AlterTable, + ) -> RS<()> { Err(m_error!( EC::NotImplemented, "alter table is not implemented" )) } - async fn begin_tx(&self) -> RS { - Ok(self.lock_inner()?.begin_tx()) + async fn begin_tx(&self) -> RS> { + Ok(self._begin_tx()) } - async fn commit_tx(&self, xid: XID) -> RS<()> { - let prepared = { - let mut guard = self.lock_inner()?; - - guard.commit_tx_prepare(xid) - }; - let result = match prepared { - Ok((opt_prepared, tx, storage, log)) => { - if let Some(prepared) = opt_prepared { - if let Some(log) = log { - new_xl_batch_writer(log.clone()) - .append(prepared.batch()) - .await?; - log.flush_async().await?; - } - storage.apply_prepared_commit(prepared)?; - { - let guard = self.inner.lock().unwrap(); - guard.tx_lock.release(xid, tx.write_ops()); - Ok(()) - } - } else { - let guard = self.inner.lock().unwrap(); - guard.tx_lock.release(xid, tx.write_ops()); - Ok(()) - } - } - Err(err) => Err(err), - }; - self.lock_inner()?.finish_tx(xid); - result + async fn commit_tx(&self, tx_mgr: Arc) -> RS<()> { + self.worker_commit_tx_async(tx_mgr).await } - async fn abort_tx(&self, xid: XID) -> RS<()> { - self.lock_inner()?.abort_tx(xid) + async fn abort_tx(&self, tx_mgr: Arc) -> RS<()> { + self.worker_rollback_tx(tx_mgr) } async fn update( &self, - xid: XID, + tx_mgr: Arc, table_id: OID, pred_key: &VecDatum, pred_non_key: &Predicate, values: &VecDatum, opt_update: &OptUpdate, ) -> RS { - let meta_mgr = { self.lock_inner()?.meta_mgr.clone() }; - let desc = meta_mgr.get_table_by_id(table_id).await?; - self.lock_inner()?.update( + let desc = self.meta_mgr.get_table_by_id(table_id).await?; + self._update( desc, - xid, + tx_mgr, table_id, pred_key, pred_non_key, values, opt_update, ) + .await } async fn read_key( &self, - xid: XID, + tx_mgr: Arc, table_id: OID, pred_key: &VecDatum, select: &VecSelTerm, opt_read: &OptRead, ) -> RS>> { - let meta_mgr = { self.lock_inner()?.meta_mgr.clone() }; - let desc = meta_mgr.get_table_by_id(table_id).await?; - self.lock_inner()? - .read_key(desc, xid, table_id, pred_key, select, opt_read) + let desc = self.meta_mgr.get_table_by_id(table_id).await?; + self._read_key(desc, tx_mgr, table_id, pred_key, select, opt_read) + .await } async fn read_range( &self, - xid: XID, + tx_mgr: Arc, table_id: OID, pred_key: &RangeData, pred_non_key: &Predicate, select: &VecSelTerm, opt_read: &OptRead, ) -> RS> { - let meta_mgr = { self.lock_inner()?.meta_mgr.clone() }; - let desc = meta_mgr.get_table_by_id(table_id).await?; - self.lock_inner()?.read_range( + let desc = self.meta_mgr.get_table_by_id(table_id).await?; + self._read_range( desc, - xid, + tx_mgr, table_id, pred_key, pred_non_key, select, opt_read, ) + .await } async fn delete( &self, - xid: XID, + tx_mgr: Arc, table_id: OID, pred_key: &VecDatum, pred_non_key: &Predicate, opt_delete: &OptDelete, ) -> RS { - let meta_mgr = { self.lock_inner()?.meta_mgr.clone() }; - let desc = meta_mgr.get_table_by_id(table_id).await?; - self.lock_inner()? - .delete(desc, xid, table_id, pred_key, pred_non_key, opt_delete) + let desc = self.meta_mgr.get_table_by_id(table_id).await?; + self._delete(desc, tx_mgr, table_id, pred_key, pred_non_key, opt_delete) + .await } async fn insert( &self, - xid: XID, + tx_mgr: Arc, table_id: OID, keys: &VecDatum, values: &VecDatum, opt_insert: &OptInsert, ) -> RS<()> { - let meta_mgr = { self.lock_inner()?.meta_mgr.clone() }; - let desc = meta_mgr.get_table_by_id(table_id).await?; - self.lock_inner()? - .insert(desc, xid, table_id, keys, values, opt_insert) + let desc = self.meta_mgr.get_table_by_id(table_id).await?; + self._insert(desc, tx_mgr, table_id, keys, values, opt_insert) + .await + } +} + +impl IoUringXContract { + pub fn meta_mgr(&self) -> Arc { + self.meta_mgr.clone() } } @@ -706,7 +656,7 @@ fn build_tuple_for( let mut ok = true; vec_data.sort_by(|(id1, _), (id2, _)| { let (f1, f2) = (desc.get_attr(*id1), desc.get_attr(*id2)); - if f1.is_primary() != IS_KEY || f2.is_primary() != IS_KEY { + if f1.primary_index().is_some() != IS_KEY || f2.primary_index().is_some() != IS_KEY { ok = false; } f1.datum_index().cmp(&f2.datum_index()) @@ -753,12 +703,16 @@ fn project_selected_fields( for i in select.vec() { let f = desc.get_attr(*i); let index = f.datum_index(); - let field_desc = if f.is_primary() { + let field_desc = if f.primary_index().is_some() { desc.key_desc().get_field_desc(index) } else { desc.value_desc().get_field_desc(index) }; - let src = if f.is_primary() { key } else { value }; + let src = if f.primary_index().is_some() { + key + } else { + value + }; let slice = field_desc.get(src)?; tuple_ret.push(slice.to_vec()); } @@ -768,10 +722,17 @@ fn project_selected_fields( fn apply_value_update(current: &TupleRaw, values: &VecDatum, desc: &TableDesc) -> RS> { let mut updated = current.clone(); let mut data = values.data().clone(); - data.sort_by(|(id1, _), (id2, _)| id1.cmp(id2)); + data.sort_by_key(|(attr, _)| desc.get_attr(*attr).datum_index()); for (id, dat) in data.iter() { + let field = desc.get_attr(*id); let mut delta = vec![]; - update_tuple(*id as _, dat, desc.value_desc(), current, &mut delta)?; + update_tuple( + field.datum_index() as usize, + dat, + desc.value_desc(), + current, + &mut delta, + )?; for item in delta { item.apply_to(&mut updated); } @@ -877,16 +838,20 @@ mod tests { fn test_schema() -> SchemaTable { SchemaTable::new( "t".to_string(), - vec![SchemaColumn::new( - "id".to_string(), - DatTypeID::I32, - DTInfo::from_text(DatTypeID::I32, String::new()), - )], - vec![SchemaColumn::new( - "v".to_string(), - DatTypeID::I32, - DTInfo::from_text(DatTypeID::I32, String::new()), - )], + vec![ + SchemaColumn::new( + "id".to_string(), + DatTypeID::I32, + DTInfo::from_text(DatTypeID::I32, String::new()), + ), + SchemaColumn::new( + "v".to_string(), + DatTypeID::I32, + DTInfo::from_text(DatTypeID::I32, String::new()), + ), + ], + vec![0], + vec![1], ) } @@ -923,10 +888,8 @@ mod tests { 9, vec![], )); - storage - .insert(table_id, b"k1".to_vec(), b"v1".to_vec(), &mut txm) - .unwrap(); - storage.remove(table_id, b"k1", &mut txm).unwrap(); + block_on(storage.put(table_id, b"k1".to_vec(), b"v1".to_vec(), &mut txm)).unwrap(); + block_on(storage.remove(table_id, b"k1", &mut txm)).unwrap(); let prepared = storage.prepare_commit(&txm).unwrap(); assert_eq!(prepared.batch().entries.len(), 1); @@ -944,17 +907,19 @@ mod tests { let table_id = schema.id(); let contract = IoUringXContract::with_log(meta_mgr, Some(log)); - block_on(contract.create_table(0, &schema)).unwrap(); - let xid = block_on(contract.begin_tx()).unwrap(); + let ddl_tx = block_on(contract.begin_tx()).unwrap(); + block_on(contract.create_table(ddl_tx.clone(), &schema)).unwrap(); + block_on(contract.commit_tx(ddl_tx)).unwrap(); + let tx_mgr = block_on(contract.begin_tx()).unwrap(); block_on(contract.insert( - xid, + tx_mgr.clone(), table_id, &key_row(1), &value_row(10), &OptInsert::default(), )) .unwrap(); - block_on(contract.commit_tx(xid)).unwrap(); + block_on(contract.commit_tx(tx_mgr)).unwrap(); let bytes = std::fs::read(layout.chunk_path(0)).unwrap(); let frames = decode_frames(&bytes).unwrap(); @@ -986,7 +951,9 @@ mod tests { let table_id = schema.id(); let contract = IoUringXContract::with_log(meta_mgr, None); - block_on(contract.create_table(0, &schema)).unwrap(); + let tx_mgr = block_on(contract.begin_tx()).unwrap(); + block_on(contract.create_table(tx_mgr.clone(), &schema)).unwrap(); + block_on(contract.commit_tx(tx_mgr)).unwrap(); let batch = XLBatch { entries: vec![crate::wal::xl_entry::XLEntry { xid: 11, @@ -1061,6 +1028,53 @@ mod tests { assert_eq!(contract.worker_get(b"wk").unwrap(), None); } + #[test] + fn iouring_xcontract_update_maps_table_attr_to_value_tuple_index() { + let meta_mgr = Arc::new(TestMetaMgr::new()); + let schema = test_schema(); + let table_id = schema.id(); + let contract = IoUringXContract::with_log(meta_mgr, None); + + let ddl_tx = block_on(contract.begin_tx()).unwrap(); + block_on(contract.create_table(ddl_tx.clone(), &schema)).unwrap(); + block_on(contract.commit_tx(ddl_tx)).unwrap(); + + let insert_tx = block_on(contract.begin_tx()).unwrap(); + block_on(contract.insert( + insert_tx.clone(), + table_id, + &key_row(1), + &value_row(10), + &OptInsert::default(), + )) + .unwrap(); + block_on(contract.commit_tx(insert_tx)).unwrap(); + + let update_tx = block_on(contract.begin_tx()).unwrap(); + let updated = block_on(contract.update( + update_tx.clone(), + table_id, + &key_row(1), + &Predicate::CNF(vec![]), + &value_row(20), + &OptUpdate {}, + )) + .unwrap(); + assert_eq!(updated, 1); + block_on(contract.commit_tx(update_tx)).unwrap(); + + let read_tx = block_on(contract.begin_tx()).unwrap(); + let relation = block_on(contract.read_key( + read_tx, + table_id, + &key_row(1), + &VecSelTerm::new(vec![1]), + &OptRead::default(), + )) + .unwrap(); + assert_eq!(relation, Some(vec![datum(20)])); + } + fn meta_table(schema: &SchemaTable) -> RS> { TableInfo::new(schema.clone())?.table_desc() } diff --git a/mudu_kernel/src/server/x_lock_mgr.rs b/mudu_kernel/src/server/x_lock_mgr.rs index 16d143a..b7d2997 100644 --- a/mudu_kernel/src/server/x_lock_mgr.rs +++ b/mudu_kernel/src/server/x_lock_mgr.rs @@ -1,20 +1,20 @@ use mudu::common::id::OID; -use std::cell::RefCell; use std::collections::HashMap; +use std::sync::Mutex; pub struct XLockMgr { - lock: RefCell, OID>>>, + lock: Mutex, OID>>>, } impl XLockMgr { pub fn new() -> Self { Self { - lock: Default::default(), + lock: Mutex::new(HashMap::new()), } } pub fn try_lock_some(&self, oid: OID, table_keys: &Vec<(OID, Vec)>) -> bool { - let mut lock = self.lock.borrow_mut(); + let mut lock = self.lock.lock().unwrap(); for (table_oid, key) in table_keys.iter() { let map = lock.entry(table_oid.clone()).or_default(); if map.contains_key(key) { @@ -27,7 +27,7 @@ impl XLockMgr { } pub fn release(&self, oid: OID, table_keys: &Vec<(OID, Vec)>) { - let mut lock = self.lock.borrow_mut(); + let mut lock = self.lock.lock().unwrap(); for (table_oid, key) in table_keys.iter() { let map = lock.entry(table_oid.clone()).or_default(); if let Some(tx) = map.get(key) { diff --git a/mudu_kernel/src/sql/binder.rs b/mudu_kernel/src/sql/binder.rs index 082b71a..94b8242 100644 --- a/mudu_kernel/src/sql/binder.rs +++ b/mudu_kernel/src/sql/binder.rs @@ -94,16 +94,29 @@ impl Binder { stmt.assign_index_for_columns(); let key_columns = stmt .primary_columns() - .iter() + .into_iter() .map(Self::schema_column_from_ast) .collect::>>()?; let value_columns = stmt .non_primary_columns() - .iter() + .into_iter() .map(Self::schema_column_from_ast) .collect::>>()?; + let mut columns = key_columns; + let value_offset = columns.len(); + let mut value_columns = value_columns; + let key_indices = (0..columns.len()).collect(); + let value_indices = (0..value_columns.len()) + .map(|index| index + value_offset) + .collect(); + columns.append(&mut value_columns); Ok(BoundCreateTable { - schema: SchemaTable::new(stmt.table_name().clone(), key_columns, value_columns), + schema: SchemaTable::new( + stmt.table_name().clone(), + columns, + key_indices, + value_indices, + ), }) } @@ -137,7 +150,7 @@ impl Binder { } let columns = if stmt.columns().is_empty() { - let total = table_desc.key_info().len() + table_desc.value_info().len(); + let total = table_desc.fields().len(); (0..total) .map(|attr| table_desc.get_attr(attr).name().clone()) .collect::>() @@ -158,7 +171,7 @@ impl Binder { let field = table_desc.get_attr(attr); let binary = ValueCodec::binary_from_expr(expr, field.type_desc(), params, &mut param_index)?; - if field.is_primary() { + if field.primary_index().is_some() { key.push((attr, binary)); } else { value.push((attr, binary)); @@ -208,7 +221,7 @@ impl Binder { for assignment in stmt.get_set_values() { let attr = self.attr_index_by_name(&table_desc, assignment.get_column_reference())?; let field = table_desc.get_attr(attr); - if field.is_primary() { + if field.primary_index().is_some() { return Err(m_error!( ER::NotImplemented, "updating primary key columns is not implemented" @@ -282,7 +295,7 @@ impl Binder { })?; let attr = self.attr_index_by_name(table_desc, field_name)?; let field = table_desc.get_attr(attr); - if !field.is_primary() { + if field.primary_index().is_none() { return Err(m_error!( ER::NotImplemented, "non-key predicates are not implemented" @@ -341,15 +354,15 @@ impl Binder { ) -> RS)>> { match self.bind_predicate_from(table_desc, predicates, params, param_index)? { BoundPredicate::KeyEq { mut key } => { - if key.len() != table_desc.key_info().len() { + if key.len() != table_desc.key_indices().len() { return Err(m_error!( ER::NotImplemented, "update/delete require a complete primary key predicate" )); } - key.sort_by_key(|(attr, _)| *attr); + key.sort_by_key(|(attr, _)| table_desc.get_attr(*attr).primary_index().unwrap()); for (index, (attr, _)) in key.iter().enumerate() { - if *attr != index { + if table_desc.get_attr(*attr).primary_index() != Some(index) { return Err(m_error!( ER::NotImplemented, "update/delete require one equality predicate for each primary key column" @@ -387,14 +400,7 @@ impl Binder { } fn reverse_compare(op: ValueCompare) -> ValueCompare { - match op { - ValueCompare::EQ => ValueCompare::EQ, - ValueCompare::LE => ValueCompare::GT, - ValueCompare::LT => ValueCompare::GE, - ValueCompare::GE => ValueCompare::LT, - ValueCompare::GT => ValueCompare::LE, - ValueCompare::NE => ValueCompare::NE, - } + ValueCompare::revert_cmp_op(op) } fn schema_column_from_ast(column: &sql_parser::ast::column_def::ColumnDef) -> RS { @@ -404,7 +410,7 @@ impl Binder { ty.dat_type_id(), DTInfo::from_opt_object(&ty), ); - schema_column.set_primary(column.is_primary_key()); + schema_column.set_primary_index(column.primary_key_index()); schema_column.set_index(column.column_index()); Ok(schema_column) } @@ -421,7 +427,7 @@ impl Binder { } fn attr_index_by_name(&self, table_desc: &TableDesc, name: &str) -> RS { - let total = table_desc.key_info().len() + table_desc.value_info().len(); + let total = table_desc.fields().len(); (0..total) .find(|attr| table_desc.get_attr(*attr).name() == name) .ok_or_else(|| m_error!(ER::NoSuchElement, format!("cannot find column {}", name))) @@ -431,6 +437,6 @@ impl Binder { self.meta_mgr .get_table_by_name(name) .await? - .ok_or_else(|| m_error!(ER::NoSuchElement, format!("cannot find table {}", name))) + .ok_or_else(|| m_error!(ER::NoSuchElement, format!("no such table {}", name))) } } diff --git a/mudu_kernel/src/sql/binder_test.rs b/mudu_kernel/src/sql/binder_test.rs new file mode 100644 index 0000000..0dc8505 --- /dev/null +++ b/mudu_kernel/src/sql/binder_test.rs @@ -0,0 +1,313 @@ +#[cfg(test)] +mod tests { + use crate::contract::meta_mgr::MetaMgr; + use crate::contract::schema_column::SchemaColumn; + use crate::contract::schema_table::SchemaTable; + use crate::contract::table_desc::TableDesc; + use crate::contract::table_info::TableInfo; + use crate::sql::binder::Binder; + use crate::sql::bound_stmt::{BoundCommand, BoundPredicate, BoundQuery, BoundStmt}; + use async_trait::async_trait; + use mudu::common::id::OID; + use mudu::common::result::RS; + use mudu::error::ec::EC; + use mudu::m_error; + use mudu_type::dat_type::DatType; + use mudu_type::dat_type_id::DatTypeID; + use mudu_type::dt_info::DTInfo; + use sql_parser::ast::parser::SQLParser; + use sql_parser::ast::stmt_type::StmtType; + use std::collections::HashMap; + use std::sync::{Arc, Mutex}; + + struct TestMetaMgr { + tables: Mutex>>, + } + + impl TestMetaMgr { + fn new(schema: SchemaTable) -> Self { + let table = TableInfo::new(schema).unwrap().table_desc().unwrap(); + let mut tables = HashMap::new(); + tables.insert(table.id(), table); + Self { + tables: Mutex::new(tables), + } + } + } + + #[async_trait] + impl MetaMgr for TestMetaMgr { + async fn get_table_by_id(&self, oid: OID) -> RS> { + self.tables + .lock() + .unwrap() + .get(&oid) + .cloned() + .ok_or_else(|| m_error!(EC::NoSuchElement, format!("no such table {}", oid))) + } + + async fn get_table_by_name(&self, name: &String) -> RS>> { + Ok(self + .tables + .lock() + .unwrap() + .values() + .find(|table| table.name() == name) + .cloned()) + } + + async fn create_table(&self, schema: &SchemaTable) -> RS<()> { + let table = TableInfo::new(schema.clone())?.table_desc()?; + self.tables.lock().unwrap().insert(table.id(), table); + Ok(()) + } + + async fn drop_table(&self, table_id: OID) -> RS<()> { + self.tables.lock().unwrap().remove(&table_id); + Ok(()) + } + } + + fn schema() -> SchemaTable { + SchemaTable::new( + "users".to_string(), + vec![ + SchemaColumn::new( + "id".to_string(), + DatTypeID::I32, + DTInfo::from_opt_object(&DatType::default_for(DatTypeID::I32)), + ), + SchemaColumn::new( + "name".to_string(), + DatTypeID::String, + DTInfo::from_opt_object(&DatType::default_for(DatTypeID::String)), + ), + ], + vec![0], + vec![1], + ) + } + + fn composite_schema() -> SchemaTable { + SchemaTable::new( + "accounts".to_string(), + vec![ + SchemaColumn::new( + "tenant_id".to_string(), + DatTypeID::I32, + DTInfo::from_opt_object(&DatType::default_for(DatTypeID::I32)), + ), + SchemaColumn::new( + "user_id".to_string(), + DatTypeID::I32, + DTInfo::from_opt_object(&DatType::default_for(DatTypeID::I32)), + ), + SchemaColumn::new( + "name".to_string(), + DatTypeID::String, + DTInfo::from_opt_object(&DatType::default_for(DatTypeID::String)), + ), + ], + vec![0, 1], + vec![2], + ) + } + + fn parse_stmt(sql: &str) -> StmtType { + SQLParser::new().parse(sql).unwrap().stmts()[0].clone() + } + + fn binder() -> Binder { + Binder::new(Arc::new(TestMetaMgr::new(schema()))) + } + + fn composite_binder() -> Binder { + Binder::new(Arc::new(TestMetaMgr::new(composite_schema()))) + } + + #[tokio::test] + async fn bind_select_builds_key_eq_predicate() { + let bound = binder() + .bind(parse_stmt("select id from users where id = 1;"), &()) + .await + .unwrap(); + + let BoundStmt::Query(BoundQuery::Select(select)) = bound else { + panic!("expected bound select"); + }; + assert_eq!(select.select_attrs, vec![0]); + match select.predicate { + BoundPredicate::KeyEq { key } => assert_eq!(key.len(), 1), + other => panic!("expected key equality predicate, got {other:?}"), + } + } + + #[tokio::test] + async fn bind_select_reverses_value_column_comparison() { + let bound = binder() + .bind(parse_stmt("select id from users where ? = id;"), &(7i32,)) + .await + .unwrap(); + + let BoundStmt::Query(BoundQuery::Select(select)) = bound else { + panic!("expected bound select"); + }; + match select.predicate { + BoundPredicate::KeyEq { key } => assert_eq!(key.len(), 1), + other => panic!("expected key equality predicate, got {other:?}"), + } + } + + #[tokio::test] + async fn bind_select_builds_range_predicate_from_placeholder() { + let bound = binder() + .bind(parse_stmt("select id from users where id > ?;"), &(7i32,)) + .await + .unwrap(); + + let BoundStmt::Query(BoundQuery::Select(select)) = bound else { + panic!("expected bound select"); + }; + match select.predicate { + BoundPredicate::KeyRange { start, end } => { + assert!(matches!(start, std::ops::Bound::Excluded(_))); + assert!(matches!(end, std::ops::Bound::Unbounded)); + } + other => panic!("expected key range predicate, got {other:?}"), + } + } + + #[tokio::test] + async fn bind_select_rejects_not_equal_predicate() { + let err = binder() + .bind(parse_stmt("select id from users where id != 1;"), &()) + .await + .unwrap_err(); + + assert!(err + .to_string() + .contains("not-equal predicates are not implemented")); + } + + #[tokio::test] + async fn bind_select_rejects_mixed_equality_and_range_predicates() { + let err = binder() + .bind( + parse_stmt("select id from users where id = 1 AND id > 0;"), + &(), + ) + .await + .unwrap_err(); + + assert!(err + .to_string() + .contains("mixed equality and range predicates are not implemented")); + } + + #[tokio::test] + async fn bind_insert_without_column_list_uses_schema_order() { + let bound = binder() + .bind(parse_stmt("insert into users values (1, 'alice');"), &()) + .await + .unwrap(); + + let BoundStmt::Command(BoundCommand::Insert(insert)) = bound else { + panic!("expected bound insert"); + }; + assert_eq!(insert.key.len(), 1); + assert_eq!(insert.value.len(), 1); + } + + #[tokio::test] + async fn bind_insert_rejects_multi_row_insert() { + let err = binder() + .bind( + parse_stmt("insert into users (id, name) values (1, 'alice'), (2, 'bob');"), + &(), + ) + .await + .unwrap_err(); + + assert!(err + .to_string() + .contains("multi-row insert is not implemented")); + } + + #[tokio::test] + async fn bind_insert_rejects_column_size_mismatch() { + let err = binder() + .bind( + parse_stmt("insert into users (id) values (1, 'alice');"), + &(), + ) + .await + .unwrap_err(); + + assert!(err.to_string().contains("insert column size mismatch")); + } + + #[tokio::test] + async fn bind_update_rejects_primary_key_updates() { + let err = binder() + .bind(parse_stmt("update users set id = 2 where id = 1;"), &()) + .await + .unwrap_err(); + + assert!(err + .to_string() + .contains("updating primary key columns is not implemented")); + } + + #[tokio::test] + async fn bind_update_rejects_expression_updates() { + let err = binder() + .bind( + parse_stmt("update users set name = id + 1 where id = 1;"), + &(), + ) + .await + .unwrap_err(); + + assert!(err + .to_string() + .contains("expression updates are not implemented")); + } + + #[tokio::test] + async fn bind_delete_rejects_non_key_predicates() { + let err = binder() + .bind(parse_stmt("delete from users where name = 'alice';"), &()) + .await + .unwrap_err(); + + assert!(err + .to_string() + .contains("non-key predicates are not implemented")); + } + + #[tokio::test] + async fn bind_delete_requires_complete_composite_primary_key() { + let err = composite_binder() + .bind(parse_stmt("delete from accounts where tenant_id = 1;"), &()) + .await + .unwrap_err(); + + assert!(err.to_string().contains("complete primary key predicate")); + } + + #[tokio::test] + async fn bind_delete_accepts_complete_composite_primary_key() { + let bound = composite_binder() + .bind( + parse_stmt("delete from accounts where tenant_id = 1 AND user_id = 2;"), + &(), + ) + .await + .unwrap(); + + let BoundStmt::Command(BoundCommand::Delete(delete)) = bound else { + panic!("expected bound delete"); + }; + assert_eq!(delete.key.len(), 2); + } +} diff --git a/mudu_kernel/src/sql/copy_layout_test.rs b/mudu_kernel/src/sql/copy_layout_test.rs new file mode 100644 index 0000000..a69a412 --- /dev/null +++ b/mudu_kernel/src/sql/copy_layout_test.rs @@ -0,0 +1,85 @@ +#[cfg(test)] +mod tests { + use crate::contract::schema_column::SchemaColumn; + use crate::contract::schema_table::SchemaTable; + use crate::contract::table_info::TableInfo; + use crate::sql::copy_layout::CopyLayout; + use mudu_type::dat_type::DatType; + use mudu_type::dat_type_id::DatTypeID; + use mudu_type::dt_info::DTInfo; + + fn table_desc() -> std::sync::Arc { + let schema = SchemaTable::new( + "accounts".to_string(), + vec![ + SchemaColumn::new( + "tenant_id".to_string(), + DatTypeID::I32, + DTInfo::from_opt_object(&DatType::default_for(DatTypeID::I32)), + ), + SchemaColumn::new( + "user_id".to_string(), + DatTypeID::I32, + DTInfo::from_opt_object(&DatType::default_for(DatTypeID::I32)), + ), + SchemaColumn::new( + "name".to_string(), + DatTypeID::String, + DTInfo::from_opt_object(&DatType::default_for(DatTypeID::String)), + ), + ], + vec![0, 1], + vec![2], + ); + TableInfo::new(schema).unwrap().table_desc().unwrap() + } + + #[test] + fn empty_columns_use_schema_order() { + let table = table_desc(); + let layout = CopyLayout::new(&table, &[]).unwrap(); + assert_eq!(layout.key_index(), &[0, 1]); + assert_eq!(layout.value_index(), &[2]); + } + + #[test] + fn full_column_list_reorders_key_and_value_positions() { + let table = table_desc(); + let columns = vec![ + "user_id".to_string(), + "name".to_string(), + "tenant_id".to_string(), + ]; + let layout = CopyLayout::new(&table, &columns).unwrap(); + assert_eq!(layout.key_index(), &[2, 0]); + assert_eq!(layout.value_index(), &[1]); + } + + #[test] + fn copy_layout_rejects_partial_column_list() { + let table = table_desc(); + let columns = vec!["tenant_id".to_string(), "user_id".to_string()]; + let err = match CopyLayout::new(&table, &columns) { + Ok(_) => panic!("expected partial column list error"), + Err(err) => err, + }; + assert!(err + .to_string() + .contains("is not equal to the size specified")); + } + + #[test] + fn copy_layout_rejects_missing_named_column() { + let table = table_desc(); + let columns = vec![ + "tenant_id".to_string(), + "user_id".to_string(), + "missing".to_string(), + ]; + let err = match CopyLayout::new(&table, &columns) { + Ok(_) => panic!("expected missing column error"), + Err(err) => err, + }; + assert!(err.to_string().contains("cannot find column name name")); + } +} diff --git a/mudu_kernel/src/sql/describer.rs b/mudu_kernel/src/sql/describer.rs index 6877b03..e81d9a3 100644 --- a/mudu_kernel/src/sql/describer.rs +++ b/mudu_kernel/src/sql/describer.rs @@ -9,27 +9,27 @@ use sql_parser::ast::stmt_type::StmtType; use std::sync::Arc; pub struct Describer { - meta_mgr: Arc, + } impl Describer { - pub fn new(meta_mgr: Arc) -> Self { - Self { meta_mgr } + pub fn new() -> Self { + Self { } } - pub async fn describe(&self, stmt: StmtType) -> RS { + pub async fn describe(meta_mgr:&dyn MetaMgr, stmt: StmtType) -> RS { match stmt { - StmtType::Select(stmt) => self.describe_select(stmt).await, + StmtType::Select(stmt) => Self::describe_select(meta_mgr, stmt).await, StmtType::Command(_) => Ok(TupleFieldDesc::new(Vec::new())), } } async fn describe_select( - &self, + meta_mgr:&dyn MetaMgr, stmt: sql_parser::ast::stmt_select::StmtSelect, ) -> RS { - let table_desc = self.get_table_by_name(stmt.get_table_reference()).await?; - let select_attrs = self.select_attrs(&table_desc, stmt.get_select_term_list())?; + let table_desc = Self::get_table_by_name(meta_mgr, stmt.get_table_reference()).await?; + let select_attrs = Self::select_attrs(&table_desc, stmt.get_select_term_list())?; Ok(project_tuple_desc( &table_desc, &crate::x_engine::api::VecSelTerm::new(select_attrs), @@ -37,27 +37,26 @@ impl Describer { } fn select_attrs( - &self, table_desc: &TableDesc, terms: &[sql_parser::ast::select_term::SelectTerm], ) -> RS> { terms .iter() - .map(|term| self.attr_index_by_name(table_desc, term.field().name())) + .map(|term| Self::attr_index_by_name(table_desc, term.field().name())) .collect() } - fn attr_index_by_name(&self, table_desc: &TableDesc, name: &str) -> RS { - let total = table_desc.key_info().len() + table_desc.value_info().len(); + fn attr_index_by_name(table_desc: &TableDesc, name: &str) -> RS { + let total = table_desc.fields().len(); (0..total) .find(|attr| table_desc.get_attr(*attr).name() == name) .ok_or_else(|| m_error!(ER::NoSuchElement, format!("cannot find column {}", name))) } - async fn get_table_by_name(&self, name: &String) -> RS> { - self.meta_mgr + async fn get_table_by_name(meta_mgr:&dyn MetaMgr, name: &String) -> RS> { + meta_mgr .get_table_by_name(name) .await? - .ok_or_else(|| m_error!(ER::NoSuchElement, format!("cannot find table {}", name))) + .ok_or_else(|| m_error!(ER::NoSuchElement, format!("no such table {}", name))) } } diff --git a/mudu_kernel/src/sql/mod.rs b/mudu_kernel/src/sql/mod.rs index 96972e8..d89379d 100644 --- a/mudu_kernel/src/sql/mod.rs +++ b/mudu_kernel/src/sql/mod.rs @@ -2,7 +2,11 @@ mod cmp_pred; mod copy_layout; +#[cfg(test)] +mod copy_layout_test; mod value_codec; +#[cfg(test)] +mod value_codec_test; pub mod stmt_cmd_run; @@ -13,6 +17,8 @@ pub mod plan_ctx; pub mod planner; pub mod proj_list; +#[cfg(test)] +mod binder_test; pub mod stmt_cmd; mod stmt_create_table; @@ -25,5 +31,9 @@ mod stmt_copy_from; mod stmt_copy_to; mod proj_field; +#[cfg(test)] +mod stmt_cmd_run_test; pub mod stmt_query; pub mod stmt_query_run; +#[cfg(test)] +mod stmt_query_run_test; diff --git a/mudu_kernel/src/sql/plan_ctx.rs b/mudu_kernel/src/sql/plan_ctx.rs index 84e7cb1..ef2f1d6 100644 --- a/mudu_kernel/src/sql/plan_ctx.rs +++ b/mudu_kernel/src/sql/plan_ctx.rs @@ -1,11 +1,11 @@ use crate::contract::meta_mgr::MetaMgr; use crate::x_engine::api::XContract; -use mudu::common::xid::XID; +use crate::x_engine::tx_mgr::TxMgr; use std::sync::Arc; #[derive(Clone)] pub struct PlanCtx { - pub xid: XID, + pub tx_mgr: Arc, pub meta_mgr: Arc, pub x_contract: Arc, } diff --git a/mudu_kernel/src/sql/planner.rs b/mudu_kernel/src/sql/planner.rs index 71e742a..257ede3 100644 --- a/mudu_kernel/src/sql/planner.rs +++ b/mudu_kernel/src/sql/planner.rs @@ -53,7 +53,7 @@ impl Planner { BoundPredicate::True => { let exec = crate::executor::index_access_range::IndexAccessRange::new( PAccessRange { - xid: self.ctx.xid, + tx_mgr: self.ctx.tx_mgr.clone(), table_id: stmt.table_id, pred_key: RangeData::new( std::ops::Bound::Unbounded, @@ -72,7 +72,7 @@ impl Planner { BoundPredicate::KeyEq { key } => { let exec = crate::executor::index_access_key::IndexAccessKey::new( PAccessKey { - xid: self.ctx.xid, + tx_mgr: self.ctx.tx_mgr.clone(), table_id: stmt.table_id, pred_key: VecDatum::new(key), select, @@ -87,7 +87,7 @@ impl Planner { BoundPredicate::KeyRange { start, end } => { let exec = crate::executor::index_access_range::IndexAccessRange::new( PAccessRange { - xid: self.ctx.xid, + tx_mgr: self.ctx.tx_mgr.clone(), table_id: stmt.table_id, pred_key: RangeData::new(start, end), pred_non_key: Predicate::CNF(Vec::new()), @@ -106,7 +106,7 @@ impl Planner { fn plan_create_table(&self, stmt: BoundCreateTable) -> CreateTable { CreateTable::new( PCreateTable { - xid: self.ctx.xid, + tx_mgr: self.ctx.tx_mgr.clone(), schema: stmt.schema, }, self.ctx.x_contract.clone(), @@ -117,7 +117,7 @@ impl Planner { fn plan_drop_table(&self, stmt: BoundDropTable) -> DropTable { DropTable::new( PDropTable { - xid: self.ctx.xid, + tx_mgr: self.ctx.tx_mgr.clone(), oid: Some(stmt.table_id), }, self.ctx.x_contract.clone(), @@ -128,7 +128,7 @@ impl Planner { fn plan_insert(&self, stmt: BoundInsert) -> InsertKeyValue { InsertKeyValue::new( PInsertKeyValue { - xid: self.ctx.xid, + tx_mgr: self.ctx.tx_mgr.clone(), table_id: stmt.table_id, key: VecDatum::new(stmt.key), value: VecDatum::new(stmt.value), @@ -141,7 +141,7 @@ impl Planner { fn plan_update(&self, stmt: BoundUpdate) -> UpdateKeyValue { UpdateKeyValue::new( PUpdateKeyValue { - xid: self.ctx.xid, + tx_mgr: self.ctx.tx_mgr.clone(), table_id: stmt.table_id, key: VecDatum::new(stmt.key), value: VecDatum::new(stmt.value), @@ -154,7 +154,7 @@ impl Planner { fn plan_delete(&self, stmt: BoundDelete) -> DeleteKeyValue { DeleteKeyValue::new( PDeleteKeyValue { - xid: self.ctx.xid, + tx_mgr: self.ctx.tx_mgr.clone(), table_id: stmt.table_id, key: VecDatum::new(stmt.key), }, @@ -166,7 +166,7 @@ impl Planner { fn plan_copy_from(&self, stmt: BoundCopyFrom) -> LoadFromFile { LoadFromFile::new( stmt.file_path, - self.ctx.xid, + self.ctx.tx_mgr.clone(), stmt.table_id, stmt.key_index, stmt.value_index, @@ -178,7 +178,7 @@ impl Planner { fn plan_copy_to(&self, stmt: BoundCopyTo) -> SaveToFile { SaveToFile::new( stmt.file_path, - self.ctx.xid, + self.ctx.tx_mgr.clone(), stmt.table_id, stmt.key_indexing, stmt.value_indexing, diff --git a/mudu_kernel/src/sql/stmt_cmd_run_test.rs b/mudu_kernel/src/sql/stmt_cmd_run_test.rs new file mode 100644 index 0000000..349c8cb --- /dev/null +++ b/mudu_kernel/src/sql/stmt_cmd_run_test.rs @@ -0,0 +1,133 @@ +#[cfg(test)] +mod tests { + use crate::contract::cmd_exec::CmdExec; + use crate::contract::ssn_ctx::SsnCtx; + use crate::sql::stmt_cmd::StmtCmd; + use crate::sql::stmt_cmd_run::run_cmd_stmt; + use async_trait::async_trait; + use mudu::common::result::RS; + use mudu::common::xid::XID; + use mudu::error::ec::EC; + use mudu::m_error; + use std::sync::atomic::{AtomicBool, Ordering}; + use std::sync::{Arc, Mutex}; + + #[derive(Default)] + struct TestSsnCtx { + current_tx: Mutex>, + ended: AtomicBool, + } + + impl TestSsnCtx { + fn ended(&self) -> bool { + self.ended.load(Ordering::SeqCst) + } + } + + impl SsnCtx for TestSsnCtx { + fn current_tx(&self) -> Option { + *self.current_tx.lock().unwrap() + } + + fn begin_tx(&self, xid: XID) -> RS<()> { + *self.current_tx.lock().unwrap() = Some(xid); + Ok(()) + } + + fn end_tx(&self) -> RS<()> { + self.ended.store(true, Ordering::SeqCst); + Ok(()) + } + } + + struct TestCmdExec { + fail_prepare: bool, + fail_run: bool, + affected_rows: u64, + } + + #[async_trait] + impl CmdExec for TestCmdExec { + async fn prepare(&self) -> RS<()> { + if self.fail_prepare { + Err(m_error!(EC::InternalErr, "prepare failed")) + } else { + Ok(()) + } + } + + async fn run(&self) -> RS<()> { + if self.fail_run { + Err(m_error!(EC::InternalErr, "run failed")) + } else { + Ok(()) + } + } + + async fn affected_rows(&self) -> RS { + Ok(self.affected_rows) + } + } + + struct TestStmtCmd { + fail_realize: bool, + fail_build: bool, + exec: Arc, + } + + #[async_trait] + impl StmtCmd for TestStmtCmd { + async fn realize(&self, _ctx: &dyn SsnCtx) -> RS<()> { + if self.fail_realize { + Err(m_error!(EC::InternalErr, "realize failed")) + } else { + Ok(()) + } + } + + async fn build(&self, _ctx: &dyn SsnCtx) -> RS> { + if self.fail_build { + Err(m_error!(EC::InternalErr, "build failed")) + } else { + Ok(self.exec.clone()) + } + } + } + + #[tokio::test] + async fn run_cmd_stmt_returns_affected_rows_on_success() { + let ctx = TestSsnCtx::default(); + let stmt = TestStmtCmd { + fail_realize: false, + fail_build: false, + exec: Arc::new(TestCmdExec { + fail_prepare: false, + fail_run: false, + affected_rows: 3, + }), + }; + + let rows = run_cmd_stmt(&stmt, &ctx).await.unwrap(); + assert_eq!(rows, 3); + assert!(ctx.current_tx().is_some()); + assert!(!ctx.ended()); + } + + #[tokio::test] + async fn run_cmd_stmt_ends_tx_on_build_error() { + let ctx = TestSsnCtx::default(); + let stmt = TestStmtCmd { + fail_realize: false, + fail_build: true, + exec: Arc::new(TestCmdExec { + fail_prepare: false, + fail_run: false, + affected_rows: 0, + }), + }; + + let err = run_cmd_stmt(&stmt, &ctx).await.unwrap_err(); + assert!(err.to_string().contains("build failed")); + assert!(ctx.ended()); + } +} diff --git a/mudu_kernel/src/sql/stmt_query_run_test.rs b/mudu_kernel/src/sql/stmt_query_run_test.rs new file mode 100644 index 0000000..f8814b9 --- /dev/null +++ b/mudu_kernel/src/sql/stmt_query_run_test.rs @@ -0,0 +1,154 @@ +#[cfg(test)] +mod tests { + use crate::contract::query_exec::QueryExec; + use crate::contract::ssn_ctx::SsnCtx; + use crate::sql::proj_field::ProjField; + use crate::sql::proj_list::ProjList; + use crate::sql::stmt_query::StmtQuery; + use crate::sql::stmt_query_run::run_query_stmt; + use async_trait::async_trait; + use futures::StreamExt; + use mudu::common::id::gen_oid; + use mudu::common::result::RS; + use mudu::common::xid::XID; + use mudu::error::ec::EC; + use mudu::m_error; + use mudu_contract::tuple::datum_desc::DatumDesc; + use mudu_contract::tuple::tuple_field::TupleField; + use mudu_contract::tuple::tuple_field_desc::TupleFieldDesc; + use mudu_type::dat_type::DatType; + use mudu_type::dat_type_id::DatTypeID; + use std::collections::VecDeque; + use std::sync::atomic::{AtomicBool, Ordering}; + use std::sync::{Arc, Mutex}; + + #[derive(Default)] + struct TestSsnCtx { + current_tx: Mutex>, + ended: AtomicBool, + } + + impl TestSsnCtx { + fn ended(&self) -> bool { + self.ended.load(Ordering::SeqCst) + } + } + + impl SsnCtx for TestSsnCtx { + fn current_tx(&self) -> Option { + *self.current_tx.lock().unwrap() + } + + fn begin_tx(&self, xid: XID) -> RS<()> { + *self.current_tx.lock().unwrap() = Some(xid); + Ok(()) + } + + fn end_tx(&self) -> RS<()> { + self.ended.store(true, Ordering::SeqCst); + Ok(()) + } + } + + struct TestQueryExec { + rows: Mutex>, + tuple_desc: TupleFieldDesc, + } + + #[async_trait] + impl QueryExec for TestQueryExec { + async fn open(&self) -> RS<()> { + Ok(()) + } + + async fn next(&self) -> RS> { + Ok(self.rows.lock().unwrap().pop_front()) + } + + fn tuple_desc(&self) -> RS { + Ok(self.tuple_desc.clone()) + } + } + + struct TestStmtQuery { + fail_realize: bool, + exec: Arc, + proj_list: ProjList, + } + + #[async_trait] + impl StmtQuery for TestStmtQuery { + async fn realize(&self, _ctx: &dyn SsnCtx) -> RS<()> { + if self.fail_realize { + Err(m_error!(EC::InternalErr, "realize failed")) + } else { + Ok(()) + } + } + + async fn build(&self, _ctx: &dyn SsnCtx) -> RS> { + Ok(self.exec.clone()) + } + + fn proj_list(&self) -> RS { + Ok(self.proj_list.clone()) + } + } + + fn int_proj_list() -> ProjList { + ProjList::new(vec![ProjField::new( + 0, + gen_oid(), + "id".to_string(), + DatType::default_for(DatTypeID::I32), + )]) + } + + fn int_tuple_desc() -> TupleFieldDesc { + TupleFieldDesc::new(vec![DatumDesc::new( + "id".to_string(), + DatType::default_for(DatTypeID::I32), + )]) + } + + #[tokio::test] + async fn run_query_stmt_returns_stream_on_success() { + let ctx = TestSsnCtx::default(); + let stmt = TestStmtQuery { + fail_realize: false, + exec: Arc::new(TestQueryExec { + rows: Mutex::new(VecDeque::new()), + tuple_desc: int_tuple_desc(), + }), + proj_list: int_proj_list(), + }; + + let (fields, mut stream) = run_query_stmt(&stmt, &ctx).await.unwrap(); + assert_eq!(fields.len(), 1); + assert!(stream.next().await.is_none()); + assert!(ctx.current_tx().is_some()); + assert!(!ctx.ended()); + } + + #[tokio::test] + async fn run_query_stmt_ends_tx_on_row_shape_error() { + let ctx = TestSsnCtx::default(); + let stmt = TestStmtQuery { + fail_realize: false, + exec: Arc::new(TestQueryExec { + rows: Mutex::new(VecDeque::from(vec![TupleField::new(vec![])])), + tuple_desc: int_tuple_desc(), + }), + proj_list: int_proj_list(), + }; + + let err = match run_query_stmt(&stmt, &ctx).await { + Ok(_) => panic!("expected query error"), + Err(err) => err, + }; + assert!(err + .to_string() + .contains("fatal error: non consistent column number")); + assert!(ctx.ended()); + } +} diff --git a/mudu_kernel/src/sql/value_codec.rs b/mudu_kernel/src/sql/value_codec.rs index 8de8324..c0f552a 100644 --- a/mudu_kernel/src/sql/value_codec.rs +++ b/mudu_kernel/src/sql/value_codec.rs @@ -3,6 +3,8 @@ use mudu::common::result::RS; use mudu::error::ec::EC as ER; use mudu::m_error; use mudu_contract::database::sql_params::SQLParams; +use mudu_type::dat_type_id::DatTypeID; +use mudu_type::dat_typed::DatTyped; use mudu_type::datum::DatumDyn; use mudu_type::dt_fn_param::DatType; use sql_parser::ast::expr_item::ExprValue; @@ -32,11 +34,39 @@ impl ValueCodec { pub(crate) fn binary_from_literal(literal: &ExprLiteral, dat_type: &DatType) -> RS { match literal { - ExprLiteral::DatumLiteral(typed) => typed + ExprLiteral::DatumLiteral(typed) => Self::coerce_literal(typed, dat_type)? .dat_internal() .to_binary(dat_type) .map(|binary| binary.into()) .map_err(|e| m_error!(ER::TypeBaseErr, "literal type mismatch", e)), } } + + fn coerce_literal(literal: &DatTyped, dat_type: &DatType) -> RS { + let source = literal.dat_type().dat_type_id(); + let target = dat_type.dat_type_id(); + if source == target { + return Ok(literal.clone()); + } + + let coerced = match (source, target) { + (DatTypeID::I64, DatTypeID::I32) => { + DatTyped::from_i32(literal.dat_internal().to_i64() as i32) + } + (DatTypeID::I32, DatTypeID::I64) => { + DatTyped::from_i64(literal.dat_internal().to_i32() as i64) + } + (DatTypeID::I64, DatTypeID::I128) => { + DatTyped::from_i128(literal.dat_internal().to_i64() as i128) + } + (DatTypeID::I64, DatTypeID::U128) => { + DatTyped::from_oid(literal.dat_internal().to_i64() as u128) + } + (DatTypeID::F64, DatTypeID::F32) => { + DatTyped::from_f32(literal.dat_internal().to_f64() as f32) + } + _ => return Ok(literal.clone()), + }; + Ok(coerced) + } } diff --git a/mudu_kernel/src/sql/value_codec_test.rs b/mudu_kernel/src/sql/value_codec_test.rs new file mode 100644 index 0000000..9a6d920 --- /dev/null +++ b/mudu_kernel/src/sql/value_codec_test.rs @@ -0,0 +1,99 @@ +#[cfg(test)] +mod tests { + use crate::sql::value_codec::ValueCodec; + use mudu_type::dat_type::DatType; + use mudu_type::dat_type_id::DatTypeID; + use mudu_type::dat_typed::DatTyped; + use mudu_type::datum::DatumDyn; + use sql_parser::ast::expr_item::ExprValue; + use sql_parser::ast::expr_literal::ExprLiteral; + + #[test] + fn placeholder_consumes_parameters_in_order() { + let mut param_index = 0; + let first = ValueCodec::binary_from_expr( + &ExprValue::ValuePlaceholder, + &DatType::default_for(DatTypeID::I32), + &(7i32, 9i32), + &mut param_index, + ) + .unwrap(); + let second = ValueCodec::binary_from_expr( + &ExprValue::ValuePlaceholder, + &DatType::default_for(DatTypeID::I32), + &(7i32, 9i32), + &mut param_index, + ) + .unwrap(); + + assert_eq!(param_index, 2); + assert_eq!( + first.as_slice(), + 7i32.to_binary(&DatType::default_for(DatTypeID::I32)) + .unwrap() + .as_ref() + ); + assert_eq!( + second.as_slice(), + 9i32.to_binary(&DatType::default_for(DatTypeID::I32)) + .unwrap() + .as_ref() + ); + } + + #[test] + fn placeholder_errors_when_parameter_is_missing() { + let mut param_index = 0; + let err = ValueCodec::binary_from_expr( + &ExprValue::ValuePlaceholder, + &DatType::default_for(DatTypeID::I32), + &(), + &mut param_index, + ) + .unwrap_err(); + + assert!(err.to_string().contains("missing parameter 0")); + } + + #[test] + fn literal_is_encoded_via_literal_path() { + let mut param_index = 0; + let binary = ValueCodec::binary_from_expr( + &ExprValue::ValueLiteral(ExprLiteral::DatumLiteral(DatTyped::from_i32(42))), + &DatType::default_for(DatTypeID::I32), + &(), + &mut param_index, + ) + .unwrap(); + + assert_eq!(param_index, 0); + assert_eq!( + binary.as_slice(), + 42i32 + .to_binary(&DatType::default_for(DatTypeID::I32)) + .unwrap() + .as_ref() + ); + } + + #[test] + fn i64_literal_is_narrowed_for_i32_columns() { + let mut param_index = 0; + let binary = ValueCodec::binary_from_expr( + &ExprValue::ValueLiteral(ExprLiteral::DatumLiteral(DatTyped::from_i64(42))), + &DatType::default_for(DatTypeID::I32), + &(), + &mut param_index, + ) + .unwrap(); + + assert_eq!(param_index, 0); + assert_eq!( + binary.as_slice(), + 42i32 + .to_binary(&DatType::default_for(DatTypeID::I32)) + .unwrap() + .as_ref() + ); + } +} diff --git a/mudu_kernel/src/storage/relation/relation.rs b/mudu_kernel/src/storage/relation/relation.rs index e052526..3143821 100644 --- a/mudu_kernel/src/storage/relation/relation.rs +++ b/mudu_kernel/src/storage/relation/relation.rs @@ -1,5 +1,6 @@ use std::ops::Bound; -use std::sync::Mutex; +use std::cell::{Cell, UnsafeCell}; +use futures::executor::block_on; use mudu::common::id::{TupleID, OID}; use mudu::common::result::RS; @@ -24,61 +25,93 @@ const KEY_FILE_INDEX: u32 = 0; const VALUE_FILE_INDEX: u32 = 1; pub struct Relation { - inner: Mutex, + inner: RelationInner, } +unsafe impl Send for Relation {} +unsafe impl Sync for Relation {} + struct RelationInner { _table_id: OID, _partition_id: OID, - index: BTreeIndex, - key_file: TimeSeriesFile, - value_file: TimeSeriesFile, - next_tuple_id: TupleID, + index: UnsafeCell>, + key_file: UnsafeCell, + value_file: UnsafeCell, + next_tuple_id: Cell, } +unsafe impl Send for RelationInner {} +unsafe impl Sync for RelationInner {} + impl Relation { pub fn new(table_id: OID, partition_id: OID, path: String, table_desc: &TableDesc) -> Self { Self { - inner: Mutex::new(RelationInner::new(table_id, partition_id, path, table_desc)), + inner: RelationInner::new(table_id, partition_id, path, table_desc), } } - pub fn has_visible_version(&self, key: &KeyTuple, snapshot: &WorkerSnapshot) -> RS { - Ok(self.lock_inner()?.visible_meta(key, snapshot)?.is_some()) + pub async fn has_visible_version(&self, key: &KeyTuple, snapshot: &WorkerSnapshot) -> RS { + Ok(self.inner.visible_meta(key, snapshot).await?.is_some()) + } + + pub fn has_visible_version_sync(&self, key: &KeyTuple, snapshot: &WorkerSnapshot) -> RS { + Ok(self.inner.visible_meta_sync(key, snapshot)?.is_some()) + } + + pub async fn visible_value(&self, key: &KeyTuple, snapshot: &WorkerSnapshot) -> RS>> { + self.inner.visible_value(key, snapshot).await + } + + pub fn visible_value_sync(&self, key: &KeyTuple, snapshot: &WorkerSnapshot) -> RS>> { + self.inner.visible_value_sync(key, snapshot) } - pub fn visible_value(&self, key: &KeyTuple, snapshot: &WorkerSnapshot) -> RS>> { - self.lock_inner()?.visible_value(key, snapshot) + pub async fn visible_range( + &self, + bounds: (Bound<&[u8]>, Bound<&[u8]>), + snapshot: &WorkerSnapshot, + ) -> RS, Vec)>> { + self.inner.visible_range(bounds, snapshot).await } - pub fn visible_range( + pub fn visible_range_sync( &self, bounds: (Bound<&[u8]>, Bound<&[u8]>), snapshot: &WorkerSnapshot, ) -> RS, Vec)>> { - self.lock_inner()?.visible_range(bounds, snapshot) + self.inner.visible_range_sync(bounds, snapshot) + } + + pub async fn has_write_conflict(&self, key: &KeyTuple, snapshot: &WorkerSnapshot) -> RS { + self.inner.has_write_conflict(key, snapshot).await } - pub fn has_write_conflict(&self, key: &KeyTuple, snapshot: &WorkerSnapshot) -> RS { - self.lock_inner()?.has_write_conflict(key, snapshot) + pub fn has_write_conflict_sync(&self, key: &KeyTuple, snapshot: &WorkerSnapshot) -> RS { + self.inner.has_write_conflict_sync(key, snapshot) } - pub fn write_value(&self, key: Vec, value: Vec, xid: u64) -> RS<()> { - self.lock_inner()?.write_row(key, Some(value), xid) + pub async fn write_value(&self, key: Vec, value: Vec, xid: u64) -> RS<()> { + self.inner.write_row(key, Some(value), xid).await } - pub fn write_delete(&self, key: Vec, xid: u64) -> RS<()> { - self.lock_inner()?.write_row(key, None, xid) + pub fn write_value_sync(&self, key: Vec, value: Vec, xid: u64) -> RS<()> { + self.inner.write_row_sync(key, Some(value), xid) } - pub fn write_row(&self, key: Vec, value: Option>, xid: u64) -> RS<()> { - self.lock_inner()?.write_row(key, value, xid) + pub async fn write_delete(&self, key: Vec, xid: u64) -> RS<()> { + self.inner.write_row(key, None, xid).await } - fn lock_inner(&self) -> RS> { - self.inner - .lock() - .map_err(|_| m_error!(EC::InternalErr, "relation lock poisoned")) + pub fn write_delete_sync(&self, key: Vec, xid: u64) -> RS<()> { + self.inner.write_row_sync(key, None, xid) + } + + pub async fn write_row(&self, key: Vec, value: Option>, xid: u64) -> RS<()> { + self.inner.write_row(key, value, xid).await + } + + pub fn write_row_sync(&self, key: Vec, value: Option>, xid: u64) -> RS<()> { + self.inner.write_row_sync(key, value, xid) } } @@ -95,19 +128,23 @@ impl RelationInner { file_index: VALUE_FILE_INDEX, }; - let mut relation = Self { + let relation = Self { _table_id: table_id, _partition_id: partition_id, - index: BTreeIndex::new(CompareContext { + index: UnsafeCell::new(BTreeIndex::new(CompareContext { result: Ok(()), comparator: TupleComparator::new(), desc: table_desc.key_desc().clone(), - }), - key_file: TimeSeriesFile::open_relation_file_sync(&path, key_identity, true) - .unwrap_or_else(|e| panic!("open relation key file failed: {e}")), - value_file: TimeSeriesFile::open_relation_file_sync(&path, value_identity, true) - .unwrap_or_else(|e| panic!("open relation value file failed: {e}")), - next_tuple_id: 1, + })), + key_file: UnsafeCell::new( + TimeSeriesFile::open_relation_file_sync(&path, key_identity, true) + .unwrap_or_else(|e| panic!("open relation key file failed: {e}")), + ), + value_file: UnsafeCell::new( + TimeSeriesFile::open_relation_file_sync(&path, value_identity, true) + .unwrap_or_else(|e| panic!("open relation value file failed: {e}")), + ), + next_tuple_id: Cell::new(1), }; relation .rebuild_from_files() @@ -115,8 +152,8 @@ impl RelationInner { relation } - fn rebuild_from_files(&mut self) -> RS<()> { - let rows = self.key_file.scan_range_sync(0, u64::MAX)?; + fn rebuild_from_files(&self) -> RS<()> { + let rows = self.key_file().scan_range_sync(0, u64::MAX)?; let mut max_tuple_id = 0; for key_row in rows { @@ -124,7 +161,7 @@ impl RelationInner { max_tuple_id = max_tuple_id.max(tuple_id); let key_tuple = KeyTuple::from(key_row.payload.clone()); - let row = match self.index.get(&key_tuple)?.cloned() { + let row = match self.index().get(&key_tuple)?.cloned() { Some(row) => { let existing_tuple_id = row .tuple_id_sync()? @@ -147,47 +184,62 @@ impl RelationInner { let timestamp = Timestamp::new(key_row.timestamp, u64::MAX); let version = match self - .value_file + .value_file() .get_sync(key_row.timestamp, key_row.tuple_id)? { Some(_) => VersionTuple::new(timestamp, Vec::new()), None => VersionTuple::new_delete(timestamp), }; row.write_sync(version, None)?; - let _ = self.index.insert(key_tuple, row)?; + let _ = self.index_mut().insert(key_tuple, row)?; } - self.next_tuple_id = max_tuple_id.saturating_add(1).max(1); + self.next_tuple_id.set(max_tuple_id.saturating_add(1).max(1)); Ok(()) } - fn visible_meta( + async fn visible_meta( &self, key: &KeyTuple, snapshot: &WorkerSnapshot, ) -> RS> { - let row = match self.index.get(key)? { + let row = match self.index().get(key)? { Some(row) => row, None => return Ok(None), }; let tuple_id = row - .tuple_id_sync()? + .tuple_id() + .await? .ok_or_else(|| m_error!(EC::InternalErr, "missing tuple id"))?; let snapshot = snapshot.to_snapshot(); - Ok(read_visible_version(row, &snapshot) + Ok(read_visible_version_async(row, &snapshot) + .await .filter(|version| !version.is_deleted()) .map(|version| (tuple_id, version))) } - fn visible_value(&self, key: &KeyTuple, snapshot: &WorkerSnapshot) -> RS>> { - let Some((tuple_id, version)) = self.visible_meta(key, snapshot)? else { + fn visible_meta_sync( + &self, + key: &KeyTuple, + snapshot: &WorkerSnapshot, + ) -> RS> { + block_on(self.visible_meta(key, snapshot)) + } + + async fn visible_value(&self, key: &KeyTuple, snapshot: &WorkerSnapshot) -> RS>> { + let Some((tuple_id, version)) = self.visible_meta(key, snapshot).await? else { return Ok(None); }; self.read_value_payload(version.timestamp().c_min(), tuple_id) + .await .map(Some) } - fn visible_range( + fn visible_value_sync(&self, key: &KeyTuple, snapshot: &WorkerSnapshot) -> RS>> { + block_on(self.visible_value(key, snapshot)) + } + + async fn visible_range( &self, bounds: (Bound<&[u8]>, Bound<&[u8]>), snapshot: &WorkerSnapshot, @@ -195,33 +247,46 @@ impl RelationInner { let begin_key = bounds.0.as_ref().map(|key| KeyTuple::from(key.to_vec())); let end_key = bounds.1.as_ref().map(|key| KeyTuple::from(key.to_vec())); let rows = self - .index + .index() .range((bound_key_ref(&begin_key), bound_key_ref(&end_key)))?; - rows.into_iter() - .filter_map(|(_key, row)| { - let snapshot = snapshot.to_snapshot(); - match visible_payloads(&self.key_file, &self.value_file, row, &snapshot) { - Ok(Some(pair)) => Some(Ok(pair)), - Ok(None) => None, - Err(err) => Some(Err(err)), - } - }) - .collect() + let snapshot = snapshot.to_snapshot(); + let mut items = Vec::new(); + for (_key, row) in rows { + if let Some(pair) = + visible_payloads_async(self.key_file(), self.value_file(), row, &snapshot).await? + { + items.push(pair); + } + } + Ok(items) + } + + fn visible_range_sync( + &self, + bounds: (Bound<&[u8]>, Bound<&[u8]>), + snapshot: &WorkerSnapshot, + ) -> RS, Vec)>> { + block_on(self.visible_range(bounds, snapshot)) } - fn has_write_conflict(&self, key: &KeyTuple, snapshot: &WorkerSnapshot) -> RS { - Ok(self - .index - .get(key)? - .and_then(latest_version) + async fn has_write_conflict(&self, key: &KeyTuple, snapshot: &WorkerSnapshot) -> RS { + let latest = match self.index().get(key)? { + Some(row) => latest_version_async(row).await, + None => None, + }; + Ok(latest .map(|latest| !snapshot.is_visible(latest.timestamp().c_min())) .unwrap_or(false)) } - fn write_row(&mut self, key: Vec, value: Option>, xid: u64) -> RS<()> { + fn has_write_conflict_sync(&self, key: &KeyTuple, snapshot: &WorkerSnapshot) -> RS { + block_on(self.has_write_conflict(key, snapshot)) + } + + async fn write_row(&self, key: Vec, value: Option>, xid: u64) -> RS<()> { let key_tuple = KeyTuple::from(key.clone()); - let row = match self.index.get(&key_tuple)?.cloned() { + let row = match self.index().get(&key_tuple)?.cloned() { Some(row) => row, None => { let tuple_id = self.alloc_tuple_id(); @@ -230,34 +295,42 @@ impl RelationInner { }; let tuple_id = row - .tuple_id_sync()? + .tuple_id() + .await? .ok_or_else(|| m_error!(EC::InternalErr, "missing tuple id"))?; let timestamp = Timestamp::new(xid, u64::MAX); - self.key_file - .insert_sync(timestamp.c_min(), tuple_id as u64, &key)?; + self.key_file_mut() + .insert(timestamp.c_min(), tuple_id as u64, &key) + .await?; if let Some(value) = value.as_ref() { - self.value_file - .insert_sync(timestamp.c_min(), tuple_id as u64, value)?; + self.value_file_mut() + .insert(timestamp.c_min(), tuple_id as u64, value) + .await?; } let version = match value { Some(_) => VersionTuple::new(timestamp, Vec::new()), None => VersionTuple::new_delete(timestamp), }; - row.write_sync(version, None)?; - let _ = self.index.insert(key_tuple, row)?; + row.write(version, None).await?; + let _ = self.index_mut().insert(key_tuple, row)?; Ok(()) } - fn alloc_tuple_id(&mut self) -> TupleID { - let tuple_id = self.next_tuple_id; - self.next_tuple_id += 1; + fn write_row_sync(&self, key: Vec, value: Option>, xid: u64) -> RS<()> { + block_on(self.write_row(key, value, xid)) + } + + fn alloc_tuple_id(&self) -> TupleID { + let tuple_id = self.next_tuple_id.get(); + self.next_tuple_id.set(tuple_id + 1); tuple_id } - fn read_value_payload(&self, timestamp: u64, tuple_id: OID) -> RS> { - self.value_file - .get_sync(timestamp, tuple_id as u64)? + async fn read_value_payload(&self, timestamp: u64, tuple_id: OID) -> RS> { + self.value_file() + .get(timestamp, tuple_id as u64) + .await? .map(|record| record.payload) .ok_or_else(|| { m_error!( @@ -266,6 +339,36 @@ impl RelationInner { ) }) } + + fn index(&self) -> &BTreeIndex { + // Safety: Relation is expected to be accessed from a single worker thread. + unsafe { &*self.index.get() } + } + + fn index_mut(&self) -> &mut BTreeIndex { + // Safety: Relation is expected to be accessed from a single worker thread. + unsafe { &mut *self.index.get() } + } + + fn key_file(&self) -> &TimeSeriesFile { + // Safety: Relation is expected to be accessed from a single worker thread. + unsafe { &*self.key_file.get() } + } + + fn key_file_mut(&self) -> &mut TimeSeriesFile { + // Safety: Relation is expected to be accessed from a single worker thread. + unsafe { &mut *self.key_file.get() } + } + + fn value_file(&self) -> &TimeSeriesFile { + // Safety: Relation is expected to be accessed from a single worker thread. + unsafe { &*self.value_file.get() } + } + + fn value_file_mut(&self) -> &mut TimeSeriesFile { + // Safety: Relation is expected to be accessed from a single worker thread. + unsafe { &mut *self.value_file.get() } + } } #[cfg(test)] @@ -285,16 +388,20 @@ mod tests { fn test_schema() -> SchemaTable { SchemaTable::new( "t".to_string(), - vec![SchemaColumn::new( - "id".to_string(), - DatTypeID::I32, - DTInfo::from_text(DatTypeID::I32, String::new()), - )], - vec![SchemaColumn::new( - "v".to_string(), - DatTypeID::I32, - DTInfo::from_text(DatTypeID::I32, String::new()), - )], + vec![ + SchemaColumn::new( + "id".to_string(), + DatTypeID::I32, + DTInfo::from_text(DatTypeID::I32, String::new()), + ), + SchemaColumn::new( + "v".to_string(), + DatTypeID::I32, + DTInfo::from_text(DatTypeID::I32, String::new()), + ), + ], + vec![0], + vec![1], ) } @@ -322,18 +429,18 @@ mod tests { let relation = Relation::new(table_id, partition_id, path.clone(), table_desc.as_ref()); relation - .write_value(i32_bytes(1), i32_bytes(11), 1) + .write_value_sync(i32_bytes(1), i32_bytes(11), 1) .unwrap(); - relation.write_delete(i32_bytes(1), 2).unwrap(); + relation.write_delete_sync(i32_bytes(1), 2).unwrap(); relation - .write_value(i32_bytes(2), i32_bytes(22), 3) + .write_value_sync(i32_bytes(2), i32_bytes(22), 3) .unwrap(); drop(relation); let reopened = Relation::new(table_id, partition_id, path.clone(), table_desc.as_ref()); assert_eq!( reopened - .visible_value( + .visible_value_sync( &KeyTuple::from(i32_bytes(1)), &WorkerSnapshot::new(1, vec![]) ) @@ -342,7 +449,7 @@ mod tests { ); assert_eq!( reopened - .visible_value( + .visible_value_sync( &KeyTuple::from(i32_bytes(1)), &WorkerSnapshot::new(2, vec![]) ) @@ -351,7 +458,7 @@ mod tests { ); assert_eq!( reopened - .visible_value( + .visible_value_sync( &KeyTuple::from(i32_bytes(2)), &WorkerSnapshot::new(3, vec![]) ) @@ -360,7 +467,7 @@ mod tests { ); reopened - .write_value(i32_bytes(3), i32_bytes(33), 4) + .write_value_sync(i32_bytes(3), i32_bytes(33), 4) .unwrap(); let key_file = TimeSeriesFile::open_ts_file_sync( TimeSeriesFile::relation_file_path(&path, partition_id, table_id, 0), @@ -376,22 +483,26 @@ mod tests { } } -fn visible_payloads( +async fn visible_payloads_async( key_file: &TimeSeriesFile, value_file: &TimeSeriesFile, row: &DataRow, snapshot: &Snapshot, ) -> RS, Vec)>> { let tuple_id = row - .tuple_id_sync()? + .tuple_id() + .await? .ok_or_else(|| m_error!(EC::InternalErr, "missing tuple id"))?; - let Some(version) = read_visible_version(row, snapshot).filter(|version| !version.is_deleted()) + let Some(version) = read_visible_version_async(row, snapshot) + .await + .filter(|version| !version.is_deleted()) else { return Ok(None); }; let ts = version.timestamp().c_min(); let key = key_file - .get_sync(ts, tuple_id as u64)? + .get(ts, tuple_id as u64) + .await? .map(|record| record.payload) .ok_or_else(|| { m_error!( @@ -400,7 +511,8 @@ fn visible_payloads( ) })?; let value = value_file - .get_sync(ts, tuple_id as u64)? + .get(ts, tuple_id as u64) + .await? .map(|record| record.payload) .ok_or_else(|| { m_error!( @@ -411,12 +523,12 @@ fn visible_payloads( Ok(Some((key, value))) } -fn latest_version(row: &DataRow) -> Option { - row.read_latest_sync().ok().flatten() +async fn latest_version_async(row: &DataRow) -> Option { + row.read_latest().await.ok().flatten() } -fn read_visible_version(row: &DataRow, snapshot: &Snapshot) -> Option { - row.read_sync(snapshot).ok().flatten() +async fn read_visible_version_async(row: &DataRow, snapshot: &Snapshot) -> Option { + row.read(snapshot).await.ok().flatten() } fn bound_key_ref(bound: &Bound) -> Bound<&KeyTuple> { diff --git a/mudu_kernel/src/tx/lock_slot.rs b/mudu_kernel/src/tx/lock_slot.rs deleted file mode 100644 index 982c6ea..0000000 --- a/mudu_kernel/src/tx/lock_slot.rs +++ /dev/null @@ -1,129 +0,0 @@ -use crate::contract::x_lock_mgr::LockResult; -use mudu::common::result::RS; -use mudu::common::xid::XID; -use mudu_utils::sync::notify_wait::Notify; -use mudu_utils::sync::s_mutex::SMutex; -use std::collections::VecDeque; -use std::sync::Arc; - -#[derive(Clone)] -pub struct LockSlot { - inner: Arc>, -} - -pub struct _LockSlotInner { - is_deleted: bool, - locked: Option, - queue: VecDeque<(XID, Notify)>, -} - -impl LockSlot { - pub fn new() -> Self { - Self { - inner: Arc::new(SMutex::new(_LockSlotInner::new())), - } - } - - /// when return None, this has been deleted, the invoker mut recreate a new LockSlot - pub fn lock(&self, xid: XID, notify: Notify) -> Option> { - let r = self.inner.lock(); - let mut g = match r { - Ok(g) => g, - Err(e) => return Some(Err(e)), - }; - if g.is_deleted() { - None - } else { - let r = g.lock(xid, notify); - match r { - Ok(_) => Some(Ok(())), - Err(e) => Some(Err(e)), - } - } - } - - pub fn release(&self, xid: XID) -> RS { - let mut g = self.inner.lock()?; - let (_, opt_notify) = g.release(xid); - if let Some((_, notify)) = opt_notify { - notify.notify(LockResult::Locked)?; - } - Ok(g.remove_if_empty()) - } -} - -impl _LockSlotInner { - fn new() -> Self { - Self { - is_deleted: false, - locked: None, - queue: Default::default(), - } - } - - fn is_deleted(&self) -> bool { - self.is_deleted - } - - fn set_deleted(&mut self) { - self.is_deleted = true; - } - - fn lock(&mut self, xid: XID, notify: Notify) -> RS { - let ret = if self.locked.is_none() { - assert!(self.queue.is_empty()); - self.locked = Some(xid); - notify.notify(LockResult::Locked)?; - true - } else { - self.queue.push_back((xid, notify)); - false - }; - Ok(ret) - } - - fn release(&mut self, xid: XID) -> (bool, Option<(XID, Notify)>) { - let ok = self.clear_locked(xid); - if ok { - let notify = self.notify_next(); - (true, notify) - } else { - (false, None) - } - } - - fn clear_locked(&mut self, xid: XID) -> bool { - match self.locked { - Some(id) => { - if xid == id { - self.locked = None; - true - } else { - false - } - } - None => false, - } - } - - fn notify_next(&mut self) -> Option<(XID, Notify)> { - if let Some(_t) = self.locked { - return None; - } - let opt = self.queue.pop_front(); - match opt { - Some((xid, notify)) => { - self.locked = Some(xid); - Some((xid, notify)) - } - None => None, - } - } - - fn remove_if_empty(&mut self) -> bool { - if self.locked.is_none() && self.queue.is_empty() { - self.is_deleted = true; - } - self.is_deleted - } -} diff --git a/mudu_kernel/src/tx/lock_table.rs b/mudu_kernel/src/tx/lock_table.rs deleted file mode 100644 index 49ac343..0000000 --- a/mudu_kernel/src/tx/lock_table.rs +++ /dev/null @@ -1,71 +0,0 @@ -use crate::collection::hash_map::hash_map_get_or_create; -use crate::contract::x_lock_mgr::LockResult; -use crate::tx::lock_slot::LockSlot; -use mudu::common::buf::Buf; -use mudu::common::result::RS; -use mudu::common::xid::XID; -use mudu::error::ec::EC as ER; -use mudu::m_error; -use mudu_contract::tuple::tuple_binary_desc::TupleBinaryDesc as TupleDesc; -use mudu_contract::tuple::tuple_key::TupleKey; -use mudu_utils::sync::notify_wait::Notify; -use scc::HashMap; -use std::sync::Arc; - -#[derive(Clone)] -pub struct LockTable { - inner: Arc, -} - -struct LockTableInner { - tuple_desc: TupleDesc, - hash_map: HashMap, -} - -impl LockTable { - pub fn new(tuple_desc: TupleDesc) -> Self { - Self { - inner: Arc::new(LockTableInner::new(tuple_desc)), - } - } - - pub fn lock(&self, notify: Notify, xid: XID, key: Buf) -> RS<()> { - self.inner.lock(notify, xid, key) - } - - pub fn release(&self, xid: XID, key: &Buf) -> RS<()> { - self.inner.release(xid, key) - } -} - -impl LockTableInner { - pub fn new(tuple_desc: TupleDesc) -> Self { - Self { - tuple_desc, - hash_map: HashMap::new(), - } - } - - fn lock(&self, notify: Notify, xid: XID, key: Buf) -> RS<()> { - let key = TupleKey::from_buf(&self.tuple_desc, key); - - hash_map_get_or_create(&self.hash_map, key, LockSlot::new, move |slot| { - let n = notify.clone(); - slot.lock(xid, n) - })? - } - - fn release(&self, xid: XID, key: &Buf) -> RS<()> { - let key = TupleKey::from_buf(&self.tuple_desc, key.clone()); - let opt = self.hash_map.get_sync(&key); - let slot = match opt { - Some(slot) => slot.clone(), - None => return Err(m_error!(ER::NoSuchElement, "")), - }; - let empty = slot.release(xid)?; - if empty { - let _ = self.hash_map.remove_sync(&key); - } - Ok(()) - } -} diff --git a/mudu_kernel/src/tx/mod.rs b/mudu_kernel/src/tx/mod.rs index 7334373..fe1461d 100644 --- a/mudu_kernel/src/tx/mod.rs +++ b/mudu_kernel/src/tx/mod.rs @@ -1,12 +1,4 @@ #![allow(dead_code)] -mod lock_slot; -pub mod lock_table; - mod test_x_snap_mgr; - -pub mod tx_ctx; -pub mod tx_mgr_factory; - -mod x_lock_mgr; pub mod x_snap_mgr; diff --git a/mudu_kernel/src/tx/tx_ctx.rs b/mudu_kernel/src/tx/tx_ctx.rs deleted file mode 100644 index 23dc305..0000000 --- a/mudu_kernel/src/tx/tx_ctx.rs +++ /dev/null @@ -1,219 +0,0 @@ -use crate::contract::data_row::DataRow; -use crate::contract::pst_op_list::PstOpList; -use crate::contract::snapshot::{Snapshot, TimeSeq}; -use crate::contract::timestamp::Timestamp; -use crate::contract::version_tuple::VersionTuple; -use crate::contract::x_lock_mgr::XLockMgr; -use crate::contract::xl_rec::XLRec; -use mudu::common::buf::Buf; -use mudu::common::id::{gen_oid, OID}; -use mudu::common::result::RS; -use mudu::common::update_delta::UpdateDelta; -use mudu::common::xid::XID; -use mudu_utils::sync::a_mutex::AMutex; -use mudu_utils::task_trace; -use std::sync::Arc; - -#[derive(Clone)] -pub struct TxCtx { - xid: XID, - inner: Arc>, -} - -struct _TxCtx { - xid: XID, - snapshot: Snapshot, - write_key: Vec<(OID, Buf)>, - log_rec: XLRec, - ops: Vec, -} - -enum TxWriteOp { - Insert(TxInsert), - Update(TxUpdate), - Delete, -} - -struct TxInsert { - table_id: OID, - tuple_id: OID, - key: Buf, - value: Buf, - row: DataRow, -} - -struct TxUpdate { - table_id: OID, - tuple_id: OID, - key: Buf, - value: Buf, - value_up: Vec, - row: DataRow, -} - -impl TxCtx { - pub fn new(xid: XID, snapshot: Snapshot) -> Self { - Self { - xid, - inner: Arc::new(AMutex::new(_TxCtx::new(xid, snapshot))), - } - } - - pub fn xid(&self) -> XID { - self.xid - } - - pub async fn insert(&self, table_id: OID, keys: Buf, values: Buf, row: DataRow) -> RS<()> { - task_trace!(); - let mut g = self.inner.lock().await; - g.insert(table_id, keys, values, row).await?; - Ok(()) - } - - pub async fn update( - &self, - table_id: OID, - tuple_id: OID, - keys: Buf, - values: Vec, - row: DataRow, - ) -> RS<()> { - let mut g = self.inner.lock().await; - g.update(table_id, tuple_id, keys, values, row); - Ok(()) - } - pub async fn write(&self, oid: OID, buf: Buf) -> RS<()> { - let mut g = self.inner.lock().await; - g.write(oid, buf); - Ok(()) - } - - pub async fn commit(&self, lock_mgr: &dyn XLockMgr) -> RS<()> { - task_trace!(); - let mut g = self.inner.lock().await; - g.commit(lock_mgr).await?; - Ok(()) - } - - async fn abort(&self, lock_mgr: &dyn XLockMgr) -> RS<()> { - let mut g = self.inner.lock().await; - g.abort(lock_mgr).await?; - Ok(()) - } - - pub async fn snapshot(&self) -> RS { - let g = self.inner.lock().await; - Ok(g.snapshot().clone()) - } -} - -impl _TxCtx { - pub fn new(xid: XID, snapshot: Snapshot) -> Self { - Self { - snapshot, - xid, - write_key: vec![], - log_rec: XLRec::new(xid), - ops: vec![], - } - } - - fn write(&mut self, oid: OID, key: Buf) { - self.add_write_key(oid, key); - } - - async fn commit(&mut self, lock_mgr: &dyn XLockMgr) -> RS<()> { - task_trace!(); - self.clear(lock_mgr).await?; - Ok(()) - } - - async fn abort(&mut self, lock_mgr: &dyn XLockMgr) -> RS<()> { - task_trace!(); - self.clear(lock_mgr).await?; - Ok(()) - } - - async fn clear(&self, lock_mgr: &dyn XLockMgr) -> RS<()> { - task_trace!(); - for (id, key) in self.write_key.iter() { - lock_mgr.release(self.xid, *id, key).await?; - } - Ok(()) - } - - fn add_write_key(&mut self, table: OID, key: Buf) { - self.write_key.push((table, key)); - } - - fn snapshot(&self) -> &Snapshot { - &self.snapshot - } - - async fn insert(&mut self, table_id: OID, key: Buf, value: Buf, row: DataRow) -> RS<()> { - task_trace!(); - let op = TxInsert { - table_id, - tuple_id: gen_oid(), - key, - value, - row, - }; - self.ops.push(TxWriteOp::Insert(op)); - Ok(()) - } - - fn update( - &mut self, - table_id: OID, - tuple_id: OID, - key: Buf, - value_up: Vec, - row: DataRow, - ) { - let op = TxUpdate { - table_id, - tuple_id, - key, - value: vec![], - value_up, - row, - }; - self.ops.push(TxWriteOp::Update(op)); - } - - async fn delete(&mut self, table_id: OID, tuple_id: OID, keys: Buf) { - self.log_rec.add_delete(table_id, tuple_id, keys); - } -} - -trait XWriteOp { - fn to_x_log_rec(&self, rec: &mut XLRec); - - async fn apply_to_mem(&self, xid: TimeSeq, row: &DataRow) -> RS<()>; - - fn apply_to_pst(&self, pst_op: PstOpList); -} - -impl XWriteOp for TxUpdate { - fn to_x_log_rec(&self, rec: &mut XLRec) { - rec.add_update( - self.table_id, - self.tuple_id, - self.key.clone(), - self.value_up.clone(), - ); - } - - async fn apply_to_mem(&self, xid: TimeSeq, row: &DataRow) -> RS<()> { - let timestamp = Timestamp::new(xid, u64::MAX); - - row.write(VersionTuple::new(timestamp, self.value.clone()), None) - .await?; - Ok(()) - } - - fn apply_to_pst(&self, _pst_op: PstOpList) { - todo!() - } -} diff --git a/mudu_kernel/src/tx/tx_mgr_factory.rs b/mudu_kernel/src/tx/tx_mgr_factory.rs deleted file mode 100644 index 62921db..0000000 --- a/mudu_kernel/src/tx/tx_mgr_factory.rs +++ /dev/null @@ -1,22 +0,0 @@ -use crate::contract::snapshot::TimeSeq; -use crate::contract::x_lock_mgr::XLockMgr; -use crate::tx::x_lock_mgr::XLockMgrImpl; -use crate::tx::x_snap_mgr::XSnapMgr; -use mudu_utils::notifier::NotifyWait; -use std::sync::Arc; - -pub struct TxMgrFactory {} - -impl TxMgrFactory { - pub fn create_lock_mgr() -> Arc { - Arc::new(XLockMgrImpl::new()) - } - - pub fn create_snap_mgr( - canceller: NotifyWait, - xid_max: TimeSeq, - snap_request_queue_size: usize, - ) -> Arc { - Arc::new(XSnapMgr::new(canceller, xid_max, snap_request_queue_size)) - } -} diff --git a/mudu_kernel/src/tx/x_lock_mgr.rs b/mudu_kernel/src/tx/x_lock_mgr.rs deleted file mode 100644 index a034edf..0000000 --- a/mudu_kernel/src/tx/x_lock_mgr.rs +++ /dev/null @@ -1,105 +0,0 @@ -use crate::contract::x_lock_mgr::{LockResult, XLockMgr}; -use crate::tx::lock_table::LockTable; -use async_trait::async_trait; -use mudu::common::buf::Buf; -use mudu::common::id::OID; -use mudu::common::result::RS; -use mudu::common::xid::XID; -use mudu::error::ec::EC as ER; -use mudu::m_error; -use mudu_contract::tuple::tuple_binary_desc::TupleBinaryDesc as TupleDesc; -use mudu_utils::sync::notify_wait::Notify; -use scc::HashMap; -use std::sync::Arc; - -#[derive(Clone)] -pub struct XLockMgrImpl { - inner: Arc<_XLockMgrInner>, -} - -struct _XLockMgrInner { - map: HashMap, -} - -impl XLockMgrImpl { - pub fn new() -> Self { - Self { - inner: Arc::new(_XLockMgrInner { - map: HashMap::new(), - }), - } - } -} - -impl _XLockMgrInner { - fn new() -> Self { - Self { - map: HashMap::new(), - } - } - - fn create_table(&self, table: OID, tuple_desc: TupleDesc) -> RS<()> { - let r = self.map.insert_sync(table, LockTable::new(tuple_desc)); - if r.is_err() { - return Err(m_error!(ER::ExistingSuchElement)); - } - Ok(()) - } - - fn drop_table(&self, table: OID) -> RS<()> { - let r = self.map.remove_sync(&table); - if r.is_none() { - return Err(m_error!(ER::NoSuchElement)); - } - Ok(()) - } - fn lock(&self, notify: Notify, xid: XID, table_id: OID, key: Buf) -> RS<()> { - let table = self.get_lock_table(table_id)?; - table.lock(notify, xid, key)?; - Ok(()) - } - - fn release(&self, xid: XID, table_id: OID, key: &Buf) -> RS<()> { - let table = self.get_lock_table(table_id)?; - table.release(xid, key)?; - Ok(()) - } - - fn get_lock_table(&self, table_id: OID) -> RS { - let lock_table = { - let opt = self.map.get_sync(&table_id); - match opt { - Some(e) => e.get().clone(), - None => { - return Err(m_error!( - ER::NoSuchElement, - format!("no such table {:}", table_id) - )); - } - } - }; - Ok(lock_table) - } -} - -#[async_trait] -impl XLockMgr for XLockMgrImpl { - async fn create_table(&self, table: OID, tuple_desc: TupleDesc) -> RS<()> { - self.inner.create_table(table, tuple_desc) - } - - async fn drop_table(&self, table: OID) -> RS<()> { - self.inner.drop_table(table) - } - - async fn lock(&self, notify: Notify, xid: XID, table_id: OID, key: Buf) -> RS<()> { - self.inner.lock(notify, xid, table_id, key) - } - - async fn release(&self, xid: XID, table_id: OID, key: &Buf) -> RS<()> { - self.inner.release(xid, table_id, key) - } -} - -unsafe impl Send for XLockMgrImpl {} -unsafe impl Sync for XLockMgrImpl {} diff --git a/mudu_kernel/src/wal/xl_batch.rs b/mudu_kernel/src/wal/xl_batch.rs index 09eda9d..3abe69e 100644 --- a/mudu_kernel/src/wal/xl_batch.rs +++ b/mudu_kernel/src/wal/xl_batch.rs @@ -16,3 +16,9 @@ pub use crate::wal::xl_batch_worker_log::{ deserialize_batch, new_xl_batch_worker_log, new_xl_batch_writer, serialize_batch, NoopXLBatchRecoveryHandler, XLBatchWorkerLog, }; + +#[allow(unused)] +pub mod _fuzz { + #[allow(dead_code)] + pub fn _de_en_x_l_batch(_data: &[u8]) {} +} diff --git a/mudu_kernel/src/x_engine/api.rs b/mudu_kernel/src/x_engine/api.rs index d08860d..497141b 100644 --- a/mudu_kernel/src/x_engine/api.rs +++ b/mudu_kernel/src/x_engine/api.rs @@ -5,9 +5,9 @@ use std::sync::Arc; use crate::contract::schema_table::SchemaTable; use crate::x_engine::dat_bin::DatBin; use crate::x_engine::operator::Operator; +use crate::x_engine::tx_mgr::TxMgr; use mudu::common::id::{AttrIndex, OID}; use mudu::common::result::RS; -use mudu::common::xid::XID; use mudu_contract::tuple::tuple_field::TupleField; pub type TupleRow = TupleField; @@ -87,7 +87,7 @@ pub struct OptDelete {} /// [`XContract`] is the storage-facing contract behind SQL execution and the /// worker-local runtime. All stable schema objects are addressed by immutable /// object identifiers such as [`OID`], while each write/read statement is -/// executed inside a transaction identified by [`XID`]. +/// executed inside a transaction identified by a [`TxMgr`] handle. /// /// Conventions: /// - `table_id` always identifies the target table by OID. @@ -99,31 +99,36 @@ pub struct OptDelete {} pub trait XContract: Send + Sync { /// Creates a table described by `schema`. /// - /// `xid` is accepted for interface uniformity; implementations may treat + /// `tx_mgr` is accepted for interface uniformity; implementations may treat /// DDL as autocommit if transactional DDL is not supported. - async fn create_table(&self, xid: XID, schema: &SchemaTable) -> RS<()>; + async fn create_table(&self, tx_mgr: Arc, schema: &SchemaTable) -> RS<()>; /// Drops the table identified by `oid`. - async fn drop_table(&self, xid: XID, oid: OID) -> RS<()>; + async fn drop_table(&self, tx_mgr: Arc, oid: OID) -> RS<()>; /// Applies an alter-table operation to the target table. - async fn alter_table(&self, xid: XID, oid: OID, alter_table: &AlterTable) -> RS<()>; + async fn alter_table( + &self, + tx_mgr: Arc, + oid: OID, + alter_table: &AlterTable, + ) -> RS<()>; - /// Starts a new transaction and returns its transaction id. - async fn begin_tx(&self) -> RS; + /// Starts a new transaction and returns its transaction manager. + async fn begin_tx(&self) -> RS>; - /// Commits the transaction identified by `xid`. - async fn commit_tx(&self, xid: XID) -> RS<()>; + /// Commits the transaction identified by `tx_mgr`. + async fn commit_tx(&self, tx_mgr: Arc) -> RS<()>; - /// Aborts the transaction identified by `xid`. - async fn abort_tx(&self, xid: XID) -> RS<()>; + /// Aborts the transaction identified by `tx_mgr`. + async fn abort_tx(&self, tx_mgr: Arc) -> RS<()>; /// Updates rows that match the provided key and non-key predicates. /// /// Returns the number of visible rows updated. async fn update( &self, - xid: XID, + tx_mgr: Arc, table_id: OID, pred_key: &VecDatum, pred_non_key: &Predicate, @@ -136,7 +141,7 @@ pub trait XContract: Send + Sync { /// Returns `None` when the key is not visible in the transaction snapshot. async fn read_key( &self, - xid: XID, + tx_mgr: Arc, table_id: OID, pred_key: &VecDatum, select: &VecSelTerm, @@ -149,7 +154,7 @@ pub trait XContract: Send + Sync { /// order of the range scan. async fn read_range( &self, - xid: XID, + tx_mgr: Arc, table_id: OID, pred_key: &RangeData, pred_non_key: &Predicate, @@ -162,7 +167,7 @@ pub trait XContract: Send + Sync { /// Returns the number of visible rows deleted. async fn delete( &self, - xid: XID, + tx_mgr: Arc, table_id: OID, pred_key: &VecDatum, pred_non_key: &Predicate, @@ -172,7 +177,7 @@ pub trait XContract: Send + Sync { /// Inserts one row identified by `keys` with payload columns from `values`. async fn insert( &self, - xid: XID, + tx_mgr: Arc, table_id: OID, keys: &VecDatum, values: &VecDatum, diff --git a/mudu_kernel/src/x_engine/mod.rs b/mudu_kernel/src/x_engine/mod.rs index ccd896e..fd6457b 100644 --- a/mudu_kernel/src/x_engine/mod.rs +++ b/mudu_kernel/src/x_engine/mod.rs @@ -4,5 +4,5 @@ pub mod api; pub mod operator; mod dat_bin; -pub mod thd_ctx; pub mod x_param; +pub(crate) mod tx_mgr; diff --git a/mudu_kernel/src/x_engine/thd_ctx.rs b/mudu_kernel/src/x_engine/thd_ctx.rs deleted file mode 100644 index ec907be..0000000 --- a/mudu_kernel/src/x_engine/thd_ctx.rs +++ /dev/null @@ -1,472 +0,0 @@ -use crate::contract::mem_store::MemStore; -use crate::contract::meta_mgr::MetaMgr; -use crate::contract::x_lock_mgr::{LockResult, XLockMgr}; -use crate::tx::x_snap_mgr::SnapshotRequester; -use async_trait::async_trait; -use std::cell::RefCell; -use std::sync::Arc; - -use crate::contract::data_row::DataRow; -use crate::contract::schema_table::SchemaTable; -use crate::contract::table_desc::TableDesc; -use crate::storage::pst_op_ch::PstOpCh; -use crate::tx::tx_ctx::TxCtx; -use crate::x_engine::api::{ - AlterTable, OptDelete, OptInsert, OptRead, OptUpdate, Predicate, RSCursor, RangeData, VecDatum, - VecSelTerm, XContract, -}; -use mudu::common::buf::Buf; -use mudu::common::id::{AttrIndex, ThdID, OID}; -use mudu::common::result::RS; -use mudu::common::result_of::rs_of_opt; -use mudu::common::update_delta::UpdateDelta; -use mudu::common::xid::XID; -use mudu::error::ec::EC as ER; -use mudu::m_error; -use mudu_contract::tuple::build_tuple::build_tuple; -use mudu_contract::tuple::tuple_binary::TupleBinary as TupleRaw; -use mudu_contract::tuple::update_tuple::update_tuple; -use mudu_utils::sync::notify_wait::create_notify_wait; -use mudu_utils::task_trace; -use scc::HashMap; - -#[derive(Clone)] -pub struct ThdCtx { - inner: Arc, -} - -struct ThdCtxInner { - id: u64, - meta_mgr: Arc, - snap_req: Arc, - x_lock_mgr: Arc, - tree_store: Arc, - pst_op_ch: Arc, - tx_ctx: HashMap, -} - -impl ThdCtx { - pub fn new( - id: u64, - meta_mgr: Arc, - snap_req: Arc, - x_lock_mgr: Arc, - tree_store: Arc, - pst_op_ch: Arc, - ) -> Self { - Self { - inner: Arc::new(ThdCtxInner::new( - id, meta_mgr, snap_req, x_lock_mgr, tree_store, pst_op_ch, - )), - } - } - - pub fn thd_id(&self) -> ThdID { - self.inner.id - } - - pub fn snap_req(&self) -> &SnapshotRequester { - self.inner.snap_req() - } - - pub fn meta_mgr(&self) -> &dyn MetaMgr { - self.inner.meta_mgr() - } - - pub fn tree_store(&self) -> &dyn MemStore { - self.inner.tree_store() - } - - pub fn x_lock_mgr(&self) -> &dyn XLockMgr { - self.inner.x_lock_mgr() - } - - pub fn pst_op_ch(&self) -> &dyn PstOpCh { - self.inner.pst_op_ch() - } -} - -impl ThdCtxInner { - fn new( - id: u64, - meta_mgr: Arc, - snap_req: Arc, - x_lock_mgr: Arc, - tree_store: Arc, - pst_op_ch: Arc, - ) -> Self { - Self { - id, - snap_req, - meta_mgr, - tree_store, - x_lock_mgr, - pst_op_ch, - tx_ctx: Default::default(), - } - } - - fn snap_req(&self) -> &SnapshotRequester { - self.snap_req.as_ref() - } - - fn meta_mgr(&self) -> &dyn MetaMgr { - self.meta_mgr.as_ref() - } - - fn tree_store(&self) -> &dyn MemStore { - self.tree_store.as_ref() - } - - fn x_lock_mgr(&self) -> &dyn XLockMgr { - self.x_lock_mgr.as_ref() - } - fn pst_op_ch(&self) -> &dyn PstOpCh { - self.pst_op_ch.as_ref() - } - - async fn create_table(&self, _xid: XID, schema: &SchemaTable) -> RS<()> { - task_trace!(); - let table_id = schema.id(); - self.meta_mgr.create_table(schema).await?; - let kv_desc = self.meta_mgr.get_table_by_id(table_id).await?; - self.x_lock_mgr - .create_table(table_id, kv_desc.key_desc().clone()) - .await?; - self.tree_store - .create_table(table_id, kv_desc.key_desc().clone()) - .await?; - Ok(()) - } - - async fn get_desc(&self, table_id: OID) -> RS> { - self.meta_mgr().get_table_by_id(table_id).await - } - - fn pk_build_tuple(pkey: &VecDatum, desc: &TableDesc) -> RS { - Self::_build_tuple::(pkey.data(), desc) - } - - fn val_build_tuple(val: &VecDatum, desc: &TableDesc) -> RS { - Self::_build_tuple::(val.data(), desc) - } - - fn val_update_tuple( - tuple: &TupleRaw, - val: &VecDatum, - desc: &TableDesc, - ) -> RS> { - Self::_update_tuple(tuple, val.data(), desc) - } - - // build update tuple for this row - fn _update_tuple( - tuple: &TupleRaw, - datum: &Vec<(AttrIndex, Buf)>, - table_desc: &TableDesc, - ) -> RS> { - let mut delta = vec![]; - for (id, dat) in datum.iter() { - let field = table_desc.get_attr(*id); - if field.is_primary() { - return Err(m_error!( - ER::IOErr, - format!( - "column {} in table {} is a primary key", - id, - table_desc.id() - ) - )); - } - let datum_index = field.datum_index(); - update_tuple(datum_index, dat, table_desc.value_desc(), tuple, &mut delta)?; - } - Ok(delta) - } - - fn _build_tuple(data: &Vec<(AttrIndex, Buf)>, desc: &TableDesc) -> RS { - let mut vec_data = data.clone(); - let ok = RefCell::new(true); - vec_data.sort_by(|(id1, _), (id2, _)| { - let (f1, f2) = (desc.get_attr(*id1), desc.get_attr(*id2)); - if f1.is_primary() != IS_KEY || f2.is_primary() != IS_KEY { - *ok.borrow_mut() = false; - } - f1.datum_index().cmp(&f2.datum_index()) - }); - if !*ok.borrow() { - return Err(m_error!(ER::TupleErr)); - } - let vec_data: Vec<_> = vec_data.into_iter().map(|(_, v)| v).collect(); - let desc = if IS_KEY { - desc.key_desc() - } else { - desc.value_desc() - }; - if desc.field_count() != vec_data.len() { - return Err(m_error!(ER::TupleErr)); - } - let tuple = build_tuple(&vec_data, desc)?; - Ok(tuple) - } - - async fn lock_x(&self, tx_ctx: &TxCtx, table_id: OID, key: Buf) -> RS<()> { - task_trace!(); - let xid = tx_ctx.xid(); - tx_ctx.write(table_id, key.clone()).await?; - let (notify, wait) = create_notify_wait(); - self.x_lock_mgr - .lock(notify, xid, table_id, key.clone()) - .await?; - let opt = wait.wait().await?; - match opt { - Some(lock_r) => match lock_r { - LockResult::Locked => Ok(()), - LockResult::LockFailed => Err(m_error!( - ER::TxErr, - format!("transaction {} lock failed", xid) - )), - }, - None => Err(m_error!( - ER::TxErr, - format!("transaction {} lock failed", tx_ctx.xid()) - )), - } - } - - async fn insert( - &self, - xid: XID, - table_id: OID, - keys: &VecDatum, - values: &VecDatum, - _opt_insert: &OptInsert, - ) -> RS<()> { - task_trace!(); - let tx_ctx = self.get_tx_ctx(xid)?; - let (key, value) = { - let desc = self.get_desc(table_id).await?; - let key = Self::pk_build_tuple(keys, &desc)?; - let value = Self::val_build_tuple(values, &desc)?; - (key, value) - }; - self.lock_x(&tx_ctx, table_id, key.clone()).await?; - let opt = self.tree_store.get_key(table_id, key.clone()).await?; - if opt.is_some() { - return Err(m_error!( - ER::ExistingSuchElement, - format!("existing key for table {}", table_id) - )); - } - let data_row = DataRow::new(0); - tx_ctx.insert(table_id, key, value, data_row).await?; - Ok(()) - } - - async fn update( - &self, - xid: XID, - table_id: OID, - pred_key: &VecDatum, - _pred_non_key: &Predicate, - values: &VecDatum, - _opt_update: &OptUpdate, - ) -> RS { - let tx_ctx = self.get_tx_ctx(xid)?; - let desc = self.get_desc(table_id).await?; - let key = { Self::pk_build_tuple(pred_key, &desc)? }; - let opt = self.tree_store.get_key(table_id, key.clone()).await?; - let data_row = match opt { - Some(row) => row, - None => { - return Err(m_error!( - ER::NoSuchElement, - format!("no existing key for table {} update", table_id) - )); - } - }; - let opt_tuple_id = data_row.tuple_id().await?; - let tuple_id = rs_of_opt(opt_tuple_id, || { - m_error!( - ER::NoSuchElement, - format!("no existing key for table {} update", table_id) - ) - })?; - let opt_tuple_version = data_row.read_latest().await?; - let tuple_version = rs_of_opt(opt_tuple_version, || { - m_error!( - ER::NoSuchElement, - format!("no existing key for table {} update", table_id) - ) - })?; - let tuple = tuple_version.tuple(); - let vec_delta = Self::val_update_tuple(tuple, values, &desc)?; - tx_ctx - .update(table_id, tuple_id, key, vec_delta, data_row) - .await?; - Ok(1) - } - - async fn read_key( - &self, - xid: XID, - table_id: OID, - pred_key: &VecDatum, - select: &VecSelTerm, - _opt_read: &OptRead, - ) -> RS>> { - let _tx_ctx = self.get_tx_ctx(xid)?; - let desc = self.get_desc(table_id).await?; - let key = { Self::pk_build_tuple(pred_key, &desc)? }; - let opt = self.tree_store.get_key(table_id, key.clone()).await?; - let data_row = match opt { - Some(row) => row, - None => { - return Ok(None); - } - }; - let opt_row = data_row.read_latest().await?; - let tuple = match &opt_row { - Some(version) => version.tuple(), - None => { - return Ok(None); - } - }; - let mut tuple_ret = vec![]; - for i in select.vec() { - let f = desc.get_attr(*i); - let index = f.datum_index(); - let desc = if f.is_primary() { - desc.key_desc().get_field_desc(index) - } else { - desc.value_desc().get_field_desc(index) - }; - let slice = desc.get(tuple)?; - tuple_ret.push(slice.to_vec()); - } - Ok(Some(tuple_ret)) - } - - async fn _begin_tx(&self) -> RS { - task_trace!(); - let snapshot = self.snap_req.start_tx().await?; - let xid = snapshot.xid(); - let tx_ctx = TxCtx::new(xid, snapshot); - let _ = self.tx_ctx.insert_sync(xid, tx_ctx); - Ok(xid) - } - - async fn _commit_tx(&self, xid: XID) -> RS<()> { - task_trace!(); - let tx_ctx = self.get_tx_ctx(xid)?; - tx_ctx.commit(&*self.x_lock_mgr).await?; - self.snap_req.end_tx(xid).await?; - self.remove_tx_ctx(xid); - Ok(()) - } - - fn get_tx_ctx(&self, xid: XID) -> RS { - let opt = self.tx_ctx.get_sync(&xid); - let entry = rs_of_opt(opt, || { - m_error!(ER::NoSuchElement, format!("no such transaction {}", xid)) - })?; - let ctx = entry.get().clone(); - Ok(ctx) - } - - fn remove_tx_ctx(&self, xid: XID) { - let _ = self.tx_ctx.remove_sync(&xid); - } -} - -#[async_trait] -impl XContract for ThdCtx { - async fn create_table(&self, xid: XID, schema: &SchemaTable) -> RS<()> { - task_trace!(); - self.inner.create_table(xid, schema).await - } - - async fn drop_table(&self, _xid: XID, _oid: OID) -> RS<()> { - todo!() - } - - async fn alter_table(&self, _xid: XID, _oid: OID, _alter_table: &AlterTable) -> RS<()> { - todo!() - } - - async fn begin_tx(&self) -> RS { - self.inner._begin_tx().await - } - - async fn commit_tx(&self, xid: XID) -> RS<()> { - self.inner._commit_tx(xid).await - } - - async fn abort_tx(&self, _xid: XID) -> RS<()> { - todo!() - } - - async fn update( - &self, - xid: XID, - table_id: OID, - pred_key: &VecDatum, - pred_non_key: &Predicate, - values: &VecDatum, - opt_update: &OptUpdate, - ) -> RS { - self.inner - .update(xid, table_id, pred_key, pred_non_key, values, opt_update) - .await - } - - async fn read_key( - &self, - xid: XID, - table_id: OID, - pred_key: &VecDatum, - vec_proj: &VecSelTerm, - opt_read: &OptRead, - ) -> RS>> { - self.inner - .read_key(xid, table_id, pred_key, vec_proj, opt_read) - .await - } - - async fn read_range( - &self, - _xid: XID, - _table_id: OID, - _pred_key: &RangeData, - _pred_non_key: &Predicate, - _select: &VecSelTerm, - _opt_read: &OptRead, - ) -> RS> { - todo!() - } - - async fn delete( - &self, - _xid: XID, - _table_id: OID, - _pred_key: &VecDatum, - _pred_non_key: &Predicate, - _opt_delete: &OptDelete, - ) -> RS { - todo!() - } - - async fn insert( - &self, - xid: XID, - table_id: OID, - keys: &VecDatum, - values: &VecDatum, - opt_insert: &OptInsert, - ) -> RS<()> { - task_trace!(); - self.inner - .insert(xid, table_id, keys, values, opt_insert) - .await - } -} diff --git a/mudu_kernel/src/x_engine/tx_mgr.rs b/mudu_kernel/src/x_engine/tx_mgr.rs new file mode 100644 index 0000000..b58107e --- /dev/null +++ b/mudu_kernel/src/x_engine/tx_mgr.rs @@ -0,0 +1,48 @@ +use crate::server::worker_snapshot::WorkerSnapshot; +use crate::wal::xl_batch::XLBatch; +use mudu::common::id::OID; +use std::collections::BTreeMap; + +pub trait TxMgr: Send + Sync { + fn xid(&self) -> u64; + + fn snapshot(&self) -> WorkerSnapshot; + + fn put(&self, key: Vec, value: Vec); + + fn delete(&self, key: Vec); + + fn get(&self, key: &[u8]) -> Option>>; + + fn put_relation(&self, oid: OID, key: Vec, value: Vec); + + fn delete_relation(&self, oid: OID, key: Vec); + + fn get_relation(&self, oid: OID, key: &[u8]) -> Option>>; + + fn staged_relation_items_in_range( + &self, + oid: OID, + start_key: &[u8], + end_key: &[u8], + ) -> Vec<(Vec, Option>)>; + + fn staged_relation_ops(&self) -> BTreeMap, Option>>>; + + fn staged_items_in_range( + &self, + start_key: &[u8], + end_key: &[u8], + ) -> Vec<(Vec, Option>)>; + + fn staged_put_items(&self) -> BTreeMap, Option>>; + + fn is_empty(&self) -> bool; + + fn write_ops(&self) -> Vec<(OID, Vec)>; + + fn build_write_ops(&self); + + fn xl_batch(&self) -> XLBatch; +} + diff --git a/mudu_kernel/src/x_engine/x_param.rs b/mudu_kernel/src/x_engine/x_param.rs index aa20bd8..aa9f8a7 100644 --- a/mudu_kernel/src/x_engine/x_param.rs +++ b/mudu_kernel/src/x_engine/x_param.rs @@ -1,11 +1,12 @@ use crate::contract::schema_table::SchemaTable; use crate::x_engine::api::{OptRead, Predicate, RangeData, VecDatum, VecSelTerm}; +use crate::x_engine::tx_mgr::TxMgr; use mudu::common::id::OID; -use mudu::common::xid::XID; +use std::sync::Arc; -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct PAccessKey { - pub xid: XID, + pub tx_mgr: Arc, pub table_id: OID, pub pred_key: VecDatum, pub select: VecSelTerm, @@ -13,7 +14,7 @@ pub struct PAccessKey { } pub struct PAccessRange { - pub xid: XID, + pub tx_mgr: Arc, pub table_id: OID, pub pred_key: RangeData, pub pred_non_key: Predicate, @@ -21,37 +22,37 @@ pub struct PAccessRange { pub opt_read: OptRead, } -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct PCreateTable { - pub xid: XID, + pub tx_mgr: Arc, pub schema: SchemaTable, } -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct PDropTable { - pub xid: XID, + pub tx_mgr: Arc, pub oid: Option, } -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct PInsertKeyValue { - pub xid: XID, + pub tx_mgr: Arc, pub table_id: OID, pub key: VecDatum, pub value: VecDatum, } -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct PUpdateKeyValue { - pub xid: XID, + pub tx_mgr: Arc, pub table_id: OID, pub key: VecDatum, pub value: VecDatum, } -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct PDeleteKeyValue { - pub xid: XID, + pub tx_mgr: Arc, pub table_id: OID, pub key: VecDatum, } diff --git a/mudu_runtime/src/backend/http_api/io_uring_http_api.rs b/mudu_runtime/src/backend/http_api/io_uring_http_api.rs index c972ce2..8fcab7e 100644 --- a/mudu_runtime/src/backend/http_api/io_uring_http_api.rs +++ b/mudu_runtime/src/backend/http_api/io_uring_http_api.rs @@ -1,6 +1,6 @@ use super::{ - find_app, parse_json_object_body, to_param, AsyncIoUringInvokeClientFactory, HttpApi, - ServerTopology, TokioIoUringInvokeClientFactory, WorkerTopology, + AsyncIoUringInvokeClientFactory, HttpApi, ServerTopology, TokioIoUringInvokeClientFactory, + WorkerTopology, find_app, parse_json_object_body, to_param, }; use crate::backend::app_mgr::AppMgr; use crate::backend::mududb_cfg::MuduDBCfg; diff --git a/mudu_runtime/src/backend/http_api/legacy_http_api.rs b/mudu_runtime/src/backend/http_api/legacy_http_api.rs index 19df831..482d1aa 100644 --- a/mudu_runtime/src/backend/http_api/legacy_http_api.rs +++ b/mudu_runtime/src/backend/http_api/legacy_http_api.rs @@ -1,6 +1,6 @@ use super::{ - legacy_invoke_async_proc, legacy_invoke_sync_proc, parse_json_object_body, - runtime_get_app_and_desc, HttpApi, + HttpApi, legacy_invoke_async_proc, legacy_invoke_sync_proc, parse_json_object_body, + runtime_get_app_and_desc, }; use crate::service::runtime::Runtime; use async_trait::async_trait; diff --git a/mudu_runtime/src/backend/http_api/mod.rs b/mudu_runtime/src/backend/http_api/mod.rs index e750b8c..3477a30 100644 --- a/mudu_runtime/src/backend/http_api/mod.rs +++ b/mudu_runtime/src/backend/http_api/mod.rs @@ -28,7 +28,7 @@ use crate::service::app_inst::AppInst; use crate::service::runtime::Runtime; use actix_cors::Cors; use actix_web::http::StatusCode; -use actix_web::{delete, get, post, web, App, HttpResponse, HttpServer, Responder}; +use actix_web::{App, HttpResponse, HttpServer, Responder, delete, get, post, web}; use async_trait::async_trait; use base64::Engine; use mudu::common::id::OID; @@ -514,7 +514,7 @@ async fn find_app(app_mgr: &dyn AppMgr, app_name: &str) -> RS { #[cfg(test)] mod test { use super::*; - use actix_web::{test, App}; + use actix_web::{App, test}; #[cfg(target_os = "linux")] use mudu::common::app_info::AppInfo; #[cfg(target_os = "linux")] diff --git a/mudu_runtime/src/backend/mod.rs b/mudu_runtime/src/backend/mod.rs index 0e11117..c926615 100644 --- a/mudu_runtime/src/backend/mod.rs +++ b/mudu_runtime/src/backend/mod.rs @@ -9,6 +9,8 @@ mod session_handle_task; mod test_backend; mod test_pg_cli; mod test_sql; +#[cfg(all(test, target_os = "linux"))] +mod sql_async_client_test; pub mod web_handle_task; pub mod web_serve; @@ -18,9 +20,5 @@ mod app_mgr; mod iouring_admin; #[cfg(target_os = "linux")] pub mod mudu_app_mgr; -pub mod mudu_conn_async; -mod mudu_conn_core; -mod mudu_prepared_stmt; -mod mudu_result_set_async; #[cfg(target_os = "linux")] pub mod server_ur; diff --git a/mudu_runtime/src/backend/mudu_conn_async.rs b/mudu_runtime/src/backend/mudu_conn_async.rs deleted file mode 100644 index 42454c9..0000000 --- a/mudu_runtime/src/backend/mudu_conn_async.rs +++ /dev/null @@ -1,94 +0,0 @@ -use crate::backend::mudu_conn_core::MuduConnCore; -use crate::backend::mudu_prepared_stmt::MuduPreparedStmt; -use async_trait::async_trait; -use mudu::common::result::RS; -use mudu::common::xid::XID; -use mudu_contract::database::db_conn::DBConnAsync; -use mudu_contract::database::prepared_stmt::PreparedStmt; -use mudu_contract::database::result_set::ResultSetAsync; -use mudu_contract::database::sql_params::SQLParams; -use mudu_contract::database::sql_stmt::SQLStmt; -use mudu_kernel::contract::meta_mgr::MetaMgr; -use mudu_kernel::x_engine::api::XContract; -use std::sync::Arc; - -pub struct MuduConnAsync { - core: Arc, -} - -impl MuduConnAsync { - pub fn new(meta_mgr: Arc, x_contract: Arc) -> Self { - Self { - core: Arc::new(MuduConnCore::new(meta_mgr, x_contract)), - } - } -} - -#[async_trait] -impl DBConnAsync for MuduConnAsync { - async fn prepare(&self, stmt: Box) -> RS> { - let parsed = self.core.parse_one(stmt.as_ref())?; - let desc = self.core.describe_stmt(parsed.clone()).await?; - Ok(Arc::new(MuduPreparedStmt::new( - self.core.clone(), - parsed, - desc, - ))) - } - - async fn exec_silent(&self, sql_text: String) -> RS<()> { - let stmts = self.core.parse_many(&sql_text)?; - for stmt in stmts { - match stmt { - sql_parser::ast::stmt_type::StmtType::Select(_) => { - let _ = self.core.query(stmt, Box::new(())).await?; - } - sql_parser::ast::stmt_type::StmtType::Command(_) => { - let _ = self.core.execute(stmt, Box::new(())).await?; - } - } - } - Ok(()) - } - - async fn begin_tx(&self) -> RS { - self.core.begin_tx().await - } - - async fn rollback_tx(&self) -> RS<()> { - self.core.rollback_tx().await - } - - async fn commit_tx(&self) -> RS<()> { - self.core.commit_tx().await - } - - async fn query( - &self, - sql: Box, - param: Box, - ) -> RS> { - let parsed = self.core.parse_one(sql.as_ref())?; - self.core.query(parsed, param).await - } - - async fn execute(&self, sql: Box, param: Box) -> RS { - let parsed = self.core.parse_one(sql.as_ref())?; - self.core.execute(parsed, param).await - } - - async fn batch(&self, sql: Box, param: Box) -> RS { - if param.size() != 0 { - return Err(mudu::m_error!( - mudu::error::ec::EC::NotImplemented, - "batch with parameters is not implemented" - )); - } - let stmts = self.core.parse_many(sql.as_ref())?; - let mut total = 0; - for stmt in stmts { - total += self.core.execute(stmt, Box::new(())).await?; - } - Ok(total) - } -} diff --git a/mudu_runtime/src/backend/mudu_conn_core.rs b/mudu_runtime/src/backend/mudu_conn_core.rs deleted file mode 100644 index 4aaa8a9..0000000 --- a/mudu_runtime/src/backend/mudu_conn_core.rs +++ /dev/null @@ -1,201 +0,0 @@ -use crate::backend::mudu_result_set_async::MuduResultSetAsync; -use mudu::common::result::RS; -use mudu::common::xid::XID; -use mudu::error::ec::EC; -use mudu::m_error; -use mudu_contract::database::result_set::ResultSetAsync; -use mudu_contract::database::sql_params::SQLParams; -use mudu_contract::tuple::tuple_field_desc::TupleFieldDesc; -use mudu_contract::tuple::tuple_value::TupleValue; -use mudu_contract::tuple::typed_bin::TypedBin; -use mudu_kernel::contract::meta_mgr::MetaMgr; -use mudu_kernel::contract::query_exec::QueryExec; -use mudu_kernel::sql::binder::Binder; -use mudu_kernel::sql::bound_stmt::BoundStmt; -use mudu_kernel::sql::describer::Describer; -use mudu_kernel::sql::plan_ctx::PlanCtx; -use mudu_kernel::sql::planner::Planner; -use mudu_kernel::x_engine::api::XContract; -use mudu_type::datum::DatumDyn; -use sql_parser::ast::parser::SQLParser; -use sql_parser::ast::stmt_type::StmtType; -use std::sync::Arc; -use tokio::sync::Mutex; - -pub struct MuduConnCore { - meta_mgr: Arc, - x_contract: Arc, - parser: Arc, - tx_state: Arc>>, -} - -enum TxScope { - Auto(XID), - Existing, -} - -impl MuduConnCore { - pub fn new(meta_mgr: Arc, x_contract: Arc) -> Self { - Self { - meta_mgr, - x_contract, - parser: Arc::new(SQLParser::new()), - tx_state: Arc::new(Mutex::new(None)), - } - } - - pub fn parse_one(&self, sql: &dyn mudu_contract::database::sql_stmt::SQLStmt) -> RS { - let stmt_list = self.parser.parse(&sql.to_sql_string())?; - let mut stmts = stmt_list.into_stmts(); - if stmts.len() != 1 { - return Err(m_error!(EC::ParseErr, "expected exactly one statement")); - } - Ok(stmts.remove(0)) - } - - pub fn parse_many( - &self, - sql: &dyn mudu_contract::database::sql_stmt::SQLStmt, - ) -> RS> { - Ok(self.parser.parse(&sql.to_sql_string())?.into_stmts()) - } - - pub async fn describe_stmt(&self, stmt: StmtType) -> RS> { - let desc = Describer::new(self.meta_mgr.clone()).describe(stmt).await?; - Ok(Arc::new(desc)) - } - - pub async fn query( - &self, - stmt: StmtType, - params: Box, - ) -> RS> { - let (scope, xid) = self.enter_tx().await?; - let result = self.query_inner(stmt, params, xid).await; - match self.leave_tx(scope, result.is_ok()).await { - Ok(()) => {} - Err(e) => return Err(e), - } - result.map(|rs| Arc::new(rs) as Arc) - } - - pub async fn execute(&self, stmt: StmtType, params: Box) -> RS { - let (scope, xid) = self.enter_tx().await?; - let result = self.execute_inner(stmt, params, xid).await; - self.leave_tx(scope, result.is_ok()).await?; - result - } - - async fn query_inner( - &self, - stmt: StmtType, - params: Box, - xid: XID, - ) -> RS { - let bound = Binder::new(self.meta_mgr.clone()) - .bind(stmt, params.as_ref()) - .await?; - let BoundStmt::Query(bound_query) = bound else { - return Err(m_error!(EC::TypeErr, "statement is not a query")); - }; - let planner = Planner::new(PlanCtx { - xid, - meta_mgr: self.meta_mgr.clone(), - x_contract: self.x_contract.clone(), - }); - let exec = planner.plan_query(bound_query).await?; - MuduResultSetAsync::from_query_exec(exec).await - } - - async fn execute_inner(&self, stmt: StmtType, params: Box, xid: XID) -> RS { - let bound = Binder::new(self.meta_mgr.clone()) - .bind(stmt, params.as_ref()) - .await?; - let BoundStmt::Command(bound_command) = bound else { - return Err(m_error!(EC::TypeErr, "statement is not a command")); - }; - let planner = Planner::new(PlanCtx { - xid, - meta_mgr: self.meta_mgr.clone(), - x_contract: self.x_contract.clone(), - }); - let cmd = planner.plan_command(bound_command).await?; - cmd.prepare().await?; - cmd.run().await?; - cmd.affected_rows().await - } - - async fn enter_tx(&self) -> RS<(TxScope, XID)> { - let guard = self.tx_state.lock().await; - if let Some(xid) = *guard { - return Ok((TxScope::Existing, xid)); - } - drop(guard); - let xid = self.x_contract.begin_tx().await?; - Ok((TxScope::Auto(xid), xid)) - } - - async fn leave_tx(&self, scope: TxScope, success: bool) -> RS<()> { - match scope { - TxScope::Existing => Ok(()), - TxScope::Auto(xid) => { - if success { - self.x_contract.commit_tx(xid).await - } else { - self.x_contract.abort_tx(xid).await - } - } - } - } - - pub async fn begin_tx(&self) -> RS { - let mut guard = self.tx_state.lock().await; - if let Some(xid) = *guard { - return Ok(xid); - } - let xid = self.x_contract.begin_tx().await?; - *guard = Some(xid); - Ok(xid) - } - - pub async fn commit_tx(&self) -> RS<()> { - let mut guard = self.tx_state.lock().await; - let xid = guard - .take() - .ok_or_else(|| m_error!(EC::NoSuchElement, "no active transaction"))?; - drop(guard); - self.x_contract.commit_tx(xid).await - } - - pub async fn rollback_tx(&self) -> RS<()> { - let mut guard = self.tx_state.lock().await; - let xid = guard - .take() - .ok_or_else(|| m_error!(EC::NoSuchElement, "no active transaction"))?; - drop(guard); - self.x_contract.abort_tx(xid).await - } -} - -pub async fn query_exec_to_rows(exec: Arc) -> RS<(Vec, TupleFieldDesc)> { - exec.open().await?; - let desc = exec.tuple_desc()?; - let mut rows = Vec::new(); - while let Some(row) = exec.next().await? { - rows.push(tuple_field_to_value(row, &desc)?); - } - Ok((rows, desc)) -} - -fn tuple_field_to_value( - row: mudu_contract::tuple::tuple_field::TupleField, - desc: &TupleFieldDesc, -) -> RS { - let mut values = Vec::with_capacity(row.fields().len()); - for (index, field) in row.fields().iter().enumerate() { - let datum_desc = &desc.fields()[index]; - let typed = TypedBin::new(datum_desc.dat_type_id(), field.clone()); - values.push(typed.to_value(datum_desc.dat_type())?); - } - Ok(TupleValue::from(values)) -} diff --git a/mudu_runtime/src/backend/mudu_prepared_stmt.rs b/mudu_runtime/src/backend/mudu_prepared_stmt.rs deleted file mode 100644 index 097d01e..0000000 --- a/mudu_runtime/src/backend/mudu_prepared_stmt.rs +++ /dev/null @@ -1,40 +0,0 @@ -use crate::backend::mudu_conn_core::MuduConnCore; -use async_trait::async_trait; -use mudu::common::result::RS; -use mudu_contract::database::prepared_stmt::PreparedStmt; -use mudu_contract::database::result_set::ResultSetAsync; -use mudu_contract::database::sql_params::SQLParams; -use mudu_contract::tuple::tuple_field_desc::TupleFieldDesc; -use sql_parser::ast::stmt_type::StmtType; -use std::sync::Arc; - -pub struct MuduPreparedStmt { - core: Arc, - stmt: StmtType, - desc: Arc, -} - -impl MuduPreparedStmt { - pub fn new(core: Arc, stmt: StmtType, desc: Arc) -> Self { - Self { core, stmt, desc } - } -} - -#[async_trait] -impl PreparedStmt for MuduPreparedStmt { - async fn query(&self, params: Box) -> RS> { - self.core.query(self.stmt.clone(), params).await - } - - async fn execute(&self, params: Box) -> RS { - self.core.execute(self.stmt.clone(), params).await - } - - async fn desc(&self) -> RS> { - Ok(self.desc.clone()) - } - - async fn reset(&self) -> RS<()> { - Ok(()) - } -} diff --git a/mudu_runtime/src/backend/sql_async_client_test.rs b/mudu_runtime/src/backend/sql_async_client_test.rs new file mode 100644 index 0000000..07a9080 --- /dev/null +++ b/mudu_runtime/src/backend/sql_async_client_test.rs @@ -0,0 +1,399 @@ +#[cfg(test)] +mod tests { + use crate::backend::backend::Backend; + use crate::backend::mududb_cfg::{MuduDBCfg, ServerMode}; + use lazy_static::lazy_static; + use mudu::common::result::RS; + use mudu_cli::client::async_client::{AsyncClient, AsyncClientImpl}; + use mudu_contract::protocol::ClientRequest; + use mudu_utils::notifier::notify_wait; + use std::fs; + use std::net::TcpListener; + use std::path::PathBuf; + use std::thread; + use std::thread::JoinHandle; + use std::time::{Duration, Instant}; + use tokio::sync::Mutex as AsyncMutex; + use tokio::time::{sleep, timeout}; + + lazy_static! { + static ref SQL_ASYNC_BACKEND_TEST_LOCK: AsyncMutex<()> = AsyncMutex::new(()); + } + + fn temp_dir(prefix: &str) -> PathBuf { + std::env::temp_dir().join(format!( + "{}_{}", + prefix, + mudu_sys::random::next_uuid_v4_string() + )) + } + + fn reserve_port() -> Option { + TcpListener::bind("127.0.0.1:0") + .ok() + .and_then(|listener| listener.local_addr().ok().map(|addr| addr.port())) + } + + fn test_cfg() -> Option { + let tcp_port = reserve_port()?; + let db_path = temp_dir("mudu_sql_async_db"); + let mpk_path = temp_dir("mudu_sql_async_mpk"); + fs::create_dir_all(&db_path).ok()?; + fs::create_dir_all(&mpk_path).ok()?; + Some(MuduDBCfg { + mpk_path: mpk_path.to_string_lossy().into_owned(), + db_path: db_path.to_string_lossy().into_owned(), + listen_ip: "127.0.0.1".to_string(), + http_listen_port: 0, + pg_listen_port: 0, + tcp_listen_port: tcp_port, + server_mode: ServerMode::IOUring, + io_uring_worker_threads: 1, + ..Default::default() + }) + } + + async fn wait_for_client(addr: &str, timeout: Duration) -> RS { + let deadline = Instant::now() + timeout; + loop { + match AsyncClientImpl::connect(addr).await { + Ok(client) => return Ok(client), + Err(err) => { + if Instant::now() >= deadline { + return Err(err); + } + sleep(Duration::from_millis(50)).await; + } + } + } + } + + async fn with_timeout(future: impl std::future::Future>) -> RS { + timeout(Duration::from_secs(20), future) + .await + .map_err(|_| mudu::m_error!(mudu::error::ec::EC::TokioErr, "sql async client test timed out"))? + } + + fn stop_server( + client: AsyncClientImpl, + stop_notifier: mudu_utils::notifier::Notifier, + server: JoinHandle>, + ) -> RS<()> { + drop(client); + stop_notifier.notify_all(); + server + .join() + .map_err(|_| mudu::m_error!(mudu::error::ec::EC::ThreadErr, "join sql async backend thread error"))? + } + + async fn start_client_backend() -> Option>, + )>> { + let Some(cfg) = test_cfg() else { + return None; + }; + let addr = format!("127.0.0.1:{}", cfg.tcp_listen_port); + let (stop_notifier, stop_waiter) = notify_wait(); + let server = thread::spawn(move || Backend::sync_serve_with_stop(cfg, stop_waiter)); + let client = match wait_for_client(&addr, Duration::from_secs(10)).await { + Ok(client) => client, + Err(err) => { + stop_notifier.notify_all(); + let _ = server.join(); + return Some(Err(err)); + } + }; + Some(Ok((client, stop_notifier, server))) + } + + async fn exec_sql(client: &mut AsyncClientImpl, sql: &str) -> RS<()> { + with_timeout(client.execute(ClientRequest::new("default", sql))) + .await + .map(|_| ()) + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn async_client_roundtrip_sql_crud_over_iouring_backend() -> RS<()> { + let _guard = SQL_ASYNC_BACKEND_TEST_LOCK.lock().await; + let Some(cfg) = test_cfg() else { + return Ok(()); + }; + let addr = format!("127.0.0.1:{}", cfg.tcp_listen_port); + let (stop_notifier, stop_waiter) = notify_wait(); + let server = thread::spawn(move || Backend::sync_serve_with_stop(cfg, stop_waiter)); + + let mut client = match wait_for_client(&addr, Duration::from_secs(10)).await { + Ok(client) => client, + Err(err) => { + stop_notifier.notify_all(); + let _ = server.join(); + if err.to_string().contains("connect io_uring tcp server error") { + return Ok(()); + } + return Err(err); + } + }; + + with_timeout(client + .execute(ClientRequest::new( + "default", + "CREATE TABLE t(id INT, v INT, PRIMARY KEY(id))", + ))) + .await?; + let inserted = with_timeout(client + .execute(ClientRequest::new( + "default", + "INSERT INTO t(id, v) VALUES (1, 10)", + ))) + .await?; + assert_eq!(inserted.affected_rows(), 1); + + let selected = with_timeout(client + .query(ClientRequest::new( + "default", + "SELECT id, v FROM t WHERE id = 1", + ))) + .await?; + assert_eq!(selected.rows(), &[vec!["1".to_string(), "10".to_string()]]); + + let updated = with_timeout(client + .execute(ClientRequest::new( + "default", + "UPDATE t SET v = 20 WHERE id = 1", + ))) + .await?; + assert_eq!(updated.affected_rows(), 1); + + let selected = with_timeout(client.query(ClientRequest::new( + "default", + "SELECT v FROM t WHERE id = 1", + ))) + .await?; + assert_eq!(selected.rows(), &[vec!["20".to_string()]]); + + let deleted = with_timeout(client + .execute(ClientRequest::new( + "default", + "DELETE FROM t WHERE id = 1", + ))) + .await?; + assert_eq!(deleted.affected_rows(), 1); + + let selected = with_timeout(client.query(ClientRequest::new( + "default", + "SELECT id FROM t WHERE id = 1", + ))) + .await?; + assert!(selected.rows().is_empty()); + + stop_server(client, stop_notifier, server)?; + Ok(()) + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn async_client_batch_executes_multiple_sql_commands() -> RS<()> { + let _guard = SQL_ASYNC_BACKEND_TEST_LOCK.lock().await; + let Some(cfg) = test_cfg() else { + return Ok(()); + }; + let addr = format!("127.0.0.1:{}", cfg.tcp_listen_port); + let (stop_notifier, stop_waiter) = notify_wait(); + let server = thread::spawn(move || Backend::sync_serve_with_stop(cfg, stop_waiter)); + + let mut client = match wait_for_client(&addr, Duration::from_secs(10)).await { + Ok(client) => client, + Err(err) => { + stop_notifier.notify_all(); + let _ = server.join(); + if err.to_string().contains("connect io_uring tcp server error") { + return Ok(()); + } + return Err(err); + } + }; + + with_timeout(client + .batch(ClientRequest::new( + "default", + "CREATE TABLE t(id INT, v INT, PRIMARY KEY(id));\ + INSERT INTO t(id, v) VALUES (1, 11);", + ))) + .await?; + + let selected = with_timeout(client + .query(ClientRequest::new( + "default", + "SELECT id, v FROM t WHERE id = 1", + ))) + .await?; + assert_eq!(selected.rows(), &[vec!["1".to_string(), "11".to_string()]]); + + stop_server(client, stop_notifier, server)?; + Ok(()) + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn async_client_drop_table_removes_table_from_catalog() -> RS<()> { + let _guard = SQL_ASYNC_BACKEND_TEST_LOCK.lock().await; + let Some(started) = start_client_backend().await else { + return Ok(()); + }; + let (mut client, stop_notifier, server) = started?; + + exec_sql( + &mut client, + "CREATE TABLE t(id INT, v INT, PRIMARY KEY(id))", + ) + .await?; + exec_sql(&mut client, "INSERT INTO t(id, v) VALUES (1, 10)").await?; + exec_sql(&mut client, "DROP TABLE t").await?; + + let err = with_timeout(client.query(ClientRequest::new( + "default", + "SELECT id, v FROM t WHERE id = 1", + ))) + .await + .expect_err("query on dropped table should fail"); + assert!(err.to_string().contains("no such table")); + + stop_server(client, stop_notifier, server)?; + Ok(()) + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn async_client_range_scan_over_primary_key() -> RS<()> { + let _guard = SQL_ASYNC_BACKEND_TEST_LOCK.lock().await; + let Some(started) = start_client_backend().await else { + return Ok(()); + }; + let (mut client, stop_notifier, server) = started?; + + exec_sql( + &mut client, + "CREATE TABLE t(id INT, v INT, PRIMARY KEY(id))", + ) + .await?; + with_timeout(client.batch(ClientRequest::new( + "default", + "INSERT INTO t(id, v) VALUES (5, 50);\ + INSERT INTO t(id, v) VALUES (1, 10);\ + INSERT INTO t(id, v) VALUES (3, 30);\ + INSERT INTO t(id, v) VALUES (2, 20);\ + INSERT INTO t(id, v) VALUES (4, 40);", + ))) + .await?; + + let selected = with_timeout(client.query(ClientRequest::new( + "default", + "SELECT id, v FROM t WHERE id >= 2 AND id <= 4", + ))) + .await?; + assert_eq!( + selected.rows(), + &[ + vec!["2".to_string(), "20".to_string()], + vec!["3".to_string(), "30".to_string()], + vec!["4".to_string(), "40".to_string()], + ] + ); + + let selected = with_timeout(client.query(ClientRequest::new( + "default", + "SELECT id FROM t WHERE id > 2 AND id <= 4", + ))) + .await?; + assert_eq!( + selected.rows(), + &[ + vec!["3".to_string()], + vec!["4".to_string()], + ] + ); + + let selected = with_timeout(client.query(ClientRequest::new( + "default", + "SELECT v FROM t WHERE id >= 4", + ))) + .await?; + assert_eq!( + selected.rows(), + &[ + vec!["40".to_string()], + vec!["50".to_string()], + ] + ); + + let selected = with_timeout(client.query(ClientRequest::new( + "default", + "SELECT id FROM t WHERE id > 10", + ))) + .await?; + assert!(selected.rows().is_empty()); + + let selected = with_timeout(client.query(ClientRequest::new( + "default", + "SELECT id FROM t WHERE id >= 3 AND id <= 3", + ))) + .await?; + assert_eq!(selected.rows(), &[vec!["3".to_string()]]); + + let selected = with_timeout(client.query(ClientRequest::new( + "default", + "SELECT id FROM t WHERE id < 3", + ))) + .await?; + assert_eq!( + selected.rows(), + &[ + vec!["1".to_string()], + vec!["2".to_string()], + ] + ); + + let selected = with_timeout(client.query(ClientRequest::new( + "default", + "SELECT v FROM t WHERE id >= 2 AND id <= 4", + ))) + .await?; + assert_eq!( + selected.rows(), + &[ + vec!["20".to_string()], + vec!["30".to_string()], + vec!["40".to_string()], + ] + ); + + stop_server(client, stop_notifier, server)?; + Ok(()) + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn async_client_rejects_mixed_equality_and_range_key_predicates() -> RS<()> { + let _guard = SQL_ASYNC_BACKEND_TEST_LOCK.lock().await; + let Some(started) = start_client_backend().await else { + return Ok(()); + }; + let (mut client, stop_notifier, server) = started?; + + exec_sql( + &mut client, + "CREATE TABLE t(k1 INT, k2 INT, v INT, PRIMARY KEY(k1, k2))", + ) + .await?; + let err = with_timeout(client.query(ClientRequest::new( + "default", + "SELECT k1, k2 FROM t WHERE k1 = 1 AND k2 >= 2 AND k2 <= 4", + ))) + .await + .expect_err("mixed equality and range predicate should be rejected"); + assert!(err + .to_string() + .contains("mixed equality and range predicates are not implemented")); + + stop_server(client, stop_notifier, server)?; + Ok(()) + } +} diff --git a/mudu_runtime/src/db_connector.rs b/mudu_runtime/src/db_connector.rs index 1ef6971..f5c020f 100644 --- a/mudu_runtime/src/db_connector.rs +++ b/mudu_runtime/src/db_connector.rs @@ -8,7 +8,9 @@ use mudu::error::ec::EC; use mudu::m_error; use mudu_contract::database::db_conn::DBConnSync; use mudu_contract::database::sql::DBConn; +use mudu_kernel::mudu_conn::mudu_conn_async::MuduConnAsync; use std::str::FromStr; +use std::sync::Arc; use strum_macros::EnumString; pub struct DBConnector {} @@ -19,6 +21,7 @@ enum DBType { LibSQL, Turso, LibSQLAsync, + Mudu, } impl DBConnector { @@ -64,6 +67,7 @@ impl DBConnector { DBType::LibSQL => create_ls_conn(&db_path, &app_name, &ddl_path), DBType::Turso => create_turso_conn(&db_path, &app_name).await, DBType::LibSQLAsync => create_libsql_async_conn(&db_path, &app_name).await, + DBType::Mudu => create_mudu_conn().await, }, None => Err(m_error!(EC::ParseErr, "not a valid DB type")), } @@ -74,6 +78,10 @@ impl DBConnector { } } +async fn create_mudu_conn() -> RS { + Ok(DBConn::Async(Arc::new(MuduConnAsync::new()))) +} + fn parse_key_value(s: &str) -> RS<(String, String)> { let parts: Vec<&str> = s.splitn(2, '=').collect(); if parts.len() != 2 { diff --git a/mudu_runtime/src/db_libsql_async/libsql_async_conn_inner.rs b/mudu_runtime/src/db_libsql_async/libsql_async_conn_inner.rs index 192c877..4dadf2b 100644 --- a/mudu_runtime/src/db_libsql_async/libsql_async_conn_inner.rs +++ b/mudu_runtime/src/db_libsql_async/libsql_async_conn_inner.rs @@ -4,9 +4,9 @@ use crate::db_libsql_async::result_set::{LibSQLAsyncResultSet, ResultSetLease}; use async_trait::async_trait; use futures::TryFutureExt; use lazy_static::lazy_static; -use libsql::{params_from_iter, Builder, Connection, Database, Statement, Transaction}; +use libsql::{Builder, Connection, Database, Statement, Transaction, params_from_iter}; use mudu::common::result::RS; -use mudu::common::xid::{new_xid, XID}; +use mudu::common::xid::{XID, new_xid}; use mudu::error::ec::EC; use mudu::error::err::MError; use mudu::m_error; @@ -259,6 +259,8 @@ fn _to_libsql_value(datum: &DatValue, ty: &DatType) -> RS { let v = match id { DatTypeID::I32 => libsql::Value::Integer(datum.expect_i32().clone() as _), DatTypeID::I64 => libsql::Value::Integer(datum.expect_i64().clone() as _), + DatTypeID::U128 => libsql::Value::Text(datum.expect_u128().to_string()), + DatTypeID::I128 => libsql::Value::Text(datum.expect_i128().to_string()), DatTypeID::F32 => libsql::Value::Real(datum.expect_f32().clone() as _), DatTypeID::F64 => libsql::Value::Real(datum.expect_f64().clone() as _), DatTypeID::String => libsql::Value::Text(datum.expect_string().clone()), @@ -397,7 +399,7 @@ impl ResultSetLease for PreparedSlotLease { #[cfg(test)] mod tests { - use libsql::{params, Builder, Value}; + use libsql::{Builder, Value, params}; use std::time::{SystemTime, UNIX_EPOCH}; fn temp_db_path(label: &str) -> String { diff --git a/mudu_runtime/src/db_libsql_async/result_set.rs b/mudu_runtime/src/db_libsql_async/result_set.rs index bd93055..99b4aad 100644 --- a/mudu_runtime/src/db_libsql_async/result_set.rs +++ b/mudu_runtime/src/db_libsql_async/result_set.rs @@ -125,6 +125,24 @@ fn turso_db_row_to_tuple_item(row: Row, item_desc: &[DatumDesc]) -> RS { + let val = row.get::(n).map_err(|e| { + m_error!(EC::DBInternalError, "libsql db get item of row error", e) + })?; + let val = val + .parse::() + .map_err(|e| m_error!(EC::DBInternalError, "libsql db oid parse error", e))?; + DatValue::from_u128(val) + } + DatTypeID::I128 => { + let val = row.get::(n).map_err(|e| { + m_error!(EC::DBInternalError, "libsql db get item of row error", e) + })?; + let val = val + .parse::() + .map_err(|e| m_error!(EC::DBInternalError, "libsql db i128 parse error", e))?; + DatValue::from_i128(val) + } DatTypeID::F32 => { let val = row.get::(n).map_err(|e| { m_error!(EC::DBInternalError, "libsql db get item of row error", e) diff --git a/mudu_runtime/src/db_postgres/result_set_pg.rs b/mudu_runtime/src/db_postgres/result_set_pg.rs index 6c1459b..6503124 100644 --- a/mudu_runtime/src/db_postgres/result_set_pg.rs +++ b/mudu_runtime/src/db_postgres/result_set_pg.rs @@ -38,6 +38,16 @@ impl ResultSet for ResultSetPG { let val: i64 = row.get(i); DatValue::from_i64(val) } + DatTypeID::U128 => { + let val: String = row.get(i); + let val = val.parse::().expect("postgres oid parse error"); + DatValue::from_u128(val) + } + DatTypeID::I128 => { + let val: String = row.get(i); + let val = val.parse::().expect("postgres i128 parse error"); + DatValue::from_i128(val) + } DatTypeID::F32 => { let val: f32 = row.get(i); DatValue::from_f32(val) diff --git a/mudu_runtime/src/db_turso/result_set.rs b/mudu_runtime/src/db_turso/result_set.rs index d7454f1..4266dbc 100644 --- a/mudu_runtime/src/db_turso/result_set.rs +++ b/mudu_runtime/src/db_turso/result_set.rs @@ -122,6 +122,24 @@ fn turso_db_row_to_tuple_item(row: Row, item_desc: &[DatumDesc]) -> RS { + let val = row.get::(n).map_err(|e| { + m_error!(EC::DBInternalError, "turso db get item of row error", e) + })?; + let val = val + .parse::() + .map_err(|e| m_error!(EC::DBInternalError, "turso db oid parse error", e))?; + DatValue::from_u128(val) + } + DatTypeID::I128 => { + let val = row.get::(n).map_err(|e| { + m_error!(EC::DBInternalError, "turso db get item of row error", e) + })?; + let val = val + .parse::() + .map_err(|e| m_error!(EC::DBInternalError, "turso db i128 parse error", e))?; + DatValue::from_i128(val) + } DatTypeID::F32 => { let val = row.get::(n).map_err(|e| { m_error!(EC::DBInternalError, "turso db get item of row error", e) diff --git a/mudu_runtime/src/db_turso/turso_conn_inner.rs b/mudu_runtime/src/db_turso/turso_conn_inner.rs index 5167d5d..39488da 100644 --- a/mudu_runtime/src/db_turso/turso_conn_inner.rs +++ b/mudu_runtime/src/db_turso/turso_conn_inner.rs @@ -262,6 +262,8 @@ fn _to_turso_value(datum: &DatValue, ty: &DatType) -> RS { let v = match id { DatTypeID::I32 => turso::Value::Integer(datum.expect_i32().clone() as _), DatTypeID::I64 => turso::Value::Integer(datum.expect_i64().clone() as _), + DatTypeID::U128 => turso::Value::Text(datum.expect_u128().to_string()), + DatTypeID::I128 => turso::Value::Text(datum.expect_i128().to_string()), DatTypeID::F32 => turso::Value::Real(datum.expect_f32().clone() as _), DatTypeID::F64 => turso::Value::Real(datum.expect_f64().clone() as _), DatTypeID::String => turso::Value::Text(datum.expect_string().clone()), diff --git a/mudu_runtime/src/service/mod.rs b/mudu_runtime/src/service/mod.rs index 0299009..1a009f5 100644 --- a/mudu_runtime/src/service/mod.rs +++ b/mudu_runtime/src/service/mod.rs @@ -3,17 +3,22 @@ pub mod app_inst; pub mod app_inst_impl; mod file_name; pub(crate) mod mudu_package; +#[cfg(test)] +mod mudu_package_test; pub mod package_module; pub mod runtime; pub mod runtime_impl; +#[cfg(test)] +mod runtime_impl_test; mod runtime_simple; pub mod test_wasm_mod_path; pub mod procedure_invoke_component; +#[cfg(test)] +mod runtime_simple_test; pub mod service; mod service_impl; mod service_trait; -mod test_runtime_simple; pub mod wt_instance_pre; mod wt_runtime; @@ -27,3 +32,5 @@ mod wt_runtime_component; pub mod app_list; #[allow(unused)] mod kernel_function_p2_async; +#[cfg(test)] +mod wt_runtime_component_test; diff --git a/mudu_runtime/src/service/mudu_package.rs b/mudu_runtime/src/service/mudu_package.rs index e9f98e3..c2a002b 100644 --- a/mudu_runtime/src/service/mudu_package.rs +++ b/mudu_runtime/src/service/mudu_package.rs @@ -143,84 +143,3 @@ fn align_single_module_name( aligned_modules.insert(expected_module_name, byte_code); aligned_modules } - -#[cfg(test)] -mod tests { - use crate::service::file_name; - use crate::service::mudu_package::MuduPackage; - use mudu_contract::procedure::mod_proc_desc::ModProcDesc; - use std::collections::HashMap; - use std::env::temp_dir; - use std::fs; - use std::io::Write; - - #[test] - fn test_app_package() { - let package_file = - temp_dir().join(format!("app_json_desc_{}.mpk", mudu_sys::random::uuid_v4())); - let file = fs::File::create(&package_file).unwrap(); - let mut zip = zip::ZipWriter::new(file); - let options = zip::write::SimpleFileOptions::default(); - - zip.start_file(file_name::PACKAGE_CFG, options).unwrap(); - zip.write_all(br#"{"name":"app-json","lang":"rust","version":"0.1.0","use_async":true}"#) - .unwrap(); - - zip.start_file(file_name::PROCEDURE_DESC, options).unwrap(); - let desc = serde_json::to_vec(&ModProcDesc::new(HashMap::new())).unwrap(); - zip.write_all(&desc).unwrap(); - - zip.start_file(file_name::DDL_SQL, options).unwrap(); - zip.write_all(b"create table t(id integer);\n").unwrap(); - - zip.start_file(file_name::INIT_DB_SQL, options).unwrap(); - zip.write_all(b"").unwrap(); - - zip.start_file("module.wasm", options).unwrap(); - zip.write_all(b"\0asm").unwrap(); - - zip.finish().unwrap(); - - let package = MuduPackage::load(&package_file).unwrap(); - assert_eq!(package.name(), "app-json"); - - fs::remove_file(package_file).unwrap(); - } - - #[test] - fn test_single_module_package_aligns_desc_module_name() { - let package_file = temp_dir().join(format!( - "app_json_align_{}.mpk", - mudu_sys::random::uuid_v4() - )); - let file = fs::File::create(&package_file).unwrap(); - let mut zip = zip::ZipWriter::new(file); - let options = zip::write::SimpleFileOptions::default(); - - zip.start_file(file_name::PACKAGE_CFG, options).unwrap(); - zip.write_all(br#"{"name":"app-json","lang":"rust","version":"0.1.0","use_async":true}"#) - .unwrap(); - - zip.start_file(file_name::PROCEDURE_DESC, options).unwrap(); - zip.write_all( - br#"{"modules":{"module":[{"module_name":"module","proc_name":"proc","param_desc":{"fields":[]},"return_desc":{"fields":[]}}]}}"#, - ) - .unwrap(); - - zip.start_file(file_name::DDL_SQL, options).unwrap(); - zip.write_all(b"create table t(id integer);\n").unwrap(); - - zip.start_file(file_name::INIT_DB_SQL, options).unwrap(); - zip.write_all(b"").unwrap(); - - zip.start_file("key_value.wasm", options).unwrap(); - zip.write_all(b"\0asm").unwrap(); - zip.finish().unwrap(); - - let package = MuduPackage::load(&package_file).unwrap(); - assert!(package.modules.contains_key("module")); - assert!(!package.modules.contains_key("key_value")); - - fs::remove_file(package_file).unwrap(); - } -} diff --git a/mudu_runtime/src/service/mudu_package_test.rs b/mudu_runtime/src/service/mudu_package_test.rs new file mode 100644 index 0000000..e310d45 --- /dev/null +++ b/mudu_runtime/src/service/mudu_package_test.rs @@ -0,0 +1,223 @@ +#[cfg(test)] +mod tests { + use crate::service::file_name; + use crate::service::mudu_package::MuduPackage; + use mudu_contract::procedure::mod_proc_desc::ModProcDesc; + use std::collections::HashMap; + use std::env::temp_dir; + use std::fs; + use std::io::Write; + use std::path::{Path, PathBuf}; + + fn package_file(name: &str) -> PathBuf { + temp_dir().join(format!("{}_{}.mpk", name, mudu_sys::random::uuid_v4())) + } + + fn write_package( + path: &Path, + package_cfg: Option<&[u8]>, + procedure_desc: Option<&[u8]>, + ddl_sql: Option<&[u8]>, + initdb_sql: Option<&[u8]>, + modules: &[(&str, &[u8])], + ) { + let file = fs::File::create(path).unwrap(); + let mut zip = zip::ZipWriter::new(file); + let options = zip::write::SimpleFileOptions::default(); + + if let Some(package_cfg) = package_cfg { + zip.start_file(file_name::PACKAGE_CFG, options).unwrap(); + zip.write_all(package_cfg).unwrap(); + } + if let Some(procedure_desc) = procedure_desc { + zip.start_file(file_name::PROCEDURE_DESC, options).unwrap(); + zip.write_all(procedure_desc).unwrap(); + } + if let Some(ddl_sql) = ddl_sql { + zip.start_file(file_name::DDL_SQL, options).unwrap(); + zip.write_all(ddl_sql).unwrap(); + } + if let Some(initdb_sql) = initdb_sql { + zip.start_file(file_name::INIT_DB_SQL, options).unwrap(); + zip.write_all(initdb_sql).unwrap(); + } + for (name, bytes) in modules { + zip.start_file(*name, options).unwrap(); + zip.write_all(bytes).unwrap(); + } + zip.finish().unwrap(); + } + + fn standard_cfg() -> &'static [u8] { + br#"{"name":"app-json","lang":"rust","version":"0.1.0","use_async":true}"# + } + + fn standard_desc() -> Vec { + serde_json::to_vec(&ModProcDesc::new(HashMap::new())).unwrap() + } + + #[test] + fn loads_valid_package() { + let package_file = package_file("app_json_desc"); + write_package( + &package_file, + Some(standard_cfg()), + Some(&standard_desc()), + Some(b"create table t(id integer);\n"), + Some(b""), + &[("module.wasm", b"\0asm\x01\0\0\0")], + ); + + let package = MuduPackage::load(&package_file).unwrap(); + assert_eq!(package.name(), "app-json"); + + fs::remove_file(package_file).unwrap(); + } + + #[test] + fn single_module_package_aligns_desc_module_name() { + let package_file = package_file("app_json_align"); + write_package( + &package_file, + Some(standard_cfg()), + Some(br#"{"modules":{"module":[{"module_name":"module","proc_name":"proc","param_desc":{"fields":[]},"return_desc":{"fields":[]}}]}}"#), + Some(b"create table t(id integer);\n"), + Some(b""), + &[("key_value.wasm", b"\0asm\x01\0\0\0")], + ); + + let package = MuduPackage::load(&package_file).unwrap(); + assert!(package.modules.contains_key("module")); + assert!(!package.modules.contains_key("key_value")); + + fs::remove_file(package_file).unwrap(); + } + + #[test] + fn load_package_requires_package_cfg() { + let package_file = package_file("missing_cfg"); + write_package( + &package_file, + None, + Some(&standard_desc()), + Some(b"create table t(id integer);\n"), + Some(b""), + &[], + ); + + let err = MuduPackage::load(&package_file).unwrap_err(); + assert!(err.to_string().contains(file_name::PACKAGE_CFG)); + + fs::remove_file(package_file).unwrap(); + } + + #[test] + fn load_package_requires_ddl_sql() { + let package_file = package_file("missing_ddl"); + write_package( + &package_file, + Some(standard_cfg()), + Some(&standard_desc()), + None, + Some(b""), + &[], + ); + + let err = MuduPackage::load(&package_file).unwrap_err(); + assert!(err.to_string().contains("ddl.sql")); + + fs::remove_file(package_file).unwrap(); + } + + #[test] + fn load_package_requires_procedure_desc() { + let package_file = package_file("missing_desc"); + write_package( + &package_file, + Some(standard_cfg()), + None, + Some(b"create table t(id integer);\n"), + Some(b""), + &[], + ); + + let err = MuduPackage::load(&package_file).unwrap_err(); + assert!(err.to_string().contains(file_name::PROCEDURE_DESC)); + + fs::remove_file(package_file).unwrap(); + } + + #[test] + fn load_package_rejects_invalid_procedure_desc_json() { + let package_file = package_file("invalid_desc"); + write_package( + &package_file, + Some(standard_cfg()), + Some(br#"{"modules":"bad"}"#), + Some(b"create table t(id integer);\n"), + Some(b""), + &[], + ); + + let err = MuduPackage::load(&package_file).unwrap_err(); + assert!( + err.to_string() + .contains("parse app procedure description error") + ); + + fs::remove_file(package_file).unwrap(); + } + + #[test] + fn load_package_rejects_invalid_package_cfg_json() { + let package_file = package_file("invalid_cfg"); + write_package( + &package_file, + Some(br#"{"name":1}"#), + Some(&standard_desc()), + Some(b"create table t(id integer);\n"), + Some(b""), + &[], + ); + + let err = MuduPackage::load(&package_file).unwrap_err(); + assert!(err.to_string().contains("parse app configuration error")); + + fs::remove_file(package_file).unwrap(); + } + + #[test] + fn load_package_rejects_corrupt_zip_archive() { + let package_file = package_file("corrupt_zip"); + fs::write(&package_file, b"not-a-zip").unwrap(); + + let err = MuduPackage::load(&package_file).unwrap_err(); + assert!(err.to_string().contains("read achieve file failed")); + + fs::remove_file(package_file).unwrap(); + } + + #[test] + fn multi_module_package_does_not_align_names() { + let package_file = package_file("multi_mod"); + write_package( + &package_file, + Some(standard_cfg()), + Some(br#"{"modules":{"module_a":[{"module_name":"module_a","proc_name":"proc_a","param_desc":{"fields":[]},"return_desc":{"fields":[]}}],"module_b":[{"module_name":"module_b","proc_name":"proc_b","param_desc":{"fields":[]},"return_desc":{"fields":[]}}]}}"#), + Some(b"create table t(id integer);\n"), + Some(b""), + &[ + ("first.wasm", b"\0asm\x01\0\0\0"), + ("second.wasm", b"\0asm\x01\0\0\0"), + ], + ); + + let package = MuduPackage::load(&package_file).unwrap(); + assert!(package.modules.contains_key("first")); + assert!(package.modules.contains_key("second")); + assert!(!package.modules.contains_key("module_a")); + assert!(!package.modules.contains_key("module_b")); + + fs::remove_file(package_file).unwrap(); + } +} diff --git a/mudu_runtime/src/service/procedure_invoke_component.rs b/mudu_runtime/src/service/procedure_invoke_component.rs index 20b2adf..d00d365 100644 --- a/mudu_runtime/src/service/procedure_invoke_component.rs +++ b/mudu_runtime/src/service/procedure_invoke_component.rs @@ -1,6 +1,6 @@ use crate::procedure::procedure::Procedure; use crate::service::runtime_opt::ComponentTarget; -use crate::service::wasi_context_component::{build_wasi_component_context, WasiContextComponent}; +use crate::service::wasi_context_component::{WasiContextComponent, build_wasi_component_context}; use mudu::common::result::RS; use mudu::error::ec::EC; use mudu::m_error; @@ -10,8 +10,8 @@ use mudu_contract::procedure::procedure_param::ProcedureParam; use mudu_contract::procedure::procedure_result::ProcedureResult; use mudu_kernel::server::worker_local::WorkerLocalRef; use std::sync::Mutex; -use wasmtime::component::{InstancePre, TypedFunc}; use wasmtime::Store; +use wasmtime::component::{InstancePre, TypedFunc}; pub struct ProcedureInvokeComponent { inner: Mutex, diff --git a/mudu_runtime/src/service/runtime_impl_test.rs b/mudu_runtime/src/service/runtime_impl_test.rs new file mode 100644 index 0000000..788ea35 --- /dev/null +++ b/mudu_runtime/src/service/runtime_impl_test.rs @@ -0,0 +1,59 @@ +#[cfg(test)] +mod tests { + use crate::service::runtime_impl::create_runtime_service; + use crate::service::runtime_opt::RuntimeOpt; + use crate::service::test_wasm_mod_path::wasm_mod_path; + use mudu_utils::notifier::notify_wait; + use std::env::temp_dir; + use std::fs; + use std::path::PathBuf; + + fn temp_path(prefix: &str) -> PathBuf { + temp_dir().join(format!("{}_{}", prefix, mudu_sys::random::uuid_v4())) + } + + #[tokio::test] + async fn create_runtime_service_rejects_file_db_path() { + let package_path = wasm_mod_path(); + let db_file = temp_path("runtime_impl_db_file"); + fs::write(&db_file, b"not-a-directory").unwrap(); + + let err = match create_runtime_service( + &package_path, + &db_file.to_string_lossy().to_string(), + None, + RuntimeOpt::default(), + ) + .await + { + Ok(_) => panic!("expected invalid db path error"), + Err(err) => err, + }; + + assert!(err.to_string().contains("is not a directory")); + fs::remove_file(db_file).unwrap(); + } + + #[tokio::test] + async fn create_runtime_service_notifies_after_initialization() { + let package_dir = temp_path("runtime_impl_pkg_dir"); + let db_path = temp_path("runtime_impl_db_dir"); + fs::create_dir_all(&package_dir).unwrap(); + let (notifier, waiter) = notify_wait(); + + let runtime = create_runtime_service( + &package_dir.to_string_lossy().to_string(), + &db_path.to_string_lossy().to_string(), + Some(notifier), + RuntimeOpt::default(), + ) + .await + .unwrap(); + + waiter.wait().await; + assert!(runtime.list().await.is_empty()); + + let _ = fs::remove_dir_all(package_dir); + let _ = fs::remove_dir_all(db_path); + } +} diff --git a/mudu_runtime/src/service/test_runtime_simple.rs b/mudu_runtime/src/service/runtime_simple_test.rs similarity index 86% rename from mudu_runtime/src/service/test_runtime_simple.rs rename to mudu_runtime/src/service/runtime_simple_test.rs index 6f64b08..c5259bf 100644 --- a/mudu_runtime/src/service/test_runtime_simple.rs +++ b/mudu_runtime/src/service/runtime_simple_test.rs @@ -1,4 +1,3 @@ -#[allow(unused)] #[cfg(test)] mod tests { use crate::service::runtime::Runtime; @@ -24,18 +23,13 @@ mod tests { ProcSysCall, ProvSysCallAsync, } - /// - /// See proc function definition [proc](mudu_wasm/src/wasm/proc.rs#L5)。 - /// - //#[test] + + #[allow(dead_code)] fn test_proc() { test_runtime_simple(TestProc::Proc) } - /// - /// See proc_sys_call function definition [proc_sys_call](mudu_wasm/src/wasm/proc2.rs#L11)。 - /// - //#[test] + #[allow(dead_code)] fn test_proc_syscall() { test_runtime_simple(TestProc::ProcSysCall) } @@ -44,6 +38,7 @@ mod tests { fn test_async() { test_runtime_simple(TestProc::ProvSysCallAsync) } + fn test_runtime_simple(test_kind: TestProc) { log_setup_ex("debug", "", false); tokio::runtime::Builder::new_multi_thread() @@ -51,7 +46,7 @@ mod tests { .build() .unwrap() .block_on(async { - let r = test_async_runtime_simple(true, test_kind).await; + let r = test_async_runtime_simple(test_kind).await; println!("{:?}", r); }); } @@ -62,7 +57,7 @@ mod tests { path.to_str().unwrap().to_string() } - async fn test_async_runtime_simple(_enable_component: bool, test_kind: TestProc) -> RS<()> { + async fn test_async_runtime_simple(test_kind: TestProc) -> RS<()> { let pkg_path = wasm_mod_path(); let db_path = db_path(); let enable_async = @@ -113,8 +108,7 @@ mod tests { let proc_result = app .invoke(id, &"mod_0".to_string(), &"proc".to_string(), param, None) .await?; - let result = proc_result.to::<(i32, String)>(&<(i32, String)>::tuple_desc_static(&[]))?; - println!("result: {:?}", result); + let _result = proc_result.to::<(i32, String)>(&<(i32, String)>::tuple_desc_static(&[]))?; app.task_end(id)?; Ok(()) } @@ -139,13 +133,12 @@ mod tests { None, ) .await?; - let result = proc_result.to::<(i32, String)>(&<(i32, String)>::tuple_desc_static(&[]))?; - println!("result: {:?}", result); + let _result = proc_result.to::<(i32, String)>(&<(i32, String)>::tuple_desc_static(&[]))?; app.task_end(id)?; Ok(()) } - #[allow(unused)] + #[allow(dead_code)] async fn async_session_sys_call(service: Arc) -> RS<()> { println!("task id {}", this_task_id()); let tuple = (1i32, 100i64, "string argument".to_string()); @@ -166,8 +159,7 @@ mod tests { None, ) .await?; - let result = proc_result.to::<(i32, String)>(&<(i32, String)>::tuple_desc_static(&[]))?; - println!("result: {:?}", result); + let _result = proc_result.to::<(i32, String)>(&<(i32, String)>::tuple_desc_static(&[]))?; app.task_end(id)?; Ok(()) } diff --git a/mudu_runtime/src/service/wt_runtime_component_test.rs b/mudu_runtime/src/service/wt_runtime_component_test.rs new file mode 100644 index 0000000..bdcf87f --- /dev/null +++ b/mudu_runtime/src/service/wt_runtime_component_test.rs @@ -0,0 +1,79 @@ +#[cfg(test)] +mod tests { + use crate::service::mudu_package::MuduPackage; + use crate::service::runtime_opt::{ComponentTarget, RuntimeOpt}; + use crate::service::wt_runtime_component::WTRuntimeComponent; + use mudu::common::app_info::AppInfo; + use mudu_contract::procedure::mod_proc_desc::ModProcDesc; + use mudu_contract::procedure::proc_desc::ProcDesc; + use mudu_contract::tuple::tuple_datum::TupleDatum; + use std::collections::HashMap; + + fn test_proc_desc(module_name: &str, proc_name: &str) -> ProcDesc { + ProcDesc::new( + module_name.to_string(), + proc_name.to_string(), + <()>::tuple_desc_static(&[]), + <()>::tuple_desc_static(&[]), + false, + ) + } + + fn test_package(desc: ModProcDesc, modules: HashMap>) -> MuduPackage { + MuduPackage { + package_cfg: AppInfo { + name: "app".to_string(), + lang: "rust".to_string(), + version: "0.1.0".to_string(), + use_async: false, + }, + ddl_sql: "create table t(id int primary key);".to_string(), + package_desc: desc, + initdb_sql: String::new(), + modules, + } + } + + #[test] + fn instantiate_rejects_p3_target() { + let mut runtime = WTRuntimeComponent::build(&RuntimeOpt { + component_target: ComponentTarget::P3, + enable_async: false, + }) + .unwrap(); + + let err = runtime.instantiate().unwrap_err(); + assert!(err.to_string().contains("not implemented yet")); + } + + #[test] + fn compile_modules_requires_declared_module_bytes() { + let mut desc_map = HashMap::new(); + desc_map.insert("mod_0".to_string(), vec![test_proc_desc("mod_0", "proc")]); + let package = test_package(ModProcDesc::new(desc_map), HashMap::new()); + + let runtime = WTRuntimeComponent::build(&RuntimeOpt::default()).unwrap(); + let err = match runtime.compile_modules(&package) { + Ok(_) => panic!("expected missing module error"), + Err(err) => err, + }; + assert!(err.to_string().contains("no such module named mod_0")); + } + + #[test] + fn compile_modules_rejects_plain_wasm_module_for_component_runtime() { + let mut desc_map = HashMap::new(); + desc_map.insert("mod_0".to_string(), vec![test_proc_desc("mod_0", "proc")]); + let mut modules = HashMap::new(); + modules.insert("mod_0".to_string(), b"\0asm\x01\0\0\0".to_vec()); + let package = test_package(ModProcDesc::new(desc_map), modules); + + let mut runtime = WTRuntimeComponent::build(&RuntimeOpt::default()).unwrap(); + runtime.instantiate().unwrap(); + let err = match runtime.compile_modules(&package) { + Ok(_) => panic!("expected component validation error"), + Err(err) => err, + }; + assert!(err.to_string().contains("runtime target is component")); + } +} diff --git a/mudu_sys/src/net.rs b/mudu_sys/src/net.rs index 94dbd23..293c283 100644 --- a/mudu_sys/src/net.rs +++ b/mudu_sys/src/net.rs @@ -76,7 +76,9 @@ pub fn sockaddr_to_socket_addr(addr: &SockAddrBuf) -> RS { if addr.len() < std::mem::size_of::() { return Err(m_error!(EC::NetErr, "short sockaddr_in length")); } - let raw = unsafe { &*(addr.raw() as *const rliburing::sockaddr_storage as *const libc::sockaddr_in) }; + let raw = unsafe { + &*(addr.raw() as *const rliburing::sockaddr_storage as *const libc::sockaddr_in) + }; let ip = std::net::Ipv4Addr::from(u32::from_be(raw.sin_addr.s_addr).to_be_bytes()); Ok(SocketAddr::from((ip, u16::from_be(raw.sin_port)))) } @@ -84,8 +86,9 @@ pub fn sockaddr_to_socket_addr(addr: &SockAddrBuf) -> RS { if addr.len() < std::mem::size_of::() { return Err(m_error!(EC::NetErr, "short sockaddr_in6 length")); } - let raw = - unsafe { &*(addr.raw() as *const rliburing::sockaddr_storage as *const libc::sockaddr_in6) }; + let raw = unsafe { + &*(addr.raw() as *const rliburing::sockaddr_storage as *const libc::sockaddr_in6) + }; let ip = std::net::Ipv6Addr::from(raw.sin6_addr.s6_addr); Ok(SocketAddr::from((ip, u16::from_be(raw.sin6_port)))) } diff --git a/mudu_sys/src/uring.rs b/mudu_sys/src/uring.rs index 57bed22..2cca839 100644 --- a/mudu_sys/src/uring.rs +++ b/mudu_sys/src/uring.rs @@ -31,7 +31,8 @@ mod linux { pub fn new(entries: u32) -> Result { let mut raw = unsafe { std::mem::zeroed() }; let mut param = unsafe { std::mem::zeroed() }; - let rc = unsafe { rliburing::io_uring_queue_init_params(entries, &mut raw, &mut param) }; + let rc = + unsafe { rliburing::io_uring_queue_init_params(entries, &mut raw, &mut param) }; if rc != 0 { return Err(rc); } @@ -65,8 +66,9 @@ mod linux { tv_sec: timeout.as_secs() as i64, tv_nsec: timeout.subsec_nanos() as i64, }; - let rc = - unsafe { rliburing::io_uring_wait_cqe_timeout(&mut self.raw, &mut cqe_ptr, &mut ts) }; + let rc = unsafe { + rliburing::io_uring_wait_cqe_timeout(&mut self.raw, &mut cqe_ptr, &mut ts) + }; if rc < 0 { return Err(rc); } @@ -138,25 +140,13 @@ mod linux { pub fn prep_read_raw(&mut self, fd: RawFd, buf: *mut u8, len: usize, offset: u64) { unsafe { - rliburing::io_uring_prep_read( - self.raw, - fd, - buf.cast(), - len as _, - offset as _, - ); + rliburing::io_uring_prep_read(self.raw, fd, buf.cast(), len as _, offset as _); } } pub fn prep_write_raw(&mut self, fd: RawFd, buf: *const u8, len: usize, offset: u64) { unsafe { - rliburing::io_uring_prep_write( - self.raw, - fd, - buf.cast(), - len as _, - offset as _, - ); + rliburing::io_uring_prep_write(self.raw, fd, buf.cast(), len as _, offset as _); } } diff --git a/mudu_transpiler/src/rust/rust_type.rs b/mudu_transpiler/src/rust/rust_type.rs index 0a503d2..d30a5f4 100644 --- a/mudu_transpiler/src/rust/rust_type.rs +++ b/mudu_transpiler/src/rust/rust_type.rs @@ -106,11 +106,14 @@ impl RustType { RustType::Primitive(s) => match s.as_str() { "i32" => DatType::default_for(DatTypeID::I32), "i64" => DatType::default_for(DatTypeID::I64), + "i128" => DatType::default_for(DatTypeID::I128), + "u128" => DatType::default_for(DatTypeID::U128), "f32" => DatType::default_for(DatTypeID::F32), "f64" => DatType::default_for(DatTypeID::F64), _ => return Err(m_error!(EC::TypeErr, format!("not support type {}", s))), }, RustType::Custom(s) => match s.as_str() { + "OID" => DatType::default_for(DatTypeID::U128), "String" => DatType::default_for(DatTypeID::String), _ => { let ty = custom_types.types.get(s).map_or_else( diff --git a/mudu_type/src/dat_type_id.rs b/mudu_type/src/dat_type_id.rs index 96dc6fe..4be3f94 100644 --- a/mudu_type/src/dat_type_id.rs +++ b/mudu_type/src/dat_type_id.rs @@ -36,6 +36,8 @@ pub enum DatTypeID { F32 = 2, F64 = 3, String = 4, + U128 = 5, + I128 = 6, // Complex types (start after primitive range) Array = PRIMITIVE_ID_MAX + 1, @@ -176,7 +178,12 @@ impl DatTypeID { pub fn has_param(&self) -> bool { match self { - DatTypeID::I32 | DatTypeID::I64 | DatTypeID::F32 | DatTypeID::F64 => false, + DatTypeID::I32 + | DatTypeID::I64 + | DatTypeID::I128 + | DatTypeID::F32 + | DatTypeID::F64 + | DatTypeID::U128 => false, _ => true, } } diff --git a/mudu_type/src/dat_typed.rs b/mudu_type/src/dat_typed.rs index d435928..37b1b12 100644 --- a/mudu_type/src/dat_typed.rs +++ b/mudu_type/src/dat_typed.rs @@ -23,6 +23,20 @@ impl DatTyped { ) } + pub fn from_i128(val: i128) -> Self { + Self::new( + DatType::default_for(DatTypeID::I128), + DatValue::from_i128(val), + ) + } + + pub fn from_oid(val: u128) -> Self { + Self::new( + DatType::default_for(DatTypeID::U128), + DatValue::from_u128(val), + ) + } + pub fn from_f32(val: f32) -> Self { Self::new( DatType::default_for(DatTypeID::F32), diff --git a/mudu_type/src/dat_value.rs b/mudu_type/src/dat_value.rs index ba8d84e..5faa505 100644 --- a/mudu_type/src/dat_value.rs +++ b/mudu_type/src/dat_value.rs @@ -34,6 +34,8 @@ enum ValueKind { F64(f64), I32(i32), I64(i64), + I128(i128), + U128(u128), String(String), Record(Vec), Array(Vec), @@ -186,6 +188,14 @@ impl DatValue { pub fn to_i64(&self) -> i64 { self.expect_i64().clone() } + + pub fn to_i128(&self) -> i128 { + self.expect_i128().clone() + } + + pub fn to_oid(&self) -> u128 { + self.expect_u128().clone() + } } /// Safe wrapper for unsafe pointer casting between types @@ -210,6 +220,8 @@ unsafe impl Sync for ValueKind {} impl_dat_value_methods! { (i32, I32, i32), (i64, I64, i64), + (i128, I128, i128), + (u128, U128, u128), (f32, F32, f32), (f64, F64, f64), (String, String, string), diff --git a/mudu_type/src/datum.rs b/mudu_type/src/datum.rs index 7e23c30..197a1ac 100644 --- a/mudu_type/src/datum.rs +++ b/mudu_type/src/datum.rs @@ -253,6 +253,8 @@ macro_rules! impl_datum_trait { impl_datum_trait!( (I32, i32, i32), (I64, i64, i64), + (I128, i128, i128), + (U128, u128, u128), (F32, f32, f32), (F64, f64, f64), (String, string, String) diff --git a/mudu_type/src/dt_impl/dat_table.rs b/mudu_type/src/dt_impl/dat_table.rs index ad39cdb..eab8d8f 100644 --- a/mudu_type/src/dt_impl/dat_table.rs +++ b/mudu_type/src/dt_impl/dat_table.rs @@ -78,6 +78,26 @@ lazy_static! { fixed_length: None, opt_fn_param: Some(dt_impl::fn_string_param::FN_CHAR_FIXED_PARAM), }, + DatTypeDef { + id: DatTypeID::U128, + type_name: "oid".to_string(), + fn_base: dt_impl::fn_u128::FN_OID_CONVERT, + opt_fn_compare: Some(dt_impl::fn_u128::FN_OID_COMPARE), + #[cfg(any(test, feature = "test"))] + fn_arbitrary: dt_impl::fn_u128_arb::FN_OID_ARBITRARY, + fixed_length: Some(size_of::() as u32), + opt_fn_param: None, + }, + DatTypeDef { + id: DatTypeID::I128, + type_name: "i128".to_string(), + fn_base: dt_impl::fn_i128::FN_I128_CONVERT, + opt_fn_compare: Some(dt_impl::fn_i128::FN_I128_COMPARE), + #[cfg(any(test, feature = "test"))] + fn_arbitrary: dt_impl::fn_i128_arb::FN_I128_ARBITRARY, + fixed_length: Some(size_of::() as u32), + opt_fn_param: None, + }, DatTypeDef { id: DatTypeID::Array, type_name: "array".to_string(), diff --git a/mudu_type/src/dt_impl/fn_i128.rs b/mudu_type/src/dt_impl/fn_i128.rs new file mode 100644 index 0000000..1e1c310 --- /dev/null +++ b/mudu_type/src/dt_impl/fn_i128.rs @@ -0,0 +1,153 @@ +use crate::dat_binary::DatBinary; +use crate::dat_json::DatJson; +use crate::dat_textual::DatTextual; +use crate::dat_type::DatType; +use crate::dat_value::DatValue; +use crate::dt_fn_compare::{ErrCompare, FnCompare}; +use crate::dt_fn_convert::FnBase; +use crate::type_error::{TyEC, TyErr}; +use byteorder::ByteOrder; +use mudu::common::endian::Endian; +use mudu::utils::json::{JsonValue, from_json_str}; +use mudu::utils::msg_pack::{MsgPackUtf8String, MsgPackValue}; +use std::cmp::Ordering; +use std::hash::Hasher; +use std::str::FromStr; + +fn parse_i128_str(value: &str) -> Result { + i128::from_str(value).map_err(|e| TyErr::new(TyEC::TypeConvertFailed, e.to_string())) +} + +fn parse_i128_json(value: &JsonValue) -> Result { + if let Some(s) = value.as_str() { + return parse_i128_str(s); + } + if let Some(n) = value.as_i64() { + return Ok(n as i128); + } + if let Some(n) = value.as_u64() { + return Ok(n as i128); + } + Err(TyErr::new( + TyEC::TypeConvertFailed, + format!("cannot convert json {} to i128", value), + )) +} + +fn fn_i128_in_textual(v: &str, dt: &DatType) -> Result { + let json = from_json_str::(v) + .map_err(|e| TyErr::new(TyEC::TypeConvertFailed, e.to_string()))?; + fn_i128_in_json(&json, dt) +} + +fn fn_i128_out_textual(v: &DatValue, dt: &DatType) -> Result { + let json = fn_i128_out_json(v, dt)?; + Ok(DatTextual::from(json.to_string())) +} + +fn fn_i128_in_json(v: &JsonValue, _: &DatType) -> Result { + Ok(DatValue::from_i128(parse_i128_json(v)?)) +} + +fn fn_i128_out_json(v: &DatValue, _: &DatType) -> Result { + Ok(DatJson::from(JsonValue::String(v.to_i128().to_string()))) +} + +fn fn_i128_in_msgpack(msg_pack: &MsgPackValue, _: &DatType) -> Result { + if let Some(s) = msg_pack.as_str() { + return Ok(DatValue::from_i128(parse_i128_str(s)?)); + } + if let Some(n) = msg_pack.as_i64() { + return Ok(DatValue::from_i128(n as i128)); + } + if let Some(n) = msg_pack.as_u64() { + return Ok(DatValue::from_i128(n as i128)); + } + Err(TyErr::new( + TyEC::TypeConvertFailed, + "cannot convert msg pack to i128".to_string(), + )) +} + +fn fn_i128_out_msgpack(v: &DatValue, _: &DatType) -> Result { + Ok(MsgPackValue::String(MsgPackUtf8String::from( + v.to_i128().to_string(), + ))) +} + +fn fn_i128_len(_: &DatType) -> Result, TyErr> { + Ok(Some(size_of::() as u32)) +} + +fn fn_i128_dat_output_len(_: &DatValue, ty: &DatType) -> Result { + Ok(fn_i128_len(ty)?.unwrap()) +} + +fn fn_i128_send(v: &DatValue, _: &DatType) -> Result { + let value = v.to_i128(); + let mut buf = vec![0; size_of::()]; + Endian::write_i128(&mut buf, value); + Ok(DatBinary::from(buf)) +} + +fn fn_i128_send_to(v: &DatValue, _: &DatType, buf: &mut [u8]) -> Result { + if buf.len() < size_of::() { + return Err(TyErr::new( + TyEC::InsufficientSpace, + "insufficient space".to_string(), + )); + } + Endian::write_i128(buf, v.to_i128()); + Ok(size_of::() as u32) +} + +fn fn_i128_recv(buf: &[u8], _: &DatType) -> Result<(DatValue, u32), TyErr> { + if buf.len() < size_of::() { + return Err(TyErr::new( + TyEC::InsufficientSpace, + "insufficient space".to_string(), + )); + } + Ok(( + DatValue::from_i128(Endian::read_i128(buf)), + size_of::() as u32, + )) +} + +fn fn_i128_default(_: &DatType) -> Result { + Ok(DatValue::from_i128(i128::default())) +} + +fn fn_i128_order(v1: &DatValue, v2: &DatValue) -> Result { + Ok(v1.to_i128().cmp(&v2.to_i128())) +} + +fn fn_i128_equal(v1: &DatValue, v2: &DatValue) -> Result { + Ok(v1.to_i128() == v2.to_i128()) +} + +fn fn_i128_hash(v: &DatValue, hasher: &mut dyn Hasher) -> Result<(), ErrCompare> { + hasher.write_i128(v.to_i128()); + Ok(()) +} + +pub const FN_I128_COMPARE: FnCompare = FnCompare { + order: fn_i128_order, + equal: fn_i128_equal, + hash: fn_i128_hash, +}; + +pub const FN_I128_CONVERT: FnBase = FnBase { + input_textual: fn_i128_in_textual, + output_textual: fn_i128_out_textual, + input_json: fn_i128_in_json, + output_json: fn_i128_out_json, + input_msg_pack: fn_i128_in_msgpack, + output_msg_pack: fn_i128_out_msgpack, + type_len: fn_i128_len, + data_len: fn_i128_dat_output_len, + receive: fn_i128_recv, + send: fn_i128_send, + send_to: fn_i128_send_to, + default: fn_i128_default, +}; diff --git a/mudu_type/src/dt_impl/fn_i128_arb.rs b/mudu_type/src/dt_impl/fn_i128_arb.rs new file mode 100644 index 0000000..7a633d8 --- /dev/null +++ b/mudu_type/src/dt_impl/fn_i128_arb.rs @@ -0,0 +1,23 @@ +use crate::dat_type::DatType; +use crate::dat_type_id::DatTypeID; +use crate::dat_value::DatValue; +use crate::dt_fn_arbitrary::FnArbitrary; +use arbitrary::{Arbitrary, Unstructured}; + +pub fn fn_i128_arb_val(u: &mut Unstructured, _: &DatType) -> arbitrary::Result { + Ok(DatValue::from_i128(i128::arbitrary(u)?)) +} + +pub fn fn_i128_arb_printable(u: &mut Unstructured, _: &DatType) -> arbitrary::Result { + Ok(i128::arbitrary(u)?.to_string()) +} + +pub fn fn_i128_arb_dt_param(_u: &mut Unstructured) -> arbitrary::Result { + Ok(DatType::new_no_param(DatTypeID::I128)) +} + +pub const FN_I128_ARBITRARY: FnArbitrary = FnArbitrary { + param: fn_i128_arb_dt_param, + value_object: fn_i128_arb_val, + value_print: fn_i128_arb_printable, +}; diff --git a/mudu_type/src/dt_impl/fn_u128.rs b/mudu_type/src/dt_impl/fn_u128.rs new file mode 100644 index 0000000..e0201af --- /dev/null +++ b/mudu_type/src/dt_impl/fn_u128.rs @@ -0,0 +1,146 @@ +use crate::dat_binary::DatBinary; +use crate::dat_json::DatJson; +use crate::dat_textual::DatTextual; +use crate::dat_type::DatType; +use crate::dat_value::DatValue; +use crate::dt_fn_compare::{ErrCompare, FnCompare}; +use crate::dt_fn_convert::FnBase; +use crate::type_error::{TyEC, TyErr}; +use mudu::common::endian; +use mudu::utils::json::{JsonValue, from_json_str}; +use mudu::utils::msg_pack::{MsgPackUtf8String, MsgPackValue}; +use std::cmp::Ordering; +use std::hash::Hasher; +use std::str::FromStr; + +fn parse_u128_str(value: &str) -> Result { + u128::from_str(value).map_err(|e| TyErr::new(TyEC::TypeConvertFailed, e.to_string())) +} + +fn parse_u128_json(value: &JsonValue) -> Result { + if let Some(s) = value.as_str() { + return parse_u128_str(s); + } + if let Some(n) = value.as_u64() { + return Ok(n as u128); + } + Err(TyErr::new( + TyEC::TypeConvertFailed, + format!("cannot convert json {} to oid", value), + )) +} + +fn fn_u128_in_textual(v: &str, dt: &DatType) -> Result { + let json = from_json_str::(v) + .map_err(|e| TyErr::new(TyEC::TypeConvertFailed, e.to_string()))?; + fn_u128_in_json(&json, dt) +} + +fn fn_u128_out_textual(v: &DatValue, dt: &DatType) -> Result { + let json = fn_u128_out_json(v, dt)?; + Ok(DatTextual::from(json.to_string())) +} + +fn fn_u128_in_json(v: &JsonValue, _: &DatType) -> Result { + Ok(DatValue::from_u128(parse_u128_json(v)?)) +} + +fn fn_u128_out_json(v: &DatValue, _: &DatType) -> Result { + Ok(DatJson::from(JsonValue::String(v.to_oid().to_string()))) +} + +fn fn_u128_in_msgpack(msg_pack: &MsgPackValue, _: &DatType) -> Result { + if let Some(s) = msg_pack.as_str() { + return Ok(DatValue::from_u128(parse_u128_str(s)?)); + } + if let Some(n) = msg_pack.as_u64() { + return Ok(DatValue::from_u128(n as u128)); + } + Err(TyErr::new( + TyEC::TypeConvertFailed, + "cannot convert msg pack to oid".to_string(), + )) +} + +fn fn_u128_out_msgpack(v: &DatValue, _: &DatType) -> Result { + Ok(MsgPackValue::String(MsgPackUtf8String::from( + v.to_oid().to_string(), + ))) +} + +fn fn_u128_len(_: &DatType) -> Result, TyErr> { + Ok(Some(size_of::() as u32)) +} + +fn fn_u128_dat_output_len(_: &DatValue, ty: &DatType) -> Result { + Ok(fn_u128_len(ty)?.unwrap()) +} + +fn fn_u128_send(v: &DatValue, _: &DatType) -> Result { + let oid = v.to_oid(); + let mut buf = vec![0; size_of::()]; + endian::write_u128(&mut buf, oid); + Ok(DatBinary::from(buf)) +} + +fn fn_u128_send_to(v: &DatValue, _: &DatType, buf: &mut [u8]) -> Result { + if buf.len() < size_of::() { + return Err(TyErr::new( + TyEC::InsufficientSpace, + "insufficient space".to_string(), + )); + } + endian::write_u128(buf, v.to_oid()); + Ok(size_of::() as u32) +} + +fn fn_u128_recv(buf: &[u8], _: &DatType) -> Result<(DatValue, u32), TyErr> { + if buf.len() < size_of::() { + return Err(TyErr::new( + TyEC::InsufficientSpace, + "insufficient space".to_string(), + )); + } + Ok(( + DatValue::from_u128(endian::read_u128(buf)), + size_of::() as u32, + )) +} + +fn fn_u128_default(_: &DatType) -> Result { + Ok(DatValue::from_u128(u128::default())) +} + +fn fn_u128_order(v1: &DatValue, v2: &DatValue) -> Result { + Ok(v1.to_oid().cmp(&v2.to_oid())) +} + +fn fn_u128_equal(v1: &DatValue, v2: &DatValue) -> Result { + Ok(v1.to_oid() == v2.to_oid()) +} + +fn fn_u128_hash(v: &DatValue, hasher: &mut dyn Hasher) -> Result<(), ErrCompare> { + hasher.write_u128(v.to_oid()); + Ok(()) +} + +pub const FN_OID_COMPARE: FnCompare = FnCompare { + order: fn_u128_order, + equal: fn_u128_equal, + hash: fn_u128_hash, +}; + +pub const FN_OID_CONVERT: FnBase = FnBase { + input_textual: fn_u128_in_textual, + output_textual: fn_u128_out_textual, + input_json: fn_u128_in_json, + output_json: fn_u128_out_json, + input_msg_pack: fn_u128_in_msgpack, + output_msg_pack: fn_u128_out_msgpack, + type_len: fn_u128_len, + data_len: fn_u128_dat_output_len, + receive: fn_u128_recv, + send: fn_u128_send, + send_to: fn_u128_send_to, + default: fn_u128_default, +}; diff --git a/mudu_type/src/dt_impl/fn_u128_arb.rs b/mudu_type/src/dt_impl/fn_u128_arb.rs new file mode 100644 index 0000000..320b8bf --- /dev/null +++ b/mudu_type/src/dt_impl/fn_u128_arb.rs @@ -0,0 +1,23 @@ +use crate::dat_type::DatType; +use crate::dat_type_id::DatTypeID; +use crate::dat_value::DatValue; +use crate::dt_fn_arbitrary::FnArbitrary; +use arbitrary::{Arbitrary, Unstructured}; + +pub fn fn_u128_arb_val(u: &mut Unstructured, _: &DatType) -> arbitrary::Result { + Ok(DatValue::from_u128(u128::arbitrary(u)?)) +} + +pub fn fn_u128_arb_printable(u: &mut Unstructured, _: &DatType) -> arbitrary::Result { + Ok(u128::arbitrary(u)?.to_string()) +} + +pub fn fn_u128_arb_dt_param(_u: &mut Unstructured) -> arbitrary::Result { + Ok(DatType::new_no_param(DatTypeID::U128)) +} + +pub const FN_OID_ARBITRARY: FnArbitrary = FnArbitrary { + param: fn_u128_arb_dt_param, + value_object: fn_u128_arb_val, + value_print: fn_u128_arb_printable, +}; diff --git a/mudu_type/src/dt_impl/lang/rust.rs b/mudu_type/src/dt_impl/lang/rust.rs index 2f49760..00e73c2 100644 --- a/mudu_type/src/dt_impl/lang/rust.rs +++ b/mudu_type/src/dt_impl/lang/rust.rs @@ -7,6 +7,8 @@ lazy_static! { static ref _id_lang_type_name: Vec<(DatTypeID, &'static str)> = vec![ (DatTypeID::I32, "i32"), (DatTypeID::I64, "i64"), + (DatTypeID::I128, "i128"), + (DatTypeID::U128, "OID"), (DatTypeID::F32, "f32"), (DatTypeID::F64, "f64"), (DatTypeID::String, "String"), @@ -16,8 +18,11 @@ lazy_static! { ]; static ref _id2name: HashMap = dat_type_id_2_lang_type_name(&_id_lang_type_name); - static ref _name2id: HashMap)> = - lang_type_name_2_dat_type_id(&_id_lang_type_name); + static ref _name2id: HashMap)> = { + let mut map = lang_type_name_2_dat_type_id(&_id_lang_type_name); + map.insert("u128".to_string(), (DatTypeID::U128, Default::default())); + map + }; } pub fn dt_lang_name_to_id(name: &str) -> Option<(DatTypeID, Vec)> { diff --git a/mudu_type/src/dt_impl/mod.rs b/mudu_type/src/dt_impl/mod.rs index 882c87e..fac7cff 100644 --- a/mudu_type/src/dt_impl/mod.rs +++ b/mudu_type/src/dt_impl/mod.rs @@ -4,10 +4,12 @@ pub mod lang; mod fn_f32; mod fn_f64; +mod fn_i128; mod fn_i32; mod fn_i64; mod fn_string; mod fn_string_param; +mod fn_u128; mod fn_array; #[cfg(any(test, feature = "test"))] @@ -21,6 +23,8 @@ mod fn_f32_arb; #[cfg(any(test, feature = "test"))] mod fn_f64_arb; #[cfg(any(test, feature = "test"))] +mod fn_i128_arb; +#[cfg(any(test, feature = "test"))] mod fn_i32_arb; #[cfg(any(test, feature = "test"))] mod fn_i64_arb; @@ -30,3 +34,5 @@ mod fn_object_arb; mod fn_object_param; #[cfg(any(test, feature = "test"))] mod fn_string_arb; +#[cfg(any(test, feature = "test"))] +mod fn_u128_arb; diff --git a/sql_parser/src/ast/column_def.rs b/sql_parser/src/ast/column_def.rs index ef0a5b3..8914cac 100644 --- a/sql_parser/src/ast/column_def.rs +++ b/sql_parser/src/ast/column_def.rs @@ -1,3 +1,4 @@ +use mudu::common::id::AttrIndex; use mudu_binding::universal::uni_dat_type::UniDatType; use mudu_binding::universal::uni_dat_value::UniDatValue; @@ -6,8 +7,8 @@ pub struct ColumnDef { column_name: String, data_type_def: UniDatType, data_type_param: Option>, - is_primary_key: bool, - index: u32, + opt_primary_key_index: Option, + index: AttrIndex, } impl ColumnDef { @@ -15,14 +16,13 @@ impl ColumnDef { column_name: String, data_type_def: UniDatType, data_type_param: Option>, - is_primary_key: bool, ) -> Self { Self { column_name, data_type_def, data_type_param, - is_primary_key, - index: u32::MAX, + opt_primary_key_index: None, + index: AttrIndex::MAX, } } @@ -35,23 +35,31 @@ impl ColumnDef { } pub fn is_primary_key(&self) -> bool { - self.is_primary_key + self.opt_primary_key_index.is_some() } pub fn column_name(&self) -> &String { &self.column_name } - pub fn set_primary_key(&mut self, is_primary: bool) { - self.is_primary_key = is_primary; + pub fn primary_key_index(&self) -> Option { + self.opt_primary_key_index } - pub fn set_index(&mut self, index: u32) { + pub fn expect_primary_key_index(&self) -> AttrIndex { + self.opt_primary_key_index.unwrap() + } + + pub fn set_primary_key_index(&mut self, index: Option) { + self.opt_primary_key_index = index; + } + + pub fn set_index(&mut self, index: AttrIndex) { self.index = index; } // column index in table schema - pub fn column_index(&self) -> u32 { + pub fn column_index(&self) -> AttrIndex { self.index } } diff --git a/sql_parser/src/ast/expr_arithmetic.rs b/sql_parser/src/ast/expr_arithmetic.rs index 440516c..2ec5e91 100644 --- a/sql_parser/src/ast/expr_arithmetic.rs +++ b/sql_parser/src/ast/expr_arithmetic.rs @@ -30,7 +30,7 @@ impl ExprArithmetic { impl Debug for ExprArithmetic { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "op: ")?; + write!(f, "arithmetic op: ")?; self.op.fmt(f)?; write!(f, "left: ")?; self.left.fmt(f)?; diff --git a/sql_parser/src/ast/expr_compare.rs b/sql_parser/src/ast/expr_compare.rs index 7c93a4e..a5a44b7 100644 --- a/sql_parser/src/ast/expr_compare.rs +++ b/sql_parser/src/ast/expr_compare.rs @@ -47,14 +47,7 @@ impl ExprCompare { // fn revert_cmp_op(op: ValueCompare) -> ValueCompare { - match op { - ValueCompare::EQ => ValueCompare::EQ, - ValueCompare::LE => ValueCompare::GT, - ValueCompare::LT => ValueCompare::GE, - ValueCompare::GE => ValueCompare::LT, - ValueCompare::GT => ValueCompare::LE, - ValueCompare::NE => ValueCompare::NE, - } + ValueCompare::revert_cmp_op(op) } } diff --git a/sql_parser/src/ast/expr_operator.rs b/sql_parser/src/ast/expr_operator.rs index b6708d0..491d212 100644 --- a/sql_parser/src/ast/expr_operator.rs +++ b/sql_parser/src/ast/expr_operator.rs @@ -78,3 +78,17 @@ impl Operator { } } } + + +impl ValueCompare { + pub fn revert_cmp_op(op: ValueCompare) -> ValueCompare { + match op { + ValueCompare::EQ => ValueCompare::EQ, + ValueCompare::LE => ValueCompare::GT, + ValueCompare::LT => ValueCompare::GE, + ValueCompare::GE => ValueCompare::LT, + ValueCompare::GT => ValueCompare::LE, + ValueCompare::NE => ValueCompare::NE, + } + } +} \ No newline at end of file diff --git a/sql_parser/src/ast/mod.rs b/sql_parser/src/ast/mod.rs index a1484d9..be82a58 100644 --- a/sql_parser/src/ast/mod.rs +++ b/sql_parser/src/ast/mod.rs @@ -18,6 +18,8 @@ pub mod stmt_delete; pub mod column_def; mod expr_arithmetic; +#[cfg(test)] +mod parser_test; pub mod stmt_copy_from; pub mod stmt_copy_to; pub mod stmt_drop; @@ -27,5 +29,4 @@ pub mod stmt_list; pub mod stmt_select; pub mod stmt_type; pub mod stmt_update; -mod test_parser; pub mod type_declare; diff --git a/sql_parser/src/ast/parser.rs b/sql_parser/src/ast/parser.rs index 399675a..9bd8786 100644 --- a/sql_parser/src/ast/parser.rs +++ b/sql_parser/src/ast/parser.rs @@ -24,6 +24,7 @@ use crate::ast::stmt_select::StmtSelect; use crate::ast::stmt_type::{StmtCommand, StmtType}; use crate::ast::stmt_update::{AssignedValue, Assignment, StmtUpdate}; use crate::ts_const::{ts_field_name, ts_kind_id}; +use mudu::common::id::AttrIndex; use mudu::error::err::MError; use mudu::m_error; use mudu_binding::universal::uni_dat_type::UniDatType; @@ -731,8 +732,7 @@ impl SQLParser { let mut index = 0; let mut f = |name: String| { if let Some(n) = map.get_mut(&name) { - n.set_primary_key(true); - n.set_index(index); + n.set_primary_key_index(Some(index)); index += 1; Ok(()) } else { @@ -756,23 +756,33 @@ impl SQLParser { let opt_n = node.child_by_field_name(ts_field_name::DATA_TYPE); let n_data_type = rs_option(opt_n, "")?; let (dat_type, opt_type_params) = self.visit_data_type(context, n_data_type)?; - let mut column_def = ColumnDef::new(column_name, dat_type, opt_type_params, false); + let mut column_def = ColumnDef::new(column_name, dat_type, opt_type_params); let mut cursor = node.walk(); let iter = node.children_by_field_name(ts_field_name::COLUMN_CONSTRAINT, &mut cursor); + let mut index_map = HashMap::new(); for n in iter { - self.visit_column_constraint(n, &mut column_def)?; + self.visit_column_constraint(n, &mut column_def, &mut index_map)?; } stmt.add_column_def(column_def); Ok(()) } - fn visit_column_constraint(&self, node: Node, column_def: &mut ColumnDef) -> RS<()> { + fn visit_column_constraint( + &self, + node: Node, + column_def: &mut ColumnDef, + index_map: &mut HashMap, + ) -> RS<()> { if node .child_by_field_name(ts_field_name::PRIMARY_KEY) .is_some() { - column_def.set_primary_key(true); + let next_index = index_map + .entry(ts_field_name::PRIMARY_KEY.to_string()) + .or_insert(0); + column_def.set_primary_key_index(Some(*next_index)); + *next_index += 1; } Ok(()) } diff --git a/sql_parser/src/ast/parser_test.rs b/sql_parser/src/ast/parser_test.rs new file mode 100644 index 0000000..65dbecb --- /dev/null +++ b/sql_parser/src/ast/parser_test.rs @@ -0,0 +1,299 @@ +#[cfg(test)] +mod tests { + use crate::ast::expr_item::{ExprItem, ExprValue}; + use crate::ast::expr_operator::{Arithmetic, ValueCompare}; + use crate::ast::expression::ExprType; + use crate::ast::parser::SQLParser; + use crate::ast::stmt_create_table::StmtCreateTable; + use crate::ast::stmt_type::{StmtCommand, StmtType}; + use crate::ast::stmt_update::AssignedValue; + use mudu::common::result::RS; + use project_root::get_project_root; + use std::fs; + use std::path::Path; + + fn parse_sql(sql: &str) -> RS> { + let parser = SQLParser::new(); + Ok(parser.parse(sql)?.stmts().clone()) + } + + fn parse_create_table(sql: &str) -> RS { + let stmts = parse_sql(sql)?; + let stmt = stmts.first().ok_or_else(|| { + mudu::m_error!(mudu::error::ec::EC::ParseErr, "expected one statement") + })?; + match stmt { + StmtType::Command(StmtCommand::CreateTable(stmt)) => Ok(stmt.clone()), + _ => Err(mudu::m_error!( + mudu::error::ec::EC::ParseErr, + "expected create table statement" + )), + } + } + + fn parse_file>(path: P) -> RS> { + let sql = fs::read_to_string(path).unwrap(); + parse_sql(&sql) + } + + #[test] + fn parse_select_where_extracts_compare_predicates() { + let stmts = + parse_sql("select id, name from users where id = 1 AND name = 'alice';").unwrap(); + + let StmtType::Select(stmt) = &stmts[0] else { + panic!("expected select"); + }; + assert_eq!(stmt.get_table_reference(), "users"); + assert_eq!(stmt.get_select_term_list().len(), 2); + assert_eq!(stmt.get_where_predicate().len(), 2); + assert!(matches!( + stmt.get_where_predicate()[0].op(), + ValueCompare::EQ + )); + assert!(matches!( + stmt.get_where_predicate()[1].op(), + ValueCompare::EQ + )); + } + + #[test] + fn parse_select_with_placeholder_keeps_value_placeholder() { + let stmts = parse_sql("select id from users where id = ?;").unwrap(); + + let StmtType::Select(stmt) = &stmts[0] else { + panic!("expected select"); + }; + let predicate = &stmt.get_where_predicate()[0]; + match predicate.right() { + ExprItem::ItemValue(ExprValue::ValuePlaceholder) => {} + other => panic!("expected placeholder, got {other:?}"), + } + } + + #[test] + fn parse_insert_without_column_list() { + let stmts = parse_sql("insert into users values (1, 'alice');").unwrap(); + + let StmtType::Command(StmtCommand::Insert(stmt)) = &stmts[0] else { + panic!("expected insert"); + }; + assert_eq!(stmt.table_name(), "users"); + assert!(stmt.columns().is_empty()); + assert_eq!(stmt.values_list().len(), 1); + assert_eq!(stmt.values_list()[0].len(), 2); + } + + #[test] + fn parse_multi_row_insert_keeps_each_row() { + let stmts = + parse_sql("insert into users (id, name) values (1, 'alice'), (2, 'bob');").unwrap(); + + let StmtType::Command(StmtCommand::Insert(stmt)) = &stmts[0] else { + panic!("expected insert"); + }; + assert_eq!(stmt.columns(), &vec!["id".to_string(), "name".to_string()]); + assert_eq!(stmt.values_list().len(), 2); + } + + #[test] + fn parse_update_distinguishes_value_and_expression_assignments() { + let stmts = + parse_sql("update users set count = 1, total = count + 1 where id = 1;").unwrap(); + + let StmtType::Command(StmtCommand::Update(stmt)) = &stmts[0] else { + panic!("expected update"); + }; + assert_eq!(stmt.get_set_values().len(), 2); + assert!(matches!( + stmt.get_set_values()[0].get_set_value(), + AssignedValue::Value(_) + )); + match stmt.get_set_values()[1].get_set_value() { + AssignedValue::Expression(ExprType::Arithmetic(expr)) => { + assert!(matches!(expr.op(), Arithmetic::PLUS)); + } + other => panic!("expected arithmetic assignment, got {other:?}"), + } + assert_eq!(stmt.get_where_predicate().len(), 1); + } + + #[test] + fn parse_delete_with_and_predicates() { + let stmts = parse_sql("delete from users where id = 1 AND name = 'alice';").unwrap(); + + let StmtType::Command(StmtCommand::Delete(stmt)) = &stmts[0] else { + panic!("expected delete"); + }; + assert_eq!(stmt.get_table_reference(), "users"); + assert_eq!(stmt.get_where_predicate().len(), 2); + } + + #[test] + fn parse_drop_table_if_exists() { + let stmts = parse_sql("drop table if exists users;").unwrap(); + + let StmtType::Command(StmtCommand::DropTable(stmt)) = &stmts[0] else { + panic!("expected drop table"); + }; + assert_eq!(stmt.table_name(), "users"); + assert!(stmt.drop_if_exists()); + } + + #[test] + fn parse_copy_from_statement() { + let stmts = parse_sql("copy users from 'users.csv';").unwrap(); + + let StmtType::Command(StmtCommand::CopyFrom(stmt)) = &stmts[0] else { + panic!("expected copy from"); + }; + assert_eq!(stmt.copy_to_table_name(), "users"); + assert_eq!(stmt.copy_from_file_path(), "'users.csv'"); + assert!(stmt.table_columns().is_empty()); + } + + #[test] + fn parse_invalid_sql_reports_syntax_context() { + let err = parse_sql("select from users where").unwrap_err(); + let text = err.to_string(); + assert!(text.contains("Syntax error")); + assert!(text.contains("select from users where")); + assert!(text.contains("position")); + } + + #[test] + fn parse_update_without_where_returns_error() { + let err = parse_sql("update users set id = 1;").unwrap_err(); + assert!(err + .to_string() + .contains("no where clause in update statement")); + } + + #[test] + fn parse_delete_without_where_returns_error() { + let err = parse_sql("delete from users;").unwrap_err(); + let text = err.to_string(); + assert!( + text.contains("Syntax error") || text.contains("no where clause in delete statement") + ); + } + + #[test] + fn parse_insert_without_values_returns_error() { + let err = parse_sql("insert into users (id, name);").unwrap_err(); + assert!(err.to_string().contains("Syntax error")); + } + + #[test] + fn parse_create_table_with_unsupported_type_returns_error() { + let err = parse_sql("create table users (id boolean primary key);").unwrap_err(); + assert!(err.to_string().contains("not yet implemented")); + } + + #[test] + #[should_panic] + fn parse_copy_to_reaches_current_todo_branch() { + let _ = parse_sql("copy users to 'users.csv';").unwrap(); + } + + #[test] + fn test_create_table() { + let sql = " + CREATE TABLE Persons ( + PersonID int PRIMARY KEY, + LastName char(255), + FirstName char(255), + Address char(255), + City char(255) + );"; + let r = parse_sql(sql); + assert!(r.is_ok()); + + let sql = " + CREATE TABLE CUSTOMERS( + ID1 INT, + ID2 INT, + NAME CHAR (20), + AGE INT, + ADDRESS CHAR (25), + SALARY INT, + PRIMARY KEY (ID1, ID2) + );"; + let r = parse_sql(sql); + assert!(r.is_ok()); + } + + #[test] + fn test_create_table_ast_column_primary_key_index() { + let stmt = parse_create_table( + " + CREATE TABLE Persons ( + PersonID int PRIMARY KEY, + LastName char(255), + FirstName char(255) + ); + ", + ) + .unwrap(); + + let primary_columns = stmt.primary_columns(); + assert_eq!(primary_columns.len(), 1); + assert_eq!(primary_columns[0].column_name(), "PersonID"); + assert_eq!(primary_columns[0].primary_key_index(), Some(0)); + assert_eq!(primary_columns[0].column_index(), 0); + + let non_primary_columns = stmt.non_primary_columns(); + assert_eq!(non_primary_columns.len(), 2); + assert_eq!(non_primary_columns[0].column_name(), "LastName"); + assert_eq!(non_primary_columns[0].primary_key_index(), None); + assert_eq!(non_primary_columns[0].column_index(), 1); + assert_eq!(non_primary_columns[1].column_name(), "FirstName"); + assert_eq!(non_primary_columns[1].primary_key_index(), None); + assert_eq!(non_primary_columns[1].column_index(), 2); + } + + #[test] + fn test_create_table_ast_table_primary_key_index_and_idempotent() { + let mut stmt = parse_create_table( + " + CREATE TABLE CUSTOMERS( + ID1 INT, + ID2 INT, + NAME CHAR(20), + PRIMARY KEY (ID1, ID2) + ); + ", + ) + .unwrap(); + + stmt.assign_index_for_columns(); + + let primary_columns = stmt.primary_columns(); + assert_eq!(primary_columns.len(), 2); + assert_eq!(primary_columns[0].column_name(), "ID1"); + assert_eq!(primary_columns[0].primary_key_index(), Some(0)); + assert_eq!(primary_columns[0].column_index(), 0); + assert_eq!(primary_columns[1].column_name(), "ID2"); + assert_eq!(primary_columns[1].primary_key_index(), Some(1)); + assert_eq!(primary_columns[1].column_index(), 1); + + let non_primary_columns = stmt.non_primary_columns(); + assert_eq!(non_primary_columns.len(), 1); + assert_eq!(non_primary_columns[0].column_name(), "NAME"); + assert_eq!(non_primary_columns[0].primary_key_index(), None); + assert_eq!(non_primary_columns[0].column_index(), 2); + } + + #[test] + fn test_parse_ddl_file() { + let path = get_project_root().unwrap(); + let path = if path.file_name().unwrap().to_str().unwrap().eq("sql_parser") { + path + } else { + path.join("sql_parser") + }; + let path = path.join("data/ddl.sql"); + let r = parse_file(path); + assert!(r.is_ok()) + } +} diff --git a/sql_parser/src/ast/stmt_create_table.rs b/sql_parser/src/ast/stmt_create_table.rs index 16e5667..c52cb7d 100644 --- a/sql_parser/src/ast/stmt_create_table.rs +++ b/sql_parser/src/ast/stmt_create_table.rs @@ -1,13 +1,14 @@ use crate::ast::ast_node::ASTNode; use crate::ast::column_def::ColumnDef; +use mudu::common::id::AttrIndex; use std::fmt::Debug; #[derive(Clone, Debug)] pub struct StmtCreateTable { table_name: String, column_def: Vec, - primary_key_column_def: Vec, - non_primary_key_column_def: Vec, + primary_key_column_def: Vec, + non_primary_key_column_def: Vec, } impl StmtCreateTable { @@ -29,36 +30,54 @@ impl StmtCreateTable { } pub fn add_column_def(&mut self, def: ColumnDef) { - self.column_def.push(def) + let mut _def = def; + _def.set_index(self.column_def.len()); + self.column_def.push(_def) } pub fn mutable_column_def(&mut self) -> &mut Vec { &mut self.column_def } - pub fn non_primary_columns(&self) -> &Vec { + pub fn column_def_by_index(&self, index: AttrIndex) -> &ColumnDef { + &self.column_def[index] + } + + pub fn non_primary_column_indices(&self) -> &Vec { &self.non_primary_key_column_def } - pub fn primary_columns(&self) -> &Vec { + pub fn primary_column_indices(&self) -> &Vec { &self.primary_key_column_def } + pub fn non_primary_columns(&self) -> Vec<&ColumnDef> { + self.non_primary_key_column_def + .iter() + .map(|index| &self.column_def[*index]) + .collect() + } + + pub fn primary_columns(&self) -> Vec<&ColumnDef> { + self.primary_key_column_def + .iter() + .map(|index| &self.column_def[*index]) + .collect() + } + pub fn assign_index_for_columns(&mut self) { - let mut index_non_primary = 0; - let column_def_list = self.column_def.clone(); - for mut c in column_def_list { + self.primary_key_column_def.clear(); + self.non_primary_key_column_def.clear(); + + for (index, c) in self.column_def.iter_mut().enumerate() { if c.is_primary_key() { - self.primary_key_column_def.push(c); + self.primary_key_column_def.push(index); } else { - c.set_index(index_non_primary); - index_non_primary += 1; - self.non_primary_key_column_def.push(c); + self.non_primary_key_column_def.push(index); } } - self.primary_key_column_def.sort_by(|x, y| { - return x.column_index().cmp(&y.column_index()); - }) + self.primary_key_column_def + .sort_by_key(|index| self.column_def[*index].expect_primary_key_index()); } } diff --git a/sql_parser/src/ast/test_parser.rs b/sql_parser/src/ast/test_parser.rs deleted file mode 100644 index 66dbb66..0000000 --- a/sql_parser/src/ast/test_parser.rs +++ /dev/null @@ -1,153 +0,0 @@ -#[cfg(test)] -mod _test { - use crate::ast::parser::SQLParser; - use mudu::common::result::RS; - - use project_root::get_project_root; - use std::fs; - use std::path::Path; - - fn parse_sql(sql: &String) -> RS<()> { - let parser = SQLParser::new(); - let stmt_list = parser.parse(sql)?; - println!("stmt: {:?}", stmt_list); - Ok(()) - } - - #[test] - fn test_select() { - let sql = " - select - distinct column1, - column2 - from table1 - where column3 = 1;" - .to_string(); - let r = parse_sql(&sql); - assert!(r.is_ok()); - - let sql2 = " - select - distinct column1, - column2 - from table1 - where column3 = 1" - .to_string(); - let r = parse_sql(&sql2); - assert!(r.is_ok()); - } - - #[test] - fn test_update() { - let sql = "\ -UPDATE Customers \ -SET ContactName = 'Alfred Schmidt', City= 'Frankfurt' \ -WHERE CustomerID = 1;" - .to_string(); - let r = parse_sql(&sql); - if r.is_err() { - println!("{:#?}", r); - } - assert!(r.is_ok()); - } - - #[test] - fn test_delete() { - let sql = " DELETE FROM Customers - WHERE CustomerName='Alfreds Futterkiste'; " - .to_string(); - let r = parse_sql(&sql); - assert!(r.is_ok()); - } - - #[test] - fn test_insert() { - let sql = " - INSERT INTO Customers ( - CustomerName, - ContactName, - Address, - City, - PostalCode, - Country - ) - VALUES ( - 'Cardinal', - 'Tom B. Erichsen', - 'Skagen 21', - 'Stavanger', - '4006', - 'Norway' - ); - INSERT INTO Customers ( - CustomerName, - ContactName, - Address, - City, - PostalCode, - Country - ) - VALUES ( - 'Cardinal', - 'Tom B. Erichsen', - 'Skagen 21', - 'Stavanger', - '4006', - 'Norway' - ); - " - .to_string(); - let r = parse_sql(&sql); - assert!(r.is_ok()); - } - - #[test] - fn test_create_table() { - let sql = " - CREATE TABLE Persons ( - PersonID int PRIMARY KEY, - LastName char(255), - FirstName char(255), - Address char(255), - City char(255) - );" - .to_string(); - let r = parse_sql(&sql); - assert!(r.is_ok()); - - let sql = " - CREATE TABLE CUSTOMERS( - ID1 INT, - ID2 INT, - NAME CHAR (20), - AGE INT, - ADDRESS CHAR (25), - SALARY INT, - PRIMARY KEY (ID1, ID2) - );" - .to_string(); - let r = parse_sql(&sql); - assert!(r.is_ok()); - } - - fn parse_file>(path: P) -> RS<()> { - let sql = fs::read_to_string(path).unwrap(); - parse_sql(&sql)?; - Ok(()) - } - - #[test] - fn test_parse_ddl_file() { - let path = get_project_root().unwrap(); - let path = if path.file_name().unwrap().to_str().unwrap().eq("sql_parser") { - path - } else { - path.join("sql_parser") - }; - println!("path: {:?}", path); - let path = path.join("data/ddl.sql"); - let r = parse_file(path); - println!("{:?}", r); - assert!(r.is_ok()) - } -} diff --git a/sql_parser/src/lib.rs b/sql_parser/src/lib.rs index 4882f4b..f4ed808 100644 --- a/sql_parser/src/lib.rs +++ b/sql_parser/src/lib.rs @@ -2,17 +2,5 @@ pub mod ast; pub mod parser; pub mod ts_const; -pub fn add(left: u64, right: u64) -> u64 { - left + right -} - #[cfg(test)] -mod tests { - use super::*; - - #[test] - fn it_works() { - let result = add(2, 2); - assert_eq!(result, 4); - } -} +mod lib_test; diff --git a/sql_parser/src/lib_test.rs b/sql_parser/src/lib_test.rs new file mode 100644 index 0000000..a461d1c --- /dev/null +++ b/sql_parser/src/lib_test.rs @@ -0,0 +1,9 @@ +#[cfg(test)] +mod tests { + #[test] + fn sql_parser_crate_loads() { + let parser = crate::ast::parser::SQLParser::new(); + let stmt_list = parser.parse("select id from users;").unwrap(); + assert_eq!(stmt_list.stmts().len(), 1); + } +} diff --git a/sql_parser/src/parser/ddl_parser.rs b/sql_parser/src/parser/ddl_parser.rs index 202e4a9..dc1e568 100644 --- a/sql_parser/src/parser/ddl_parser.rs +++ b/sql_parser/src/parser/ddl_parser.rs @@ -40,7 +40,7 @@ impl DDLParser { let column_def = FieldDef::new( d.column_name().clone(), d.data_type().clone(), - d.is_primary_key(), + d.primary_key_index().is_some(), ); column_def })