diff --git a/Cargo.toml b/Cargo.toml index 788e526..9a85422 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -58,7 +58,8 @@ rand_core = { version = "0.6.4", default-features = false, features = [ once_cell = { version = "1.21.3", default-features = false, features = [ "critical-section", ] } -parking_lot = "0.12.3" +lock_api = "=0.4.13" # msrv 1.64 +parking_lot = { version = "=0.12.4", optional = true } # msrv 1.64 [target.'cfg(all(target_arch = "wasm32", target_os="unknown"))'.dependencies] # only for js (browser or node). if it's not js, like substrate, it won't build @@ -72,7 +73,7 @@ once_cell = { version = "1.21.3", default-features = false, features = ["std"] } [features] default = ["aes-openssl"] -std = ["hkdf/std", "sha2/std", "once_cell/std"] +std = ["hkdf/std", "sha2/std", "once_cell/std", "dep:parking_lot"] # curves # no usage, TODO: make optional after 0.3.0: secp256k1 = ["dep:libsecp256k1"] @@ -102,11 +103,11 @@ criterion = { version = "0.7.0", default-features = false } hex = { version = "0.4.3", default-features = false, features = ["alloc"] } [target.'cfg(target_arch = "wasm32")'.dev-dependencies] -wasm-bindgen-test = "0.3.50" +wasm-bindgen-test = "0.3.54" [target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies] futures-util = "0.3.31" -reqwest = "0.12.15" +reqwest = "0.12.23" tokio = { version = "1.44.1", default-features = false, features = [ "rt-multi-thread", ] } diff --git a/src/config.rs b/src/config.rs index ac111ba..4c6f8a5 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,4 +1,8 @@ use once_cell::sync::Lazy; + +#[cfg(not(feature = "std"))] +use crate::sync::RwLock; +#[cfg(feature = "std")] use parking_lot::RwLock; /// ECIES config. Make sure all parties use the same config diff --git a/src/lib.rs b/src/lib.rs index 6e428fe..c2bc31c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -19,6 +19,9 @@ pub mod utils; mod compat; mod elliptic; +#[cfg(not(feature = "std"))] +mod sync; + use config::{get_ephemeral_key_size, is_ephemeral_key_compressed}; use elliptic::{decapsulate, encapsulate, generate_keypair, parse_pk, parse_sk, pk_to_vec, Error}; use symmetric::{sym_decrypt, sym_encrypt}; diff --git a/src/sync.rs b/src/sync.rs new file mode 100644 index 0000000..d9d7f72 --- /dev/null +++ b/src/sync.rs @@ -0,0 +1,307 @@ +use core::sync::atomic::{AtomicU32, Ordering}; +use lock_api::{GuardSend, RawRwLock}; + +/// A raw reader-writer lock implementation for no_std environments +/// +/// This uses a spinlock approach with atomic operations: +/// - State 0: unlocked +/// - State 1..WRITER: number of readers +/// - State WRITER: exclusive writer lock +const WRITER: u32 = u32::MAX; +const MAX_READERS: u32 = u32::MAX - 1; + +pub struct RawSpinRwLock { + state: AtomicU32, +} + +unsafe impl RawRwLock for RawSpinRwLock { + const INIT: Self = RawSpinRwLock { + state: AtomicU32::new(0), + }; + + type GuardMarker = GuardSend; + + #[inline] + fn lock_shared(&self) { + while !self.try_lock_shared() { + // Spin with a hint to the CPU + core::hint::spin_loop(); + } + } + + #[inline] + fn try_lock_shared(&self) -> bool { + let mut state = self.state.load(Ordering::Relaxed); + + loop { + // Cannot acquire read lock if writer is present or too many readers + if state >= MAX_READERS { + return false; + } + + // Try to increment reader count + match self + .state + .compare_exchange_weak(state, state + 1, Ordering::Acquire, Ordering::Relaxed) + { + Ok(_) => return true, + Err(new_state) => state = new_state, + } + } + } + + #[inline] + unsafe fn unlock_shared(&self) { + // Decrement reader count + self.state.fetch_sub(1, Ordering::Release); + } + + #[inline] + fn lock_exclusive(&self) { + while !self.try_lock_exclusive() { + core::hint::spin_loop(); + } + } + + #[inline] + fn try_lock_exclusive(&self) -> bool { + // Try to acquire exclusive lock (state must be 0) + self.state + .compare_exchange(0, WRITER, Ordering::Acquire, Ordering::Relaxed) + .is_ok() + } + + #[inline] + unsafe fn unlock_exclusive(&self) { + self.state.store(0, Ordering::Release); + } +} + +/// A reader-writer lock for no_std environments +/// +/// This allows multiple concurrent readers or a single writer. +pub(crate) type RwLock = lock_api::RwLock; + +#[cfg(test)] +mod tests { + use super::*; + extern crate std; + use std::sync::Arc; + use std::thread; + use std::vec::Vec; + + #[test] + fn test_multiple_readers() { + let lock = RwLock::new(42); + + let r1 = lock.read(); + let r2 = lock.read(); + let r3 = lock.read(); + + assert_eq!(*r1, 42); + assert_eq!(*r2, 42); + assert_eq!(*r3, 42); + } + + #[test] + fn test_exclusive_writer() { + let lock = RwLock::new(42); + + let mut writer = lock.write(); + *writer = 100; + drop(writer); + + let reader = lock.read(); + assert_eq!(*reader, 100); + } + + #[test] + fn test_try_write_fails_with_readers() { + let lock = RwLock::new(42); + + let _reader = lock.read(); + assert!(lock.try_write().is_none()); + } + + #[test] + fn test_try_read_fails_with_writer() { + let lock = RwLock::new(42); + + let _writer = lock.write(); + assert!(lock.try_read().is_none()); + } + + #[test] + fn test_concurrent_readers() { + let lock = Arc::new(RwLock::new(0)); + let mut handles = Vec::new(); + + // Spawn 10 reader threads + for _ in 0..10 { + let lock = Arc::clone(&lock); + let handle = thread::spawn(move || { + for _ in 0..100 { + let value = lock.read(); + // Just read the value + let _ = *value; + } + }); + handles.push(handle); + } + + for handle in handles { + handle.join().unwrap(); + } + } + + #[test] + fn test_concurrent_writers() { + let lock = Arc::new(RwLock::new(0)); + let mut handles = Vec::new(); + + // Spawn 10 writer threads, each incrementing 100 times + for _ in 0..10 { + let lock = Arc::clone(&lock); + let handle = thread::spawn(move || { + for _ in 0..100 { + let mut value = lock.write(); + *value += 1; + } + }); + handles.push(handle); + } + + for handle in handles { + handle.join().unwrap(); + } + + // Should be exactly 1000 + assert_eq!(*lock.read(), 1000); + } + + #[test] + fn test_mixed_readers_writers() { + let lock = Arc::new(RwLock::new(0)); + let mut handles = Vec::new(); + + // Spawn 5 writer threads + for _ in 0..5 { + let lock = Arc::clone(&lock); + let handle = thread::spawn(move || { + for _ in 0..100 { + let mut value = lock.write(); + *value += 1; + } + }); + handles.push(handle); + } + + // Spawn 10 reader threads + for _ in 0..10 { + let lock = Arc::clone(&lock); + let handle = thread::spawn(move || { + for _ in 0..100 { + let value = lock.read(); + // Verify value is valid (between 0 and 500) + assert!(*value <= 500); + } + }); + handles.push(handle); + } + + for handle in handles { + handle.join().unwrap(); + } + + assert_eq!(*lock.read(), 500); + } + + #[test] + fn test_writer_blocks_readers() { + let lock = Arc::new(RwLock::new(0)); + let lock2 = Arc::clone(&lock); + + let writer = lock.write(); + + // Spawn a reader thread - it should block + let handle = thread::spawn(move || { + let value = lock2.read(); + *value + }); + + // Sleep a bit to ensure reader is blocked + thread::sleep(std::time::Duration::from_millis(50)); + + // Release writer lock + drop(writer); + + // Now reader should complete + let result = handle.join().unwrap(); + assert_eq!(result, 0); + } + + #[test] + fn test_readers_block_writer() { + let lock = Arc::new(RwLock::new(0)); + let lock2 = Arc::clone(&lock); + + let reader = lock.read(); + + // Spawn a writer thread - it should block + let handle = thread::spawn(move || { + let mut value = lock2.write(); + *value = 42; + }); + + // Sleep a bit to ensure writer is blocked + thread::sleep(std::time::Duration::from_millis(50)); + + // Release reader lock + drop(reader); + + // Wait for writer to complete + handle.join().unwrap(); + + assert_eq!(*lock.read(), 42); + } + + #[test] + fn test_stress_test() { + let lock = Arc::new(RwLock::new(Vec::new())); + let mut handles = Vec::new(); + + // Spawn multiple writers that push to the vec + for i in 0..5 { + let lock = Arc::clone(&lock); + let handle = thread::spawn(move || { + for j in 0..20 { + let mut vec = lock.write(); + vec.push(i * 100 + j); + } + }); + handles.push(handle); + } + + // Spawn multiple readers that check vec length + for _ in 0..5 { + let lock = Arc::clone(&lock); + let handle = thread::spawn(move || { + for _ in 0..50 { + let vec = lock.read(); + let len = vec.len(); + // Length should be between 0 and 100 + assert!(len <= 100); + thread::yield_now(); + } + }); + handles.push(handle); + } + + for handle in handles { + handle.join().unwrap(); + } + + // Final check - should have exactly 100 elements + assert_eq!(lock.read().len(), 100); + } +}