diff --git a/libcrux-ml-kem/Cargo.toml b/libcrux-ml-kem/Cargo.toml index 6d0eaf3b7..62c4967eb 100644 --- a/libcrux-ml-kem/Cargo.toml +++ b/libcrux-ml-kem/Cargo.toml @@ -33,12 +33,18 @@ libcrux-intrinsics = { version = "0.0.3", path = "../libcrux-intrinsics" } libcrux-secrets = { version = "0.0.3", path = "../secrets" } libcrux-traits = { version = "0.0.3", path = "../traits" } hax-lib.workspace = true +tls_codec = { version = "0.4.2", features = [ + "derive", +], default-features = false, optional = true } [features] # By default all variants and std are enabled. default = ["default-no-std", "std"] default-no-std = ["mlkem512", "mlkem768", "mlkem1024", "rand"] +# Serialization & Deserialization using tls_codec +codec = ["dep:tls_codec"] + # Hardware features can be force enabled. # It is not recommended to use these. This crate performs CPU feature detection # and enables the features when they are available. @@ -57,7 +63,7 @@ kyber = [] rand = ["dep:rand"] # std support -std = ["alloc", "rand/std"] +std = ["alloc", "rand/std", "tls_codec/std"] alloc = [] # Incremental encapsulation API diff --git a/libcrux-ml-kem/src/lib.rs b/libcrux-ml-kem/src/lib.rs index 0d87aa59c..7e27bb4e2 100644 --- a/libcrux-ml-kem/src/lib.rs +++ b/libcrux-ml-kem/src/lib.rs @@ -75,6 +75,9 @@ #[cfg(feature = "std")] extern crate std; +#[cfg(all(feature = "alloc", feature = "incremental"))] +extern crate alloc; + /// Feature gating helper macros #[macro_use] mod cfg; @@ -90,9 +93,6 @@ pub(crate) mod hax_utils; // This is being tracked in https://github.com/hacspec/hacspec-v2/issues/27 pub(crate) mod constants; -#[cfg(all(feature = "alloc", feature = "incremental"))] -extern crate alloc; - /// Helpers for verification and extraction mod helper; diff --git a/libcrux-ml-kem/src/types.rs b/libcrux-ml-kem/src/types.rs index 0f47d2555..d2d18a8c0 100644 --- a/libcrux-ml-kem/src/types.rs +++ b/libcrux-ml-kem/src/types.rs @@ -146,6 +146,122 @@ mod index_impls { impl_index_impls_for_generic_struct!(MlKemPublicKey); } +#[cfg(all(feature = "codec", feature = "alloc"))] +mod codec { + use super::*; + + macro_rules! impl_tls_codec_for_generic_struct { + ($name:ident) => { + // XXX: `tls_codec::{Serialize, Deserialize}` are only + // available for feature `std`. For `no_std` scenarios, we + // need to implement `tls_codec::{SerializeBytes, + // DeserializeBytes}`, but `SerializeBytes` is not + // implemented for `VLByteSlice`. + impl tls_codec::DeserializeBytes for $name { + fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), tls_codec::Error> { + let (bytes, remainder) = tls_codec::VLBytes::tls_deserialize_bytes(bytes)?; + Ok(( + Self { + value: bytes + .as_ref() + .try_into() + .map_err(|_| tls_codec::Error::InvalidInput)?, + }, + remainder, + )) + } + } + + #[cfg(feature = "std")] + impl tls_codec::Serialize for $name { + fn tls_serialize( + &self, + writer: &mut W, + ) -> Result { + let out = tls_codec::VLByteSlice(self.as_ref()); + out.tls_serialize(writer) + } + } + + #[cfg(feature = "std")] + impl tls_codec::Serialize for &$name { + fn tls_serialize( + &self, + writer: &mut W, + ) -> Result { + (*self).tls_serialize(writer) + } + } + + #[cfg(feature = "std")] + impl tls_codec::Deserialize for $name { + fn tls_deserialize( + bytes: &mut R, + ) -> Result { + let bytes = tls_codec::VLBytes::tls_deserialize(bytes)?; + Ok(Self { + value: bytes + .as_ref() + .try_into() + .map_err(|_| tls_codec::Error::InvalidInput)?, + }) + } + } + + impl tls_codec::Size for $name { + fn tls_serialized_len(&self) -> usize { + tls_codec::VLByteSlice(self.as_ref()).tls_serialized_len() + } + } + + impl tls_codec::Size for &$name { + fn tls_serialized_len(&self) -> usize { + (*self).tls_serialized_len() + } + } + }; + } + + impl_tls_codec_for_generic_struct!(MlKemCiphertext); + impl_tls_codec_for_generic_struct!(MlKemPublicKey); + + #[cfg(test)] + mod test { + use tls_codec::{Deserialize, Serialize, Size}; + + use super::*; + + #[test] + #[cfg(feature = "std")] + fn ser_de() { + use tls_codec::DeserializeBytes; + + const SIZE: usize = 1568; + let test_struct = MlKemCiphertext::::default(); + + assert_eq!(test_struct.tls_serialized_len(), SIZE + 2); + let test_struct_serialized = test_struct.tls_serialize_detached().unwrap(); + assert_eq!( + test_struct_serialized.len(), + test_struct.tls_serialized_len() + ); + + let test_struct_deserialized = + MlKemCiphertext::::tls_deserialize_exact(&test_struct_serialized).unwrap(); + + let test_struct_deserialized_bytes = + MlKemCiphertext::::tls_deserialize_exact_bytes(&test_struct_serialized) + .unwrap(); + + assert_eq!(test_struct.as_ref(), test_struct_deserialized.as_ref()); + assert_eq!( + test_struct.as_ref(), + test_struct_deserialized_bytes.as_ref() + ) + } + } +} + /// An ML-KEM key pair pub struct MlKemKeyPair { pub(crate) sk: MlKemPrivateKey, diff --git a/libcrux-psq/Cargo.toml b/libcrux-psq/Cargo.toml index 50715ffc1..3033a0011 100644 --- a/libcrux-psq/Cargo.toml +++ b/libcrux-psq/Cargo.toml @@ -20,16 +20,22 @@ libcrux-kem = { version = "=0.0.3", path = "../libcrux-kem" } libcrux-chacha20poly1305 = { version = "0.0.3", path = "../chacha20poly1305" } libcrux-hkdf = { version = "=0.0.3", path = "../libcrux-hkdf" } libcrux-hmac = { version = "=0.0.3", path = "../libcrux-hmac" } +libcrux-sha2 = { version = "=0.0.3", path = "../sha2" } classic-mceliece-rust = { version = "3.1.0", features = [ "mceliece460896f", "zeroize", ], optional = true } rand = { version = "0.9" } rand_old = { version = "0.8", package = "rand", optional = true } -libcrux-ecdh = { version = "0.0.3", path = "../libcrux-ecdh", optional = true } +libcrux-ecdh = { version = "0.0.3", path = "../libcrux-ecdh" } +libcrux-ml-kem = { version = "0.0.3", path = "../libcrux-ml-kem", features = [ + "codec", + "rand", +] } libcrux-ed25519 = { version = "0.0.3", path = "../ed25519", features = [ "rand", ] } +tls_codec = { version = "0.4.2", features = ["derive"] } [dev-dependencies] libcrux-psq = { path = ".", features = ["test-utils"] } @@ -43,9 +49,13 @@ classic-mceliece = ["dep:classic-mceliece-rust", "rand_old"] # DO NOT USE: This feature enables implementations backed # by non-post-quantum KEMs and should only be used for # testing purposes and benchmark baselines. -test-utils = ["libcrux-ecdh"] +test-utils = [] [[bench]] name = "psq" harness = false required-features = ["classic-mceliece", "test-utils"] + +[[bench]] +name = "psq_v2" +harness = false diff --git a/libcrux-psq/benches/psq_v2.rs b/libcrux-psq/benches/psq_v2.rs new file mode 100644 index 000000000..3b2bcb630 --- /dev/null +++ b/libcrux-psq/benches/psq_v2.rs @@ -0,0 +1,763 @@ +use criterion::{criterion_group, criterion_main, BatchSize, Criterion}; + +use libcrux_psq::protocol::{ + api::{Builder, IntoTransport, Protocol}, + dhkem::DHKeyPair, + initiator::{QueryInitiator, RegistrationInitiator}, + pqkem::PQKeyPair, + responder::Responder, +}; +use rand::CryptoRng; + +pub fn randombytes(n: usize) -> Vec { + use rand::rngs::OsRng; + use rand::TryRngCore; + + let mut bytes = vec![0u8; n]; + OsRng.try_fill_bytes(&mut bytes).unwrap(); + bytes +} + +fn query(c: &mut Criterion) { + let mut rng = rand::rng(); + let ciphersuite = if PQ { "x25519" } else { "x25519-mlkem768" }; + let ctx = b"Test Context"; + let aad_initiator = b"Test Data I"; + let aad_responder = b"Test Data R"; + + // External setup + let responder_ecdh_keys = DHKeyPair::new(&mut rng); + let responder_pq_keys = PQKeyPair::new(&mut rng); + + // x25519 + + // Setup initiator + let mut initiator = query_initiator(rand::rng(), ctx, aad_initiator, &responder_ecdh_keys); + c.bench_function(&format!("[Query] Initiator setup"), |b| { + b.iter_batched( + || rand::rng(), + |rng| { + initiator = query_initiator(rng, ctx, aad_initiator, &responder_ecdh_keys); + }, + BatchSize::SmallInput, + ) + }); + + // Setup responder + let mut responder = build_responder::( + rand::rng(), + ctx, + aad_responder, + &responder_ecdh_keys, + &responder_pq_keys, + ); + c.bench_function(&format!("[Query] Responder setup {ciphersuite}"), |b| { + b.iter_batched( + || rand::rng(), + |rng| { + responder = build_responder::( + rng, + ctx, + aad_responder, + &responder_ecdh_keys, + &responder_pq_keys, + ); + }, + BatchSize::SmallInput, + ) + }); + + // Setup for sending messages. + + // Send first message + c.bench_function( + &format!("[Query] Initiator send query {ciphersuite}"), + |b| { + b.iter_batched_ref( + || { + let msg_channel = vec![0u8; 4096]; + + let initiator = + query_initiator(rand::rng(), ctx, aad_initiator, &responder_ecdh_keys); + + (initiator, msg_channel) + }, + |(initiator, msg_channel)| { + let query_payload_initiator = b"Query_init"; + let _len_i = initiator + .write_message(query_payload_initiator, msg_channel) + .unwrap(); + }, + BatchSize::SmallInput, + ) + }, + ); + + // Read first message + c.bench_function( + &format!("[Query] Responder read message {ciphersuite}"), + |b| { + b.iter_batched_ref( + || { + let mut msg_channel = vec![0u8; 4096]; + let payload_buf_responder = vec![0u8; 4096]; + + let mut initiator = + query_initiator(rand::rng(), ctx, aad_initiator, &responder_ecdh_keys); + + let query_payload_initiator = b"Query_init"; + let _len_i = initiator + .write_message(query_payload_initiator, &mut msg_channel) + .unwrap(); + + let responder = build_responder::( + rand::rng(), + ctx, + aad_responder, + &responder_ecdh_keys, + &responder_pq_keys, + ); + + (responder, msg_channel, payload_buf_responder) + }, + |(responder, msg_channel, payload_buf_responder)| { + let (_len_r_deserialized, _len_r_payload) = responder + .read_message(msg_channel, payload_buf_responder) + .unwrap(); + }, + BatchSize::SmallInput, + ) + }, + ); + + // Respond + c.bench_function(&format!("[Query] Responder respond {ciphersuite}"), |b| { + b.iter_batched_ref( + || { + let mut msg_channel = vec![0u8; 4096]; + let mut payload_buf_responder = vec![0u8; 4096]; + + let mut initiator = + query_initiator(rand::rng(), ctx, aad_initiator, &responder_ecdh_keys); + + let query_payload_initiator = b"Query_init"; + let _len_i = initiator + .write_message(query_payload_initiator, &mut msg_channel) + .unwrap(); + + let mut responder = build_responder::( + rand::rng(), + ctx, + aad_responder, + &responder_ecdh_keys, + &responder_pq_keys, + ); + + let (_len_r_deserialized, _len_r_payload) = responder + .read_message(&msg_channel, &mut payload_buf_responder) + .unwrap(); + + (responder, msg_channel) + }, + |(responder, msg_channel)| { + let query_payload_responder = b"Query_respond"; + let _len_r = responder + .write_message(query_payload_responder, msg_channel) + .unwrap(); + }, + BatchSize::SmallInput, + ) + }); + + // Finalize on query initiator + c.bench_function(&format!("[Query] Finalize initiator {ciphersuite}"), |b| { + b.iter_batched_ref( + || { + let mut msg_channel = vec![0u8; 4096]; + let mut payload_buf_responder = vec![0u8; 4096]; + let payload_buf_initiator = vec![0u8; 4096]; + + let mut initiator = + query_initiator(rand::rng(), ctx, aad_initiator, &responder_ecdh_keys); + + let query_payload_initiator = b"Query_init"; + let _len_i = initiator + .write_message(query_payload_initiator, &mut msg_channel) + .unwrap(); + + let mut responder = build_responder::( + rand::rng(), + ctx, + aad_responder, + &responder_ecdh_keys, + &responder_pq_keys, + ); + + let (_len_r_deserialized, _len_r_payload) = responder + .read_message(&msg_channel, &mut payload_buf_responder) + .unwrap(); + + let query_payload_responder = b"Query_respond"; + let _len_r = responder + .write_message(query_payload_responder, &mut msg_channel) + .unwrap(); + + (initiator, msg_channel, payload_buf_initiator) + }, + |(initiator, msg_channel, payload_buf_initiator)| { + let (_len_i_deserialized, _len_i_payload) = initiator + .read_message(msg_channel, payload_buf_initiator) + .unwrap(); + }, + BatchSize::SmallInput, + ) + }); +} + +fn registration(c: &mut Criterion) { + let mut rng = rand::rng(); + let ciphersuite = if PQ { "x25519" } else { "x25519-mlkem768" }; + let ctx = b"Test Context"; + let aad_initiator_outer = b"Test Data I Outer"; + let aad_initiator_inner = b"Test Data I Inner"; + let aad_responder = b"Test Data R"; + + // External setup + let responder_ecdh_keys = DHKeyPair::new(&mut rng); + let responder_pq_keys = PQKeyPair::new(&mut rng); + let initiator_ecdh_keys = DHKeyPair::new(&mut rng); + + // x25519 + + // Setup initiator + let mut initiator = registration_initiator::( + rand::rng(), + ctx, + aad_initiator_outer, + aad_initiator_inner, + &responder_ecdh_keys, + &responder_pq_keys, + &initiator_ecdh_keys, + ); + c.bench_function(&format!("[Registration] Initiator setup"), |b| { + b.iter_batched( + || rand::rng(), + |rng| { + initiator = registration_initiator::( + rng, + ctx, + aad_initiator_outer, + aad_initiator_inner, + &responder_ecdh_keys, + &responder_pq_keys, + &initiator_ecdh_keys, + ); + }, + BatchSize::SmallInput, + ) + }); + + // Setup responder + let mut responder = build_responder::( + rand::rng(), + ctx, + aad_responder, + &responder_ecdh_keys, + &responder_pq_keys, + ); + c.bench_function( + &format!("[Registration] Responder setup {ciphersuite}"), + |b| { + b.iter_batched( + || rand::rng(), + |rng| { + responder = build_responder::( + rng, + ctx, + aad_responder, + &responder_ecdh_keys, + &responder_pq_keys, + ); + }, + BatchSize::SmallInput, + ) + }, + ); + + // Setup for sending messages. + + // Send first message + c.bench_function( + &format!("[Registration] Initiator send registration {ciphersuite}"), + |b| { + b.iter_batched_ref( + || { + let msg_channel = vec![0u8; 4096]; + + let initiator = registration_initiator::( + rand::rng(), + ctx, + aad_initiator_outer, + aad_initiator_inner, + &responder_ecdh_keys, + &responder_pq_keys, + &initiator_ecdh_keys, + ); + + (initiator, msg_channel) + }, + |(initiator, msg_channel)| { + let registration_payload_initiator = b"Registration_init"; + let _len_i = initiator + .write_message(registration_payload_initiator, msg_channel) + .unwrap(); + }, + BatchSize::SmallInput, + ) + }, + ); + + // Read first message + c.bench_function( + &format!("[Registration] Responder read message {ciphersuite}"), + |b| { + b.iter_batched_ref( + || { + let mut msg_channel = vec![0u8; 4096]; + let payload_buf_responder = vec![0u8; 4096]; + + let mut initiator = registration_initiator::( + rand::rng(), + ctx, + aad_initiator_outer, + aad_initiator_inner, + &responder_ecdh_keys, + &responder_pq_keys, + &initiator_ecdh_keys, + ); + + let registration_payload_initiator = b"Registration_init"; + let _len_i = initiator + .write_message(registration_payload_initiator, &mut msg_channel) + .unwrap(); + + let responder = build_responder::( + rand::rng(), + ctx, + aad_responder, + &responder_ecdh_keys, + &responder_pq_keys, + ); + + (responder, msg_channel, payload_buf_responder) + }, + |(responder, msg_channel, payload_buf_responder)| { + let (_len_r_deserialized, _len_r_payload) = responder + .read_message(msg_channel, payload_buf_responder) + .unwrap(); + }, + BatchSize::SmallInput, + ) + }, + ); + + // Respond + c.bench_function( + &format!("[Registration] Responder respond {ciphersuite}"), + |b| { + b.iter_batched_ref( + || { + let mut msg_channel = vec![0u8; 4096]; + let mut payload_buf_responder = vec![0u8; 4096]; + + let mut initiator = registration_initiator::( + rand::rng(), + ctx, + aad_initiator_outer, + aad_initiator_inner, + &responder_ecdh_keys, + &responder_pq_keys, + &initiator_ecdh_keys, + ); + + let registration_payload_initiator = b"Registration_init"; + let _len_i = initiator + .write_message(registration_payload_initiator, &mut msg_channel) + .unwrap(); + + let mut responder = build_responder::( + rand::rng(), + ctx, + aad_responder, + &responder_ecdh_keys, + &responder_pq_keys, + ); + + let (_len_r_deserialized, _len_r_payload) = responder + .read_message(&msg_channel, &mut payload_buf_responder) + .unwrap(); + + (responder, msg_channel) + }, + |(responder, msg_channel)| { + let registration_payload_responder = b"Registration_respond"; + let _len_r = responder + .write_message(registration_payload_responder, msg_channel) + .unwrap(); + }, + BatchSize::SmallInput, + ) + }, + ); + + // Finalize on registration initiator + c.bench_function( + &format!("[Registration] Finalize initiator {ciphersuite}"), + |b| { + b.iter_batched_ref( + || { + let mut msg_channel = vec![0u8; 4096]; + let mut payload_buf_responder = vec![0u8; 4096]; + let payload_buf_initiator = vec![0u8; 4096]; + + let mut initiator = registration_initiator::( + rand::rng(), + ctx, + aad_initiator_outer, + aad_initiator_inner, + &responder_ecdh_keys, + &responder_pq_keys, + &initiator_ecdh_keys, + ); + + let registration_payload_initiator = b"Registration_init"; + let _len_i = initiator + .write_message(registration_payload_initiator, &mut msg_channel) + .unwrap(); + + let mut responder = build_responder::( + rand::rng(), + ctx, + aad_responder, + &responder_ecdh_keys, + &responder_pq_keys, + ); + + let (_len_r_deserialized, _len_r_payload) = responder + .read_message(&msg_channel, &mut payload_buf_responder) + .unwrap(); + + let registration_payload_responder = b"Registration_respond"; + let _len_r = responder + .write_message(registration_payload_responder, &mut msg_channel) + .unwrap(); + + (initiator, msg_channel, payload_buf_initiator) + }, + |(initiator, msg_channel, payload_buf_initiator)| { + let (_len_i_deserialized, _len_i_payload) = initiator + .read_message(msg_channel, payload_buf_initiator) + .unwrap(); + }, + BatchSize::SmallInput, + ) + }, + ); + + // IntoTransport transform Initiator + c.bench_function( + &format!("[Registration] IntoTransport Responder {ciphersuite}"), + |b| { + b.iter_batched( + || { + let mut msg_channel = vec![0u8; 4096]; + let mut payload_buf_responder = vec![0u8; 4096]; + + let mut initiator = registration_initiator::( + rand::rng(), + ctx, + aad_initiator_outer, + aad_initiator_inner, + &responder_ecdh_keys, + &responder_pq_keys, + &initiator_ecdh_keys, + ); + + let registration_payload_initiator = b"Registration_init"; + let _len_i = initiator + .write_message(registration_payload_initiator, &mut msg_channel) + .unwrap(); + + let mut responder = build_responder::( + rand::rng(), + ctx, + aad_responder, + &responder_ecdh_keys, + &responder_pq_keys, + ); + + let (_len_r_deserialized, _len_r_payload) = responder + .read_message(&msg_channel, &mut payload_buf_responder) + .unwrap(); + + let registration_payload_responder = b"Registration_respond"; + let _len_r = responder + .write_message(registration_payload_responder, &mut msg_channel) + .unwrap(); + + responder + }, + |responder| { + let _ = responder.into_transport_mode(); + }, + BatchSize::SmallInput, + ) + }, + ); + + // IntoTransport transform Initiator + c.bench_function( + &format!("[Registration] IntoTransport Initiator {ciphersuite}"), + |b| { + b.iter_batched( + || { + let mut msg_channel = vec![0u8; 4096]; + let mut payload_buf_responder = vec![0u8; 4096]; + let mut payload_buf_initiator = vec![0u8; 4096]; + + let mut initiator = registration_initiator::( + rand::rng(), + ctx, + aad_initiator_outer, + aad_initiator_inner, + &responder_ecdh_keys, + &responder_pq_keys, + &initiator_ecdh_keys, + ); + + let registration_payload_initiator = b"Registration_init"; + let _len_i = initiator + .write_message(registration_payload_initiator, &mut msg_channel) + .unwrap(); + + let mut responder = build_responder::( + rand::rng(), + ctx, + aad_responder, + &responder_ecdh_keys, + &responder_pq_keys, + ); + + let (_len_r_deserialized, _len_r_payload) = responder + .read_message(&msg_channel, &mut payload_buf_responder) + .unwrap(); + + let registration_payload_responder = b"Registration_respond"; + let _len_r = responder + .write_message(registration_payload_responder, &mut msg_channel) + .unwrap(); + + let (_len_i_deserialized, _len_i_payload) = initiator + .read_message(&msg_channel, &mut payload_buf_initiator) + .unwrap(); + + initiator + }, + |initiator| { + let _ = initiator.into_transport_mode(); + }, + BatchSize::SmallInput, + ) + }, + ); + + // Transport write message + c.bench_function( + &format!("[Registration] Transport Write {ciphersuite}"), + |b| { + b.iter_batched_ref( + || { + let mut msg_channel = vec![0u8; 5050]; + let mut payload_buf_responder = vec![0u8; 4096]; + let mut payload_buf_initiator = vec![0u8; 4096]; + + let mut initiator = registration_initiator::( + rand::rng(), + ctx, + aad_initiator_outer, + aad_initiator_inner, + &responder_ecdh_keys, + &responder_pq_keys, + &initiator_ecdh_keys, + ); + + let registration_payload_initiator = b"Registration_init"; + let _len_i = initiator + .write_message(registration_payload_initiator, &mut msg_channel) + .unwrap(); + + let mut responder = build_responder::( + rand::rng(), + ctx, + aad_responder, + &responder_ecdh_keys, + &responder_pq_keys, + ); + + let (_len_r_deserialized, _len_r_payload) = responder + .read_message(&msg_channel, &mut payload_buf_responder) + .unwrap(); + + let registration_payload_responder = b"Registration_respond"; + let _len_r = responder + .write_message(registration_payload_responder, &mut msg_channel) + .unwrap(); + + let (_len_i_deserialized, _len_i_payload) = initiator + .read_message(&msg_channel, &mut payload_buf_initiator) + .unwrap(); + + let initiator = initiator.into_transport_mode().unwrap(); + let payload = randombytes(4096); + (initiator, msg_channel, payload) + }, + |(initiator, msg_channel, payload)| { + let _ = initiator.write_message(payload, msg_channel).unwrap(); + }, + BatchSize::SmallInput, + ) + }, + ); + + // Transport read message + c.bench_function( + &format!("[Registration] Transport Read {ciphersuite}"), + |b| { + b.iter_batched_ref( + || { + let mut msg_channel = vec![0u8; 5050]; + let mut payload_buf_responder = vec![0u8; 4096]; + let mut payload_buf_initiator = vec![0u8; 4096]; + + let mut initiator = registration_initiator::( + rand::rng(), + ctx, + aad_initiator_outer, + aad_initiator_inner, + &responder_ecdh_keys, + &responder_pq_keys, + &initiator_ecdh_keys, + ); + + let registration_payload_initiator = b"Registration_init"; + let _len_i = initiator + .write_message(registration_payload_initiator, &mut msg_channel) + .unwrap(); + + let mut responder = build_responder::( + rand::rng(), + ctx, + aad_responder, + &responder_ecdh_keys, + &responder_pq_keys, + ); + + let (_len_r_deserialized, _len_r_payload) = responder + .read_message(&msg_channel, &mut payload_buf_responder) + .unwrap(); + + let registration_payload_responder = b"Registration_respond"; + let _len_r = responder + .write_message(registration_payload_responder, &mut msg_channel) + .unwrap(); + + let (_len_i_deserialized, _len_i_payload) = initiator + .read_message(&msg_channel, &mut payload_buf_initiator) + .unwrap(); + + let mut initiator = initiator.into_transport_mode().unwrap(); + let _ = initiator + .write_message(&randombytes(4096), &mut msg_channel) + .unwrap(); + + let responder = responder.into_transport_mode().unwrap(); + (responder, msg_channel, payload_buf_responder) + }, + |(responder, msg_channel, payload_buf_responder)| { + let _ = responder + .read_message(msg_channel, payload_buf_responder) + .unwrap(); + }, + BatchSize::SmallInput, + ) + }, + ); +} + +#[inline(always)] +fn build_responder<'a, const PQ: bool>( + rng: impl CryptoRng, + ctx: &'a [u8], + aad_responder: &'a [u8], + responder_ecdh_keys: &'a DHKeyPair, + responder_pq_keys: &'a PQKeyPair, +) -> Responder<'a, impl CryptoRng> { + let mut responder = Builder::new(rng) + .context(ctx) + .outer_aad(aad_responder) + .longterm_ecdh_keys(responder_ecdh_keys) + .recent_keys_upper_bound(30); + if PQ { + responder = responder.longterm_pq_keys(&responder_pq_keys); + } + responder.build_responder().unwrap() +} + +#[inline(always)] +fn query_initiator<'a>( + rng: impl CryptoRng, + ctx: &'a [u8], + aad_initiator: &'a [u8], + responder_ecdh_keys: &'a DHKeyPair, +) -> QueryInitiator<'a> { + Builder::new(rng) + .outer_aad(aad_initiator) + .context(ctx) + .peer_longterm_ecdh_pk(&responder_ecdh_keys.pk) + .build_query_initiator() + .unwrap() +} + +#[inline(always)] +fn registration_initiator<'a, const PQ: bool>( + rng: impl CryptoRng, + ctx: &'a [u8], + aad_initiator_outer: &'a [u8], + aad_initiator_inner: &'a [u8], + responder_ecdh_keys: &'a DHKeyPair, + responder_pq_keys: &'a PQKeyPair, + initiator_ecdh_keys: &'a DHKeyPair, +) -> RegistrationInitiator<'a, impl CryptoRng> { + let mut builder = Builder::new(rng) + .outer_aad(aad_initiator_outer) + .outer_aad(aad_initiator_inner) + .context(ctx) + .peer_longterm_ecdh_pk(&responder_ecdh_keys.pk) + .longterm_ecdh_keys(initiator_ecdh_keys); + if PQ { + builder = builder.peer_longterm_pq_pk(&responder_pq_keys.pk); + } + builder.build_registration_initiator().unwrap() +} + +pub fn protocol(c: &mut Criterion) { + // PSQv2 query protocol + query::(c); + query::(c); + // PSQv2 registration protocol + registration::(c); + registration::(c); +} + +criterion_group!(benches, protocol); +criterion_main!(benches); diff --git a/libcrux-psq/src/lib.rs b/libcrux-psq/src/lib.rs index 06a43dd48..7fbfdfe5b 100644 --- a/libcrux-psq/src/lib.rs +++ b/libcrux-psq/src/lib.rs @@ -10,6 +10,10 @@ use std::array::TryFromSliceError; #[derive(Debug)] /// PSQ Errors. pub enum Error { + /// An error during serialization. + Serialization, + /// The Initiator message was stale + TimestampElapsed, /// An invalid public key was provided InvalidPublicKey, /// An invalid private key was provided @@ -75,6 +79,7 @@ const PSK_LENGTH: usize = 32; type Psk = [u8; PSK_LENGTH]; pub mod cred; +pub mod protocol; pub mod psk_registration; #[cfg(feature = "classic-mceliece")] diff --git a/libcrux-psq/src/protocol.rs b/libcrux-psq/src/protocol.rs new file mode 100644 index 000000000..d3e1ff1a8 --- /dev/null +++ b/libcrux-psq/src/protocol.rs @@ -0,0 +1,44 @@ +//! The PSQ registration protocol +#![allow(missing_docs)] + +use api::Error; +use dhkem::DHPublicKey; +use pqkem::PQCiphertext; +use tls_codec::{TlsDeserialize, TlsSerialize, TlsSize, VLByteSlice, VLBytes}; + +pub mod dhkem; +pub mod initiator; +mod keys; +pub mod pqkem; +pub mod responder; +pub mod session; +mod transcript; + +pub mod api; + +#[derive(TlsDeserialize, TlsSize)] +pub struct Message { + pk: DHPublicKey, + ciphertext: VLBytes, + tag: [u8; 16], + aad: VLBytes, + pq_encapsulation: Option, +} + +#[derive(TlsSerialize, TlsSize)] +pub struct MessageOut<'a> { + pk: &'a DHPublicKey, + ciphertext: VLByteSlice<'a>, + tag: [u8; 16], // XXX: implement Serialize for &[T; N] + aad: VLByteSlice<'a>, + pq_encapsulation: Option<&'a PQCiphertext>, +} + +pub(crate) fn write_output(payload: &[u8], out: &mut [u8]) -> Result { + let payload_len = payload.len(); + if out.len() < payload_len { + return Err(Error::OutputBufferShort); + } + out[..payload_len].copy_from_slice(payload); + Ok(payload_len) +} diff --git a/libcrux-psq/src/protocol/api.rs b/libcrux-psq/src/protocol/api.rs new file mode 100644 index 000000000..b1bb25e57 --- /dev/null +++ b/libcrux-psq/src/protocol/api.rs @@ -0,0 +1,273 @@ +use std::io::Cursor; + +use rand::CryptoRng; + +use tls_codec::{ + Deserialize, Serialize, Size, TlsDeserialize, TlsSerialize, TlsSize, VLByteSlice, VLBytes, +}; + +use crate::protocol::write_output; + +use super::{ + dhkem::{DHKeyPair, DHPublicKey}, + initiator::{QueryInitiator, RegistrationInitiator}, + keys::{derive_session_key, AEADKey}, + pqkem::{PQKeyPair, PQPublicKey}, + responder::Responder, + session::{SessionKey, SESSION_ID_LENGTH}, + transcript::Transcript, +}; + +#[derive(Debug)] +pub enum Error { + BuilderState, + Serialize(tls_codec::Error), + Deserialize(tls_codec::Error), + CryptoError, + InitiatorState, + ResponderState, + TransportState, + OutputBufferShort, + PayloadTooLong, + OtherError, +} + +#[derive(Debug)] +pub(crate) struct ToTransportState { + pub(crate) tx2: Transcript, + pub(crate) k2: AEADKey, +} + +pub struct Transport { + session_key: SessionKey, +} +impl Transport { + pub(crate) fn new(tx2: Transcript, k2: AEADKey) -> Result { + Ok(Self { + session_key: derive_session_key(k2, tx2)?, + }) + } + + pub fn id(&self) -> &[u8; SESSION_ID_LENGTH] { + &self.session_key.identifier + } +} + +#[derive(TlsSerialize, TlsSize)] +struct TransportMessageOut<'a> { + ciphertext: VLByteSlice<'a>, + tag: [u8; 16], +} + +#[derive(TlsDeserialize, TlsSize)] +struct TransportMessage { + ciphertext: VLBytes, + tag: [u8; 16], +} + +impl Protocol for Transport { + fn write_message(&mut self, payload: &[u8], out: &mut [u8]) -> Result { + // We match the maximum payload length of Noise. + if payload.len() > 65535 { + return Err(Error::PayloadTooLong); + } + let mut ciphertext = vec![0u8; payload.len()]; + let tag = self + .session_key + .key + .encrypt(payload, &[], &mut ciphertext)?; + let message = TransportMessageOut { + ciphertext: VLByteSlice(ciphertext.as_ref()), + tag, + }; + + message + .tls_serialize(&mut &mut out[..]) + .map_err(|e| Error::Serialize(e)) + } + + fn read_message(&mut self, message: &[u8], out: &mut [u8]) -> Result<(usize, usize), Error> { + let message = TransportMessage::tls_deserialize(&mut Cursor::new(message)) + .map_err(|e| Error::Deserialize(e))?; + + let bytes_deserialized = message.tls_serialized_len(); + + let payload = + self.session_key + .key + .decrypt(message.ciphertext.as_slice(), &message.tag, &[])?; + + let out_bytes_written = write_output(&payload, out)?; + + Ok((bytes_deserialized, out_bytes_written)) + } +} + +pub trait IntoTransport { + fn into_transport_mode(self) -> Result; + fn is_handshake_finished(&self) -> bool; +} + +pub trait Protocol { + /// Write a handshake message to `out` to drive the handshake forward. + /// + /// The message may include a `payload`. Returns the number of + /// bytes written to `out`. If the internal state is not ready to + /// write a message, nothing is written to `out` and `Ok(0)` is + /// returned. + fn write_message(&mut self, payload: &[u8], out: &mut [u8]) -> Result; + + /// Reads the bytes in `message` as input to the handshake, and + /// writes any payload bytes to `payload`. + /// + /// Returns a pair of `(bytes_deserialized, bytes_written)`, where + /// `bytes_deserialized` is the number of bytes read from + /// `message` and `bytes_written` is the number of bytes written + /// to `payload`. If the internal state is not ready to read a + /// message, nothing is written to `payload` and `Ok((0,0))` is + /// returned. + fn read_message(&mut self, message: &[u8], payload: &mut [u8]) + -> Result<(usize, usize), Error>; +} + +pub struct Builder<'a, Rng: CryptoRng> { + rng: Rng, + context: &'a [u8], + inner_aad: &'a [u8], + outer_aad: &'a [u8], + longterm_ecdh_keys: Option<&'a DHKeyPair>, + longterm_pq_keys: Option<&'a PQKeyPair>, + peer_longterm_ecdh_pk: Option<&'a DHPublicKey>, + peer_longterm_pq_pk: Option<&'a PQPublicKey>, + recent_keys_upper_bound: Option, +} + +impl<'a, Rng: CryptoRng> Builder<'a, Rng> { + /// Create a new builder. + pub fn new(rng: Rng) -> Self { + Self { + rng, + context: &[], + inner_aad: &[], + outer_aad: &[], + longterm_ecdh_keys: None, + longterm_pq_keys: None, + peer_longterm_ecdh_pk: None, + peer_longterm_pq_pk: None, + recent_keys_upper_bound: None, + } + } + + // properties + + /// Set the context. + pub fn context(mut self, context: &'a [u8]) -> Self { + self.context = context; + self + } + + /// Set the inner AAD. + pub fn inner_aad(mut self, inner_aad: &'a [u8]) -> Self { + self.inner_aad = inner_aad; + self + } + + /// Set the outer AAD. + pub fn outer_aad(mut self, outer_aad: &'a [u8]) -> Self { + self.outer_aad = outer_aad; + self + } + + /// Set the long-term ECDH key pair. + pub fn longterm_ecdh_keys(mut self, longterm_ecdh_keys: &'a DHKeyPair) -> Self { + self.longterm_ecdh_keys = Some(longterm_ecdh_keys); + self + } + + /// Set the long-term PQ key pair. + pub fn longterm_pq_keys(mut self, longterm_pq_keys: &'a PQKeyPair) -> Self { + self.longterm_pq_keys = Some(longterm_pq_keys); + self + } + + /// Set the peer's long-term ECDH public key. + pub fn peer_longterm_ecdh_pk(mut self, peer_longterm_ecdh_pk: &'a DHPublicKey) -> Self { + self.peer_longterm_ecdh_pk = Some(peer_longterm_ecdh_pk); + self + } + + /// Set the peer's long-term PQ public key. + pub fn peer_longterm_pq_pk(mut self, peer_longterm_pq_pk: &'a PQPublicKey) -> Self { + self.peer_longterm_pq_pk = Some(peer_longterm_pq_pk); + self + } + + /// Set the maximum number of recent keys stored for DoS protection. + pub fn recent_keys_upper_bound(mut self, recent_keys_upper_bound: usize) -> Self { + self.recent_keys_upper_bound = Some(recent_keys_upper_bound); + self + } + + // builders + + /// Build a new [`QueryInitiator`]. + /// + /// This requires that a `responder_ecdh_pk` is set. + /// It also uses the `context` and `outer_aad`. + pub fn build_query_initiator(self) -> Result, Error> { + let Some(responder_ecdh_pk) = self.peer_longterm_ecdh_pk else { + return Err(Error::BuilderState); + }; + + QueryInitiator::new(responder_ecdh_pk, self.context, self.outer_aad, self.rng) + } + + /// Build a new [`RegistrationInitiator`]. + /// + /// This requires that a `longterm_ecdh_keys` and a `peer_longterm_ecdh_pk` + /// is set. + /// It also uses the `context`, `inner_aad`, `outer_aad`, and + /// `peer_longterm_pq_pk`. + pub fn build_registration_initiator(self) -> Result, Error> { + let Some(longterm_ecdh_keys) = self.longterm_ecdh_keys else { + return Err(Error::BuilderState); + }; + + let Some(peer_longterm_ecdh_pk) = self.peer_longterm_ecdh_pk else { + return Err(Error::BuilderState); + }; + + RegistrationInitiator::new( + longterm_ecdh_keys, + peer_longterm_ecdh_pk, + self.peer_longterm_pq_pk, + self.context, + self.inner_aad, + self.outer_aad, + self.rng, + ) + } + + /// Build a new [`Responder`]. + /// + /// This requires that a `longterm_ecdh_keys`, and `recent_keys_upper_bound` is set. + /// It also uses the `context`, `outer_aad`, and `longterm_pq_keys`. + pub fn build_responder(self) -> Result, Error> { + let Some(longterm_ecdh_keys) = self.longterm_ecdh_keys else { + return Err(Error::BuilderState); + }; + + let Some(recent_keys_upper_bound) = self.recent_keys_upper_bound else { + return Err(Error::BuilderState); + }; + + Ok(Responder::new( + longterm_ecdh_keys, + self.longterm_pq_keys, + self.context, + self.outer_aad, + recent_keys_upper_bound, + self.rng, + )) + } +} diff --git a/libcrux-psq/src/protocol/dhkem.rs b/libcrux-psq/src/protocol/dhkem.rs new file mode 100644 index 000000000..19618dde0 --- /dev/null +++ b/libcrux-psq/src/protocol/dhkem.rs @@ -0,0 +1,92 @@ +//! Diffie-Hellman KEM type wrappers +//! +//! This module provides wrappers around KEM types, assuming a DH-KEM +//! style API. +use libcrux_ecdh::{secret_to_public, Algorithm}; +use rand::CryptoRng; +use tls_codec::{TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSerializeBytes, TlsSize}; + +use crate::protocol::api::Error; + +#[derive(TlsSerializeBytes, TlsSize)] +/// A wrapper around a KEM shared secret. +/// +/// We don't directly expose this. +pub(crate) struct DHSharedSecret(Vec); + +impl AsRef<[u8]> for DHSharedSecret { + fn as_ref(&self) -> &[u8] { + self.0.as_slice() + } +} + +#[derive( + Eq, + Debug, + Hash, + PartialEq, + Clone, + TlsDeserializeBytes, + TlsSerializeBytes, + TlsSize, + TlsSerialize, + TlsDeserialize, +)] +/// A wrapper around a KEM public key. +pub struct DHPublicKey(Vec); + +impl AsRef<[u8]> for DHPublicKey { + fn as_ref(&self) -> &[u8] { + self.0.as_slice() + } +} + +/// A wrapper around a KEM private key. +pub struct DHPrivateKey(Vec); + +impl AsRef<[u8]> for DHPrivateKey { + fn as_ref(&self) -> &[u8] { + self.0.as_slice() + } +} + +impl DHSharedSecret { + /// Derive a shared secret, DH-KEM style. + pub(crate) fn derive(sk: &DHPrivateKey, pk: &DHPublicKey) -> Result { + Ok(DHSharedSecret( + libcrux_ecdh::derive(Algorithm::X25519, &pk.0, &sk.0) + .map_err(|_| Error::CryptoError)?, + )) + } +} + +impl DHPrivateKey { + /// Creates a new KEM private key. + pub fn new(rng: &mut impl CryptoRng) -> Self { + Self( + libcrux_ecdh::generate_secret(libcrux_ecdh::Algorithm::X25519, rng) + .expect("Insufficient Randomness"), + ) + } + + /// Compute the KEM public key from the KEM private key. + pub fn to_public(&self) -> DHPublicKey { + DHPublicKey( + secret_to_public(libcrux_ecdh::Algorithm::X25519, &self.0) + .expect("secret key is honestly generated X25519 key"), + ) + } +} + +pub struct DHKeyPair { + pub(crate) sk: DHPrivateKey, + pub pk: DHPublicKey, +} + +impl DHKeyPair { + pub fn new(rng: &mut impl CryptoRng) -> Self { + let sk = DHPrivateKey::new(rng); + let pk = sk.to_public(); + Self { sk, pk } + } +} diff --git a/libcrux-psq/src/protocol/initiator.rs b/libcrux-psq/src/protocol/initiator.rs new file mode 100644 index 000000000..56a7942a8 --- /dev/null +++ b/libcrux-psq/src/protocol/initiator.rs @@ -0,0 +1,323 @@ +use std::{io::Cursor, mem::take}; + +use rand::CryptoRng; +use tls_codec::{ + Deserialize, Serialize, Size, TlsDeserialize, TlsSerialize, TlsSize, VLByteSlice, VLBytes, +}; + +use crate::protocol::MessageOut; + +use super::{ + api::{Error, IntoTransport, Protocol, Transport}, + dhkem::{DHKeyPair, DHPrivateKey, DHPublicKey}, + keys::{ + derive_k0, derive_k1, derive_k2_query_initiator, derive_k2_registration_initiator, AEADKey, + }, + pqkem::PQPublicKey, + responder::{ResponderQueryPayload, ResponderRegistrationPayload}, + transcript::{tx1, tx2, Transcript}, + write_output, Message, +}; + +pub struct QueryInitiator<'a> { + responder_longterm_ecdh_pk: &'a DHPublicKey, + initiator_ephemeral_keys: DHKeyPair, + tx0: Transcript, + k0: AEADKey, + outer_aad: &'a [u8], +} + +pub struct RegistrationInitiator<'a, Rng: CryptoRng> { + responder_longterm_ecdh_pk: &'a DHPublicKey, + responder_longterm_pq_pk: Option<&'a PQPublicKey>, + initiator_longterm_ecdh_keys: &'a DHKeyPair, + inner_aad: &'a [u8], + outer_aad: &'a [u8], + rng: Rng, + state: RegistrationInitiatorState, +} + +#[derive(TlsSerialize, TlsSize)] +#[repr(u8)] +pub enum InitiatorOuterPayloadOut<'a> { + Query(VLByteSlice<'a>), + Registration(MessageOut<'a>), +} + +#[derive(TlsDeserialize, TlsSize)] +pub struct InitiatorInnerPayload(pub VLBytes); + +#[derive(TlsSerialize, TlsSize)] +pub struct InitiatorInnerPayloadOut<'a>(pub VLByteSlice<'a>); + +pub struct InitialState { + initiator_ephemeral_keys: DHKeyPair, + tx0: Transcript, + k0: AEADKey, +} + +pub struct WaitingState { + initiator_ephemeral_ecdh_sk: DHPrivateKey, + tx1: Transcript, + k1: AEADKey, +} + +pub struct ToTransportState { + tx2: Transcript, + k2: AEADKey, +} + +#[derive(Default)] +pub enum RegistrationInitiatorState { + #[default] + InProgress, // A placeholder while computing the next state + Initial(Box), + Waiting(Box), + ToTransport(Box), +} + +impl<'a> QueryInitiator<'a> { + /// Create a new [`QueryInitiator`]. + pub(crate) fn new( + responder_longterm_ecdh_pk: &'a DHPublicKey, + ctx: &[u8], + outer_aad: &'a [u8], + mut rng: impl CryptoRng, + ) -> Result { + let initiator_ephemeral_keys = DHKeyPair::new(&mut rng); + + let (tx0, k0) = derive_k0( + responder_longterm_ecdh_pk, + &initiator_ephemeral_keys.pk, + &initiator_ephemeral_keys.sk, + ctx, + false, + )?; + + Ok(Self { + responder_longterm_ecdh_pk, + tx0, + k0, + outer_aad, + initiator_ephemeral_keys, + }) + } + + fn read_response(&self, responder_msg: &Message) -> Result { + let tx2 = tx2(&self.tx0, &responder_msg.pk)?; + + let mut k2 = derive_k2_query_initiator( + &self.k0, + &responder_msg.pk, + &self.initiator_ephemeral_keys.sk, + self.responder_longterm_ecdh_pk, + &tx2, + )?; + + k2.decrypt_deserialize( + responder_msg.ciphertext.as_slice(), + &responder_msg.tag, + responder_msg.aad.as_slice(), + ) + } +} + +impl<'a, Rng: CryptoRng> RegistrationInitiator<'a, Rng> { + /// Create a new [`RegistrationInitiator`]. + pub(crate) fn new( + initiator_longterm_ecdh_keys: &'a DHKeyPair, + responder_longterm_ecdh_pk: &'a DHPublicKey, + responder_longterm_pq_pk: Option<&'a PQPublicKey>, + ctx: &[u8], + inner_aad: &'a [u8], + outer_aad: &'a [u8], + mut rng: Rng, + ) -> Result { + let initiator_ephemeral_keys = DHKeyPair::new(&mut rng); + + let (tx0, k0) = derive_k0( + responder_longterm_ecdh_pk, + &initiator_ephemeral_keys.pk, + &initiator_ephemeral_keys.sk, + ctx, + false, + )?; + + let state = RegistrationInitiatorState::Initial( + InitialState { + tx0, + k0, + initiator_ephemeral_keys, + } + .into(), + ); + + Ok(Self { + responder_longterm_ecdh_pk, + responder_longterm_pq_pk, + initiator_longterm_ecdh_keys, + inner_aad, + outer_aad, + rng, + state, + }) + } +} + +impl<'a> Protocol for QueryInitiator<'a> { + fn write_message(&mut self, payload: &[u8], out: &mut [u8]) -> Result { + let outer_payload = InitiatorOuterPayloadOut::Query(VLByteSlice(payload)); + let (ciphertext, tag) = self.k0.serialize_encrypt(&outer_payload, self.outer_aad)?; + + let msg = MessageOut { + pk: &self.initiator_ephemeral_keys.pk, + ciphertext: VLByteSlice(&ciphertext), + tag, + aad: VLByteSlice(self.outer_aad), + pq_encapsulation: None, + }; + + msg.tls_serialize(&mut &mut out[..]) + .map_err(Error::Serialize) + } + + fn read_message( + &mut self, + message_bytes: &[u8], + out: &mut [u8], + ) -> Result<(usize, usize), Error> { + let msg = Message::tls_deserialize(&mut Cursor::new(&message_bytes[..])) + .map_err(Error::Deserialize)?; + + let result = self.read_response(&msg)?; + let out_bytes_written = write_output(result.0.as_slice(), out)?; + + Ok((msg.tls_serialized_len(), out_bytes_written)) + } +} + +impl<'a, Rng: CryptoRng> Protocol for RegistrationInitiator<'a, Rng> { + fn write_message(&mut self, payload: &[u8], out: &mut [u8]) -> Result { + let out_bytes_written; + + let RegistrationInitiatorState::Initial(mut state) = take(&mut self.state) else { + // If we're not in the initial state, we write nothing + return Ok(0); + }; + + let pq_encaps_pair = self + .responder_longterm_pq_pk + .map(|pk| pk.encapsulate(&mut self.rng)); + + let (pq_encapsulation, pq_shared_secret) = + if let Some((pq_encaps, pq_shared_secret)) = pq_encaps_pair { + (Some(pq_encaps), Some(pq_shared_secret)) + } else { + (None, None) + }; + + let tx1 = tx1( + &state.tx0, + &self.initiator_longterm_ecdh_keys.pk, + self.responder_longterm_pq_pk, + pq_encapsulation.as_ref(), + )?; + + let mut k1 = derive_k1( + &state.k0, + &self.initiator_longterm_ecdh_keys.sk, + self.responder_longterm_ecdh_pk, + &pq_shared_secret, + &tx1, + )?; + + let inner_payload = InitiatorInnerPayloadOut(VLByteSlice(payload)); + let (inner_ciphertext, inner_tag) = k1.serialize_encrypt(&inner_payload, self.inner_aad)?; + + let outer_payload = InitiatorOuterPayloadOut::Registration(MessageOut { + pk: &self.initiator_longterm_ecdh_keys.pk, + ciphertext: VLByteSlice(&inner_ciphertext), + tag: inner_tag, + aad: VLByteSlice(self.inner_aad), + pq_encapsulation: pq_encapsulation.as_ref(), + }); + let (outer_ciphertext, outer_tag) = + state.k0.serialize_encrypt(&outer_payload, self.outer_aad)?; + + let msg = MessageOut { + pk: &state.initiator_ephemeral_keys.pk, + ciphertext: VLByteSlice(&outer_ciphertext), + tag: outer_tag, + aad: VLByteSlice(self.outer_aad), + pq_encapsulation: None, + }; + + out_bytes_written = msg + .tls_serialize(&mut &mut out[..]) + .map_err(Error::Serialize)?; + + self.state = RegistrationInitiatorState::Waiting( + WaitingState { + initiator_ephemeral_ecdh_sk: state.initiator_ephemeral_keys.sk, + tx1, + k1, + } + .into(), + ); + + Ok(out_bytes_written) + } + + fn read_message( + &mut self, + message_bytes: &[u8], + out: &mut [u8], + ) -> Result<(usize, usize), Error> { + let RegistrationInitiatorState::Waiting(state) = take(&mut self.state) else { + // If we're not in the waiting state, we do nothing. + return Ok((0, 0)); + }; + + // Deserialize the message. + let responder_msg = Message::tls_deserialize(&mut Cursor::new(&message_bytes)) + .map_err(Error::Deserialize)?; + let bytes_deserialized = responder_msg.tls_serialized_len(); + + // Derive K2 + let tx2 = tx2(&state.tx1, &responder_msg.pk)?; + let mut k2 = derive_k2_registration_initiator( + &state.k1, + &tx2, + &self.initiator_longterm_ecdh_keys.sk, + &state.initiator_ephemeral_ecdh_sk, + &responder_msg.pk, + )?; + + // Decrypt Payload + let registration_response: ResponderRegistrationPayload = k2.decrypt_deserialize( + responder_msg.ciphertext.as_slice(), + &responder_msg.tag, + responder_msg.aad.as_slice(), + )?; + + let out_bytes_written = write_output(registration_response.0.as_slice(), out)?; + + self.state = RegistrationInitiatorState::ToTransport(ToTransportState { tx2, k2 }.into()); + + Ok((bytes_deserialized, out_bytes_written)) + } +} + +impl<'a, Rng: CryptoRng> IntoTransport for RegistrationInitiator<'a, Rng> { + fn into_transport_mode(self) -> Result { + let RegistrationInitiatorState::ToTransport(state) = self.state else { + return Err(Error::InitiatorState); + }; + + Transport::new(state.tx2, state.k2) + } + + fn is_handshake_finished(&self) -> bool { + matches!(self.state, RegistrationInitiatorState::ToTransport(_)) + } +} diff --git a/libcrux-psq/src/protocol/keys.rs b/libcrux-psq/src/protocol/keys.rs new file mode 100644 index 000000000..a15611ce0 --- /dev/null +++ b/libcrux-psq/src/protocol/keys.rs @@ -0,0 +1,296 @@ +use libcrux_chacha20poly1305::{decrypt_detached, encrypt_detached, KEY_LEN, NONCE_LEN}; +use libcrux_hkdf::Algorithm; +use tls_codec::{Deserialize, Serialize, SerializeBytes, TlsSerializeBytes, TlsSize}; + +use crate::protocol::pqkem::PQSharedSecret; + +use super::{ + api::Error, + dhkem::{DHPrivateKey, DHPublicKey, DHSharedSecret}, + session::{SessionKey, SESSION_ID_LENGTH}, + transcript::{self, Transcript}, +}; + +#[derive(Default, Clone, TlsSerializeBytes, TlsSize)] +pub struct AEADKey([u8; KEY_LEN], #[tls_codec(skip)] [u8; NONCE_LEN]); + +impl std::fmt::Debug for AEADKey { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_tuple("AEADKey").field(&"***").finish() + } +} + +fn serialize_error(e: tls_codec::Error) -> Error { + Error::Serialize(e) +} + +impl AEADKey { + fn new(ikm: &impl SerializeBytes, info: &impl SerializeBytes) -> Result { + let prk = libcrux_hkdf::extract( + Algorithm::Sha256, + [], + ikm.tls_serialize().map_err(serialize_error)?, + ) + .map_err(|_| Error::CryptoError)?; + + Ok(AEADKey( + libcrux_hkdf::expand( + Algorithm::Sha256, + prk, + info.tls_serialize().map_err(serialize_error)?, + KEY_LEN, + ) + .map_err(|_| Error::CryptoError)? + .try_into() + .map_err(|_| Error::CryptoError)?, // We don't expect this to fail, unless HDKF gave us the wrong output length, + [0u8; NONCE_LEN], + )) + } + + fn increment_nonce(&mut self) -> Result<(), Error> { + if self.1 == [0xff; NONCE_LEN] { + return Err(Error::CryptoError); + } + let mut buf = [0u8; 16]; + buf[16 - NONCE_LEN..].copy_from_slice(self.1.as_slice()); + let mut nonce = u128::from_be_bytes(buf); + nonce += 1; + let buf = nonce.to_be_bytes(); + self.1.copy_from_slice(&buf[16 - NONCE_LEN..]); + Ok(()) + } + + pub(crate) fn encrypt( + &mut self, + payload: &[u8], + aad: &[u8], + ciphertext: &mut [u8], + ) -> Result<[u8; 16], crate::protocol::api::Error> { + let mut tag = [0u8; 16]; + + self.increment_nonce()?; + + // XXX: We could do better if we'd have an inplace API here. + let _ = encrypt_detached(&self.0, payload, ciphertext, &mut tag, aad, &self.1) + .map_err(|_| Error::CryptoError)?; + + Ok(tag) + } + + pub(crate) fn serialize_encrypt( + &mut self, + payload: &T, + aad: &[u8], + ) -> Result<(Vec, [u8; 16]), crate::protocol::api::Error> { + let serialization_buffer = payload.tls_serialize_detached().map_err(Error::Serialize)?; + + let mut ciphertext = vec![0u8; serialization_buffer.len()]; + let tag = self.encrypt(&serialization_buffer, aad, &mut ciphertext)?; + + Ok((ciphertext, tag)) + } + + pub(crate) fn decrypt( + &mut self, + ciphertext: &[u8], + tag: &[u8; 16], + aad: &[u8], + ) -> Result, Error> { + self.increment_nonce()?; + let mut plaintext = vec![0u8; ciphertext.len()]; + + let _ = decrypt_detached(&self.0, &mut plaintext, ciphertext, tag, aad, &self.1) + .map_err(|_| Error::CryptoError)?; + + Ok(plaintext) + } + + pub(crate) fn decrypt_deserialize( + &mut self, + ciphertext: &[u8], + tag: &[u8; 16], + aad: &[u8], + ) -> Result { + let payload_serialized_buf = self.decrypt(ciphertext, tag, aad)?; + + T::tls_deserialize_exact(&payload_serialized_buf).map_err(Error::Deserialize) + } +} +impl AsRef<[u8; KEY_LEN]> for AEADKey { + fn as_ref(&self) -> &[u8; KEY_LEN] { + &self.0 + } +} + +#[derive(TlsSerializeBytes, TlsSize)] +struct K0Ikm<'a> { + g_xs: &'a DHSharedSecret, +} + +const SESSION_KEY_INFO: &[u8] = b"shared key id"; + +// id_skCS = KDF(skCS, "shared key id") +fn session_key_id(key: &AEADKey) -> Result<[u8; SESSION_ID_LENGTH], Error> { + let prk = libcrux_hkdf::extract( + Algorithm::Sha256, + [], + key.tls_serialize().map_err(serialize_error)?, + ) + .map_err(|_| Error::CryptoError)?; + + Ok( + libcrux_hkdf::expand(Algorithm::Sha256, prk, SESSION_KEY_INFO, SESSION_ID_LENGTH) + .map_err(|_| Error::CryptoError)? + .try_into() + .map_err(|_| Error::CryptoError)?, // We don't expect this to fail, unless HDKF gave us the wrong output length + ) +} + +// skCS = KDF(K2, "shared secret" | tx2) +pub(super) fn derive_session_key(k2: AEADKey, tx2: Transcript) -> Result { + #[derive(TlsSerializeBytes, TlsSize)] + struct SessionKeyInfo<'a> { + domain_separator: &'static [u8], + tx2: &'a Transcript, + } + + const SHARED_KEY_LABEL: &'static [u8] = b"shared key"; + let key = AEADKey::new( + &k2, + &SessionKeyInfo { + domain_separator: SHARED_KEY_LABEL, + tx2: &tx2, + }, + )?; + let identifier = session_key_id(&key)?; + Ok(SessionKey { key, identifier }) +} + +// K0 = KDF(g^xs, tx0) +pub(super) fn derive_k0( + peer_pk: &DHPublicKey, + own_pk: &DHPublicKey, + own_sk: &DHPrivateKey, + ctx: &[u8], + is_responder: bool, +) -> Result<(Transcript, AEADKey), Error> { + let tx0 = if is_responder { + transcript::tx0(ctx, own_pk, peer_pk)? + } else { + transcript::tx0(ctx, peer_pk, own_pk)? + }; + let ikm = K0Ikm { + g_xs: &DHSharedSecret::derive(own_sk, peer_pk)?, + }; + + Ok((tx0, AEADKey::new(&ikm, &tx0)?)) +} + +// K1 = KDF(K0 | g^cs | SS, tx1) +pub(super) fn derive_k1( + k0: &AEADKey, + own_longterm_key: &DHPrivateKey, + peer_longterm_pk: &DHPublicKey, + pq_shared_secret: &Option, + tx1: &Transcript, +) -> Result { + #[derive(TlsSerializeBytes, TlsSize)] + struct K1Ikm<'a, 'b, 'c> { + k0: &'a AEADKey, + ecdh_shared_secret: &'b DHSharedSecret, + pq_shared_secret: &'c Option, + } + + let ecdh_shared_secret = DHSharedSecret::derive(own_longterm_key, peer_longterm_pk)?; + + AEADKey::new( + &K1Ikm { + k0, + ecdh_shared_secret: &ecdh_shared_secret, + pq_shared_secret, + }, + &tx1, + ) +} + +#[derive(TlsSerializeBytes, TlsSize)] +struct K2IkmQuery<'a> { + k0: &'a AEADKey, + g_xs: &'a DHSharedSecret, + g_xy: &'a DHSharedSecret, +} + +#[derive(TlsSerializeBytes, TlsSize)] +struct K2IkmRegistration<'a, 'b> { + k1: &'a AEADKey, + g_cy: &'b DHSharedSecret, + g_xy: &'b DHSharedSecret, +} + +// K2 = KDF(K1 | g^cy | g^xy, tx2) +pub(super) fn derive_k2_registration_responder( + k1: &AEADKey, + tx2: &Transcript, + initiator_longterm_pk: &DHPublicKey, + initiator_ephemeral_pk: &DHPublicKey, + responder_ephemeral_sk: &DHPrivateKey, +) -> Result { + let responder_ikm = K2IkmRegistration { + k1, + g_cy: &DHSharedSecret::derive(responder_ephemeral_sk, initiator_longterm_pk)?, + g_xy: &DHSharedSecret::derive(responder_ephemeral_sk, initiator_ephemeral_pk)?, + }; + + Ok(AEADKey::new(&responder_ikm, tx2)?) +} + +// K2 = KDF(K1 | g^cy | g^xy, tx2) +pub(super) fn derive_k2_registration_initiator( + k1: &AEADKey, + tx2: &Transcript, + initiator_longterm_sk: &DHPrivateKey, + initiator_ephemeral_sk: &DHPrivateKey, + responder_ephemeral_pk: &DHPublicKey, +) -> Result { + let responder_ikm = K2IkmRegistration { + k1, + g_cy: &DHSharedSecret::derive(initiator_longterm_sk, responder_ephemeral_pk)?, + g_xy: &DHSharedSecret::derive(initiator_ephemeral_sk, responder_ephemeral_pk)?, + }; + + AEADKey::new(&responder_ikm, tx2) +} + +// K2 = KDF(K0 | g^xs | g^xy, tx2) +pub(super) fn derive_k2_query_responder( + k0: &AEADKey, + initiator_ephemeral_ecdh_pk: &DHPublicKey, + responder_ephemeral_ecdh_sk: &DHPrivateKey, + responder_longterm_ecdh_sk: &DHPrivateKey, + tx2: &Transcript, +) -> Result { + let responder_ikm = K2IkmQuery { + k0, + g_xs: &DHSharedSecret::derive(responder_longterm_ecdh_sk, initiator_ephemeral_ecdh_pk)?, + g_xy: &DHSharedSecret::derive(responder_ephemeral_ecdh_sk, initiator_ephemeral_ecdh_pk)?, + }; + + AEADKey::new(&responder_ikm, tx2) +} + +// K2 = KDF(K0 | g^xs | g^xy, tx2) +pub(super) fn derive_k2_query_initiator( + k0: &AEADKey, + responder_ephemeral_ecdh_pk: &DHPublicKey, + initiator_ephemeral_ecdh_sk: &DHPrivateKey, + responder_longterm_ecdh_pk: &DHPublicKey, + tx2: &Transcript, +) -> Result { + let initiator_ikm = K2IkmQuery { + k0, + g_xs: &DHSharedSecret::derive(initiator_ephemeral_ecdh_sk, responder_longterm_ecdh_pk)?, + g_xy: &DHSharedSecret::derive(initiator_ephemeral_ecdh_sk, responder_ephemeral_ecdh_pk)?, + }; + + AEADKey::new(&initiator_ikm, tx2) +} diff --git a/libcrux-psq/src/protocol/pqkem.rs b/libcrux-psq/src/protocol/pqkem.rs new file mode 100644 index 000000000..bf2d722d0 --- /dev/null +++ b/libcrux-psq/src/protocol/pqkem.rs @@ -0,0 +1,45 @@ +use libcrux_ml_kem::{ + mlkem768::{ + decapsulate, + rand::{encapsulate, generate_key_pair}, + MlKem768Ciphertext, MlKem768PrivateKey, MlKem768PublicKey, + }, + MlKemSharedSecret, +}; +use rand::CryptoRng; +use tls_codec::{TlsDeserialize, TlsSerialize, TlsSerializeBytes, TlsSize}; + +#[derive(TlsSerialize, TlsSize)] +pub struct PQPublicKey(MlKem768PublicKey); +pub struct PQPrivateKey(MlKem768PrivateKey); +pub struct PQKeyPair { + pub pk: PQPublicKey, + pub(crate) sk: PQPrivateKey, +} +#[derive(TlsSerialize, TlsDeserialize, TlsSize)] +pub struct PQCiphertext(MlKem768Ciphertext); +#[derive(TlsSerializeBytes, TlsSize)] +pub struct PQSharedSecret(MlKemSharedSecret); + +impl PQPublicKey { + pub(crate) fn encapsulate(&self, rng: &mut impl CryptoRng) -> (PQCiphertext, PQSharedSecret) { + let (ciphertext, shared_secret) = encapsulate(&self.0, rng); + (PQCiphertext(ciphertext), PQSharedSecret(shared_secret)) + } +} + +impl PQPrivateKey { + pub(crate) fn decapsulate(&self, enc: &PQCiphertext) -> PQSharedSecret { + PQSharedSecret(decapsulate(&self.0, &enc.0)) + } +} + +impl PQKeyPair { + pub fn new(rng: &mut impl CryptoRng) -> Self { + let (sk, pk) = generate_key_pair(rng).into_parts(); + PQKeyPair { + pk: PQPublicKey(pk), + sk: PQPrivateKey(sk), + } + } +} diff --git a/libcrux-psq/src/protocol/responder.rs b/libcrux-psq/src/protocol/responder.rs new file mode 100644 index 000000000..fa1bd9503 --- /dev/null +++ b/libcrux-psq/src/protocol/responder.rs @@ -0,0 +1,375 @@ +use std::{collections::VecDeque, io::Cursor, mem::take}; + +use rand::CryptoRng; +use tls_codec::{ + Deserialize, Serialize, Size, TlsDeserialize, TlsSerialize, TlsSize, VLByteSlice, VLBytes, +}; + +use crate::protocol::MessageOut; + +use super::{ + api::{Error, IntoTransport, Protocol, ToTransportState, Transport}, + dhkem::{DHKeyPair, DHPrivateKey, DHPublicKey}, + initiator::InitiatorInnerPayload, + keys::{ + derive_k0, derive_k1, derive_k2_query_responder, derive_k2_registration_responder, AEADKey, + }, + pqkem::PQKeyPair, + transcript::{tx1, tx2, Transcript}, + write_output, Message, +}; + +#[derive(TlsDeserialize, TlsSize)] +#[repr(u8)] +pub enum InitiatorOuterPayload { + Query(VLBytes), + Registration(Message), +} + +#[derive(Debug)] +pub(crate) struct RespondQueryState { + pub(crate) tx0: Transcript, + pub(crate) k0: AEADKey, + pub(crate) initiator_ephemeral_ecdh_pk: DHPublicKey, +} + +#[derive(Debug)] +pub(crate) struct RespondRegistrationState { + pub(crate) tx1: Transcript, + pub(crate) k1: AEADKey, + pub(crate) initiator_ephemeral_ecdh_pk: DHPublicKey, + pub(crate) initiator_longterm_ecdh_pk: DHPublicKey, +} + +#[derive(Default, Debug)] +pub(crate) enum ResponderState { + #[default] + InProgress, // A placeholder while computing the next state + Initial, + RespondQuery(Box), + RespondRegistration(Box), + ToTransport(Box), +} + +pub struct Responder<'a, Rng: CryptoRng> { + pub(crate) state: ResponderState, + recent_keys: VecDeque, + recent_keys_upper_bound: usize, + longterm_ecdh_keys: &'a DHKeyPair, + longterm_pq_keys: Option<&'a PQKeyPair>, + context: &'a [u8], + aad: &'a [u8], + rng: Rng, +} + +#[derive(TlsDeserialize, TlsSize)] +pub struct ResponderQueryPayload(pub VLBytes); + +#[derive(TlsSerialize, TlsSize)] +pub struct ResponderQueryPayloadOut<'a>(VLByteSlice<'a>); + +#[derive(TlsDeserialize, TlsSize)] +pub struct ResponderRegistrationPayload(pub VLBytes); + +#[derive(TlsSerialize, TlsSize)] +pub struct ResponderRegistrationPayloadOut<'a>(VLByteSlice<'a>); + +impl<'a, Rng: CryptoRng> Responder<'a, Rng> { + pub fn new( + longterm_ecdh_keys: &'a DHKeyPair, + longterm_pq_keys: Option<&'a PQKeyPair>, + context: &'a [u8], + aad: &'a [u8], + recent_keys_upper_bound: usize, + rng: Rng, + ) -> Self { + Self { + state: ResponderState::Initial {}, + longterm_ecdh_keys, + longterm_pq_keys, + context, + aad, + rng, + recent_keys: VecDeque::with_capacity(recent_keys_upper_bound), + recent_keys_upper_bound, + } + } + + fn derive_query_key( + &self, + tx0: &Transcript, + k0: &AEADKey, + responder_ephemeral_ecdh_pk: &DHPublicKey, + responder_ephemeral_ecdh_sk: &DHPrivateKey, + initiator_ephemeral_ecdh_pk: &DHPublicKey, + ) -> Result<(Transcript, AEADKey), Error> { + let tx2 = tx2(tx0, responder_ephemeral_ecdh_pk)?; + let k2 = derive_k2_query_responder( + k0, + initiator_ephemeral_ecdh_pk, + responder_ephemeral_ecdh_sk, + &self.longterm_ecdh_keys.sk, + &tx2, + )?; + + Ok((tx2, k2)) + } + + fn derive_registration_key( + &self, + tx1: &Transcript, + k1: &AEADKey, + responder_ephemeral_ecdh_pk: &DHPublicKey, + responder_ephemeral_ecdh_sk: &DHPrivateKey, + initiator_longterm_ecdh_pk: &DHPublicKey, + initiator_ephemeral_ecdh_pk: &DHPublicKey, + ) -> Result<(Transcript, AEADKey), Error> { + let tx2 = tx2(tx1, responder_ephemeral_ecdh_pk)?; + let k2 = derive_k2_registration_responder( + k1, + &tx2, + initiator_longterm_ecdh_pk, + initiator_ephemeral_ecdh_pk, + responder_ephemeral_ecdh_sk, + )?; + + Ok((tx2, k2)) + } + + fn decrypt_outer_message( + &self, + initiator_outer_message: &Message, + ) -> Result<(InitiatorOuterPayload, Transcript, AEADKey), Error> { + let (tx0, mut k0) = derive_k0( + &initiator_outer_message.pk, + &self.longterm_ecdh_keys.pk, + &self.longterm_ecdh_keys.sk, + self.context, + true, + )?; + + let initiator_payload: InitiatorOuterPayload = k0.decrypt_deserialize( + initiator_outer_message.ciphertext.as_slice(), + &initiator_outer_message.tag, + initiator_outer_message.aad.as_slice(), + )?; + + Ok((initiator_payload, tx0, k0)) + } + + fn decrypt_inner_message( + &self, + tx0: &Transcript, + k0: &AEADKey, + initiator_inner_message: &Message, + ) -> Result<(InitiatorInnerPayload, Transcript, AEADKey), Error> { + let pq_shared_secret = initiator_inner_message + .pq_encapsulation + .as_ref() + .zip(self.longterm_pq_keys) + .map(|(enc, longterm_pq_keys)| longterm_pq_keys.sk.decapsulate(enc)); + + let responder_pq_pk_opt = self.longterm_pq_keys.map(|keys| &keys.pk); + + let tx1 = tx1( + tx0, + &initiator_inner_message.pk, + responder_pq_pk_opt, + initiator_inner_message.pq_encapsulation.as_ref(), + )?; + + let mut k1 = derive_k1( + k0, + &self.longterm_ecdh_keys.sk, + &initiator_inner_message.pk, + &pq_shared_secret, + &tx1, + )?; + + let inner_payload: InitiatorInnerPayload = k1.decrypt_deserialize( + initiator_inner_message.ciphertext.as_slice(), + &initiator_inner_message.tag, + initiator_inner_message.aad.as_slice(), + )?; + + Ok((inner_payload, tx1, k1)) + } + + fn registration( + &mut self, + payload: &[u8], + out: &mut [u8], + responder_ephemeral_ecdh_sk: DHPrivateKey, + responder_ephemeral_ecdh_pk: DHPublicKey, + state: Box, + ) -> Result { + let (tx2, mut k2) = self.derive_registration_key( + &state.tx1, + &state.k1, + &responder_ephemeral_ecdh_pk, + &responder_ephemeral_ecdh_sk, + &state.initiator_longterm_ecdh_pk, + &state.initiator_ephemeral_ecdh_pk, + )?; + + let outer_payload = ResponderRegistrationPayloadOut(VLByteSlice(payload)); + let (ciphertext, tag) = k2.serialize_encrypt(&outer_payload, self.aad)?; + + let out_msg = MessageOut { + pk: &responder_ephemeral_ecdh_pk, + ciphertext: VLByteSlice(&ciphertext), + tag, + aad: VLByteSlice(self.aad), + pq_encapsulation: None, + }; + + let out_len = out_msg + .tls_serialize(&mut &mut out[..]) + .map_err(Error::Serialize)?; + self.state = ResponderState::ToTransport(ToTransportState { tx2, k2 }.into()); + + Ok(out_len) + } + + fn query( + &mut self, + payload: &[u8], + out: &mut [u8], + responder_ephemeral_ecdh_sk: DHPrivateKey, + responder_ephemeral_ecdh_pk: DHPublicKey, + state: Box, + ) -> Result { + let (_tx2, mut k2) = self.derive_query_key( + &state.tx0, + &state.k0, + &responder_ephemeral_ecdh_pk, + &responder_ephemeral_ecdh_sk, + &state.initiator_ephemeral_ecdh_pk, + )?; + + let outer_payload = ResponderQueryPayloadOut(VLByteSlice(payload)); + let (ciphertext, tag) = k2.serialize_encrypt(&outer_payload, self.aad)?; + + let out_msg = MessageOut { + pk: &responder_ephemeral_ecdh_pk, + ciphertext: VLByteSlice(&ciphertext), + tag, + aad: VLByteSlice(self.aad), + pq_encapsulation: None, + }; + + out_msg + .tls_serialize(&mut &mut out[..]) + .map_err(Error::Serialize)?; + self.state = ResponderState::Initial; + + Ok(out_msg.tls_serialized_len()) + } +} + +impl<'a, Rng: CryptoRng> Protocol for Responder<'a, Rng> { + fn write_message(&mut self, payload: &[u8], out: &mut [u8]) -> Result { + let mut out_bytes_written = 0; + let responder_ephemeral_ecdh_sk = DHPrivateKey::new(&mut self.rng); + let responder_ephemeral_ecdh_pk = responder_ephemeral_ecdh_sk.to_public(); + + let state = take(&mut self.state); + if let ResponderState::RespondQuery(state) = state { + out_bytes_written = self.query( + payload, + out, + responder_ephemeral_ecdh_sk, + responder_ephemeral_ecdh_pk, + state, + )?; + } else if let ResponderState::RespondRegistration(state) = state { + out_bytes_written = self.registration( + payload, + out, + responder_ephemeral_ecdh_sk, + responder_ephemeral_ecdh_pk, + state, + )?; + } + + Ok(out_bytes_written) + } + + fn read_message( + &mut self, + message_bytes: &[u8], + out: &mut [u8], + ) -> Result<(usize, usize), Error> { + if !matches!(self.state, ResponderState::Initial {}) { + return Ok((0, 0)); + } + + // Deserialize the outer message. + let initiator_outer_message = Message::tls_deserialize(&mut Cursor::new(&message_bytes)) + .map_err(Error::Deserialize)?; + let bytes_deserialized = initiator_outer_message.tls_serialized_len(); + + // Check that the ephemeral key was not in the most recent keys. + if self.recent_keys.contains(&initiator_outer_message.pk) { + return Ok((0, 0)); + } else { + if self.recent_keys.len() == self.recent_keys_upper_bound { + self.recent_keys.pop_back(); + } + self.recent_keys + .push_front(initiator_outer_message.pk.clone()); + } + + // Decrypt the outer message payload. + let (initiator_outer_payload, tx0, k0) = + self.decrypt_outer_message(&initiator_outer_message)?; + + match initiator_outer_payload { + InitiatorOuterPayload::Query(initiator_query_payload) => { + // We're ready to respond to the query message. + self.state = ResponderState::RespondQuery( + RespondQueryState { + tx0, + k0, + initiator_ephemeral_ecdh_pk: initiator_outer_message.pk, + } + .into(), + ); + let out_bytes_written = write_output(initiator_query_payload.as_slice(), out)?; + Ok((bytes_deserialized, out_bytes_written)) + } + + InitiatorOuterPayload::Registration(initiator_inner_message) => { + // Decrypt the inner message payload. + let (initiator_inner_payload, tx1, k1) = + self.decrypt_inner_message(&tx0, &k0, &initiator_inner_message)?; + // We're ready to respond to the registration message. + self.state = ResponderState::RespondRegistration( + RespondRegistrationState { + tx1, + k1, + initiator_ephemeral_ecdh_pk: initiator_outer_message.pk, + initiator_longterm_ecdh_pk: initiator_inner_message.pk, + } + .into(), + ); + let out_bytes_written = write_output(initiator_inner_payload.0.as_slice(), out)?; + Ok((bytes_deserialized, out_bytes_written)) + } + } + } +} + +impl<'a, Rng: CryptoRng> IntoTransport for Responder<'a, Rng> { + fn into_transport_mode(self) -> Result { + let ResponderState::ToTransport(state) = self.state else { + return Err(Error::ResponderState); + }; + + Transport::new(state.tx2, state.k2) + } + + fn is_handshake_finished(&self) -> bool { + matches!(self.state, ResponderState::ToTransport { .. }) + } +} diff --git a/libcrux-psq/src/protocol/session.rs b/libcrux-psq/src/protocol/session.rs new file mode 100644 index 000000000..3733d6191 --- /dev/null +++ b/libcrux-psq/src/protocol/session.rs @@ -0,0 +1,13 @@ +use super::keys::AEADKey; + +/// The length of a session ID in bytes. +pub const SESSION_ID_LENGTH: usize = 32; + +/// The length of a sessin key in bytes. +pub const SESSION_KEY_LENGTH: usize = 32; + +// XXX: Session storage to be implemented (cf. https://github.com/cryspen/libcrux/issues/1077) +pub struct SessionKey { + pub(crate) identifier: [u8; SESSION_ID_LENGTH], + pub(crate) key: AEADKey, +} diff --git a/libcrux-psq/src/protocol/signature.rs b/libcrux-psq/src/protocol/signature.rs new file mode 100644 index 000000000..de58f4c4b --- /dev/null +++ b/libcrux-psq/src/protocol/signature.rs @@ -0,0 +1,22 @@ +use tls_codec::{TlsDeserializeBytes, TlsSerializeBytes, TlsSize}; + +#[derive(Debug, Clone, PartialEq, TlsSerializeBytes, TlsDeserializeBytes, TlsSize)] +#[repr(u8)] +pub(crate) enum Signature { + Ed25519(Vec), + MlDsa(Vec), +} +#[derive(Debug, Clone, PartialEq, TlsSerializeBytes, TlsDeserializeBytes, TlsSize)] +#[repr(u8)] +pub(crate) enum VerificationKey { + Ed25519(Vec), + MlDsa(Vec), +} + +#[derive(Debug, PartialEq)] +pub(crate) enum SigningKey { + Ed25519(Vec), + MlDsa(Vec), +} + +pub(crate) type CredentialKeyPair = (SigningKey, VerificationKey); diff --git a/libcrux-psq/src/protocol/transcript.rs b/libcrux-psq/src/protocol/transcript.rs new file mode 100644 index 000000000..a5a5ac338 --- /dev/null +++ b/libcrux-psq/src/protocol/transcript.rs @@ -0,0 +1,103 @@ +use tls_codec::{Serialize, SerializeBytes, TlsSerialize, TlsSerializeBytes, TlsSize}; + +pub const TX0_DOMAIN_SEP: u8 = 0; +pub const TX1_DOMAIN_SEP: u8 = 1; +pub const TX2_DOMAIN_SEP: u8 = 2; + +use crate::protocol::pqkem::PQCiphertext; + +use super::{api::Error, dhkem::DHPublicKey, pqkem::PQPublicKey}; +use libcrux_sha2::{Digest, SHA256_LENGTH}; + +/// The initial transcript hash. +#[derive(Debug, Default, Clone, Copy, TlsSerializeBytes, TlsSize)] +pub struct Transcript([u8; SHA256_LENGTH]); + +impl Transcript { + fn new(initial_input: &impl Serialize) -> Result { + Self::add_hash::(None, initial_input) + } + + fn add_hash( + old_transcript: Option<&Transcript>, + input: &impl Serialize, + ) -> Result { + let mut hasher = libcrux_sha2::Sha256::new(); + hasher.update(&[DOMAIN_SEPARATOR]); + hasher.update( + old_transcript + .tls_serialize() + .map_err(|e| Error::Serialize(e))? + .as_slice(), + ); + hasher.update( + input + .tls_serialize_detached() + .map_err(|e| Error::Serialize(e))? + .as_slice(), + ); + + let mut digest = [0u8; 32]; + hasher.finish(&mut digest); + Ok(Transcript(digest)) + } +} + +impl AsRef<[u8]> for Transcript { + fn as_ref(&self) -> &[u8] { + self.0.as_slice() + } +} + +// tx0 = hash(0 | ctx | g^s | g^x) +pub(crate) fn tx0( + context: &[u8], + responder_pk: &DHPublicKey, + initiator_pk: &DHPublicKey, +) -> Result { + #[derive(TlsSerialize, TlsSize)] + struct Transcript0Inputs<'a, 'b, 'c> { + context: &'a [u8], + responder_pk: &'b DHPublicKey, + initiator_pk: &'c DHPublicKey, + } + + Transcript::new(&Transcript0Inputs { + context, + responder_pk, + initiator_pk, + }) +} + +// tx1 = hash(1 | tx0 | g^c | pkS | encap(pkS, SS)) +pub(crate) fn tx1( + tx0: &Transcript, + initiator_longterm_pk: &DHPublicKey, + responder_pq_pk: Option<&PQPublicKey>, + pq_encaps: Option<&PQCiphertext>, +) -> Result { + #[derive(TlsSerialize, TlsSize)] + struct Transcript1Inputs<'a, 'b, 'c> { + initiator_longterm_pk: &'a DHPublicKey, + responder_pq_pk: Option<&'b PQPublicKey>, + pq_encaps: Option<&'c PQCiphertext>, + } + + Transcript::add_hash::( + Some(tx0), + &Transcript1Inputs { + initiator_longterm_pk, + pq_encaps, + responder_pq_pk, + }, + ) +} + +// Registration Mode: tx2 = hash(2 | tx1 | g^y) +// Query Mode: tx2 = hash(2 | tx0 | g^y) +pub(crate) fn tx2( + prev_tx: &Transcript, + responder_ephemeral_pk: &DHPublicKey, +) -> Result { + Transcript::add_hash::(Some(prev_tx), responder_ephemeral_pk) +} diff --git a/libcrux-psq/tests/query.rs b/libcrux-psq/tests/query.rs new file mode 100644 index 000000000..8ed049773 --- /dev/null +++ b/libcrux-psq/tests/query.rs @@ -0,0 +1,74 @@ +use libcrux_psq::protocol::{api::Protocol, dhkem::DHKeyPair, pqkem::PQKeyPair, *}; + +#[test] +fn query() { + let mut rng = rand::rng(); + let ctx = b"Test Context"; + let aad_initiator = b"Test Data I"; + let aad_responder = b"Test Data R"; + + let mut msg_channel = vec![0u8; 4096]; + let mut payload_buf_responder = vec![0u8; 4096]; + let mut payload_buf_initiator = vec![0u8; 4096]; + + // External setup + let responder_ecdh_keys = DHKeyPair::new(&mut rng); + + let responder_pq_keys = PQKeyPair::new(&mut rng); + + // Setup initiator + let mut initiator = api::Builder::new(rand::rng()) + .outer_aad(aad_initiator) + .context(ctx) + .peer_longterm_ecdh_pk(&responder_ecdh_keys.pk) + .build_query_initiator() + .unwrap(); + + // Setup responder + let mut responder = api::Builder::new(rand::rng()) + .context(ctx) + .outer_aad(aad_responder) + .longterm_ecdh_keys(&responder_ecdh_keys) + .longterm_pq_keys(&responder_pq_keys) + .recent_keys_upper_bound(30) + .build_responder() + .unwrap(); + + // Send first message + let query_payload_initiator = b"Query_init"; + let len_i = initiator + .write_message(query_payload_initiator, &mut msg_channel) + .unwrap(); + + // Read first message + let (len_r_deserialized, len_r_payload) = responder + .read_message(&msg_channel, &mut payload_buf_responder) + .unwrap(); + + // We read the same amount of data. + assert_eq!(len_r_deserialized, len_i); + assert_eq!(len_r_payload, b"Query init".len()); + assert_eq!( + &payload_buf_responder[0..len_r_payload], + query_payload_initiator + ); + + // Respond + let query_payload_responder = b"Query_respond"; + let len_r = responder + .write_message(query_payload_responder, &mut msg_channel) + .unwrap(); + + // Finalize on query initiator + let (len_i_deserialized, len_i_payload) = initiator + .read_message(&msg_channel, &mut payload_buf_initiator) + .unwrap(); + + // We read the same amount of data. + assert_eq!(len_r, len_i_deserialized); + assert_eq!(query_payload_responder.len(), len_i_payload); + assert_eq!( + &payload_buf_initiator[0..len_i_payload], + query_payload_responder + ); +} diff --git a/libcrux-psq/tests/registration.rs b/libcrux-psq/tests/registration.rs new file mode 100644 index 000000000..1ba0addc9 --- /dev/null +++ b/libcrux-psq/tests/registration.rs @@ -0,0 +1,132 @@ +use libcrux_psq::protocol::{ + api::{IntoTransport, Protocol}, + dhkem::DHKeyPair, + pqkem::PQKeyPair, + *, +}; + +fn registration(pq: bool) { + let mut rng = rand::rng(); + let ctx = b"Test Context"; + let aad_initiator_outer = b"Test Data I Outer"; + let aad_initiator_inner = b"Test Data I Inner"; + let aad_responder = b"Test Data R"; + + let mut msg_channel = vec![0u8; 4096]; + let mut payload_buf_responder = vec![0u8; 4096]; + let mut payload_buf_initiator = vec![0u8; 4096]; + + // External setup + let responder_pq_keys = PQKeyPair::new(&mut rng); + + let responder_ecdh_keys = DHKeyPair::new(&mut rng); + let initiator_ecdh_keys = DHKeyPair::new(&mut rng); + + // Setup initiator + let mut initiator = api::Builder::new(rand::rng()) + .outer_aad(aad_initiator_outer) + .inner_aad(aad_initiator_inner) + .context(ctx) + .longterm_ecdh_keys(&initiator_ecdh_keys) + .peer_longterm_ecdh_pk(&responder_ecdh_keys.pk); + + if pq { + initiator = initiator.peer_longterm_pq_pk(&responder_pq_keys.pk); + } + + let mut initiator = initiator.build_registration_initiator().unwrap(); + + // Setup responder + let mut builder = api::Builder::new(rand::rng()) + .context(ctx) + .outer_aad(aad_responder) + .longterm_ecdh_keys(&responder_ecdh_keys) + .recent_keys_upper_bound(30); + if pq { + builder = builder.longterm_pq_keys(&responder_pq_keys); + } + let mut responder = builder.build_responder().unwrap(); + + // Send first message + let registration_payload_initiator = b"Registration_init"; + let len_i = initiator + .write_message(registration_payload_initiator, &mut msg_channel) + .unwrap(); + + // Read first message + let (len_r_deserialized, len_r_payload) = responder + .read_message(&msg_channel, &mut payload_buf_responder) + .unwrap(); + + // We read the same amount of data. + assert_eq!(len_r_deserialized, len_i); + assert_eq!(len_r_payload, registration_payload_initiator.len()); + assert_eq!( + &payload_buf_responder[0..len_r_payload], + registration_payload_initiator + ); + + // Respond + let registration_payload_responder = b"Registration_respond"; + let len_r = responder + .write_message(registration_payload_responder, &mut msg_channel) + .unwrap(); + + // Finalize on registration initiator + let (len_i_deserialized, len_i_payload) = initiator + .read_message(&msg_channel, &mut payload_buf_initiator) + .unwrap(); + + // We read the same amount of data. + assert_eq!(len_r, len_i_deserialized); + assert_eq!(registration_payload_responder.len(), len_i_payload); + assert_eq!( + &payload_buf_initiator[0..len_i_payload], + registration_payload_responder + ); + + // Ready for transport mode + assert!(initiator.is_handshake_finished()); + assert!(responder.is_handshake_finished()); + + let mut i_transport = initiator.into_transport_mode().unwrap(); + let mut r_transport = responder.into_transport_mode().unwrap(); + + let app_data_i = b"Derived session hey".as_slice(); + let app_data_r = b"Derived session ho".as_slice(); + + let len_i = i_transport + .write_message(app_data_i, &mut msg_channel) + .unwrap(); + + let (len_r_deserialized, len_r_payload) = r_transport + .read_message(&msg_channel, &mut payload_buf_responder) + .unwrap(); + + // We read the same amount of data. + assert_eq!(len_r_deserialized, len_i); + assert_eq!(len_r_payload, app_data_i.len()); + assert_eq!(&payload_buf_responder[0..len_r_payload], app_data_i); + + let len_r = r_transport + .write_message(app_data_r, &mut msg_channel) + .unwrap(); + + let (len_i_deserialized, len_i_payload) = i_transport + .read_message(&msg_channel, &mut payload_buf_initiator) + .unwrap(); + + assert_eq!(len_r, len_i_deserialized); + assert_eq!(app_data_r.len(), len_i_payload); + assert_eq!(&payload_buf_initiator[0..len_i_payload], app_data_r); +} + +#[test] +fn registration_pq() { + registration(true); +} + +#[test] +fn registration_classic() { + registration(false); +}