Skip to content

Add and implement KEM traits #1053

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
6 changes: 6 additions & 0 deletions curve25519/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,9 @@ repository.workspace = true
[dependencies]
libcrux-hacl-rs = { version = "=0.0.3", path = "../hacl-rs/" }
libcrux-macros = { version = "=0.0.3", path = "../macros" }
libcrux-traits = { version = "=0.0.3", path = "../traits" }

[dev-dependencies]
libcrux-traits = { version = "0.0.3", path = "../traits", features = [
"generic-tests",
] }
8 changes: 4 additions & 4 deletions curve25519/src/impl_hacl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,29 @@ impl Curve25519 for HaclCurve25519 {
// The hacl::ecdh function requires all parameters to be 32 byte long, which we enforce using
// types.
#[inline(always)]
fn secret_to_public(pk: &mut [u8; PK_LEN], sk: &[u8; SK_LEN]) {
fn secret_to_public(pk: &mut [u8; EK_LEN], sk: &[u8; DK_LEN]) {
secret_to_public(pk, sk)
}

// The hacl::ecdh function requires all parameters to be 32 byte long, which we enforce using
// types.
#[inline(always)]
fn ecdh(out: &mut [u8; SHK_LEN], pk: &[u8; PK_LEN], sk: &[u8; SK_LEN]) -> Result<(), Error> {
fn ecdh(out: &mut [u8; SS_LEN], pk: &[u8; EK_LEN], sk: &[u8; DK_LEN]) -> Result<(), Error> {
ecdh(out, pk, sk)
}
}

// The hacl::ecdh function requires all parameters to be 32 byte long, which we enforce using
// types.
#[inline(always)]
pub fn secret_to_public(pk: &mut [u8; PK_LEN], sk: &[u8; SK_LEN]) {
pub fn secret_to_public(pk: &mut [u8; EK_LEN], sk: &[u8; DK_LEN]) {
crate::hacl::secret_to_public(pk, sk)
}

// The hacl::ecdh function requires all parameters to be 32 byte long, which we enforce using
// types.
#[inline(always)]
pub fn ecdh(out: &mut [u8; SHK_LEN], pk: &[u8; PK_LEN], sk: &[u8; SK_LEN]) -> Result<(), Error> {
pub fn ecdh(out: &mut [u8; SS_LEN], pk: &[u8; EK_LEN], sk: &[u8; DK_LEN]) -> Result<(), Error> {
match crate::hacl::ecdh(out, sk, pk) {
true => Ok(()),
false => Err(Error),
Expand Down
56 changes: 51 additions & 5 deletions curve25519/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@ mod impl_hacl;
pub use impl_hacl::{ecdh, secret_to_public};

/// The length of Curve25519 secret keys.
pub const SK_LEN: usize = 32;
pub const DK_LEN: usize = 32;

/// The length of Curve25519 public keys.
pub const PK_LEN: usize = 32;
pub const EK_LEN: usize = 32;

/// The length of Curve25519 shared keys.
pub const SHK_LEN: usize = 32;
pub const SS_LEN: usize = 32;

/// Indicates that an error occurred
pub struct Error;
Expand All @@ -23,9 +23,55 @@ pub struct Error;
#[allow(dead_code)]
trait Curve25519 {
/// Computes a public key from a secret key.
fn secret_to_public(pk: &mut [u8; PK_LEN], sk: &[u8; SK_LEN]);
fn secret_to_public(pk: &mut [u8; EK_LEN], sk: &[u8; DK_LEN]);

/// Computes the scalar multiplication between the provided public and secret keys. Returns an
/// error if the result is 0.
fn ecdh(out: &mut [u8; SHK_LEN], pk: &[u8; PK_LEN], sk: &[u8; SK_LEN]) -> Result<(), Error>;
fn ecdh(out: &mut [u8; SS_LEN], pk: &[u8; EK_LEN], sk: &[u8; DK_LEN]) -> Result<(), Error>;
}

pub struct X25519;

impl libcrux_traits::kem::arrayref::Kem<DK_LEN, EK_LEN, EK_LEN, SS_LEN, DK_LEN, DK_LEN> for X25519 {
fn keygen(
ek: &mut [u8; DK_LEN],
dk: &mut [u8; EK_LEN],
rand: &[u8; DK_LEN],
) -> Result<(), libcrux_traits::kem::arrayref::KeyGenError> {
dk.copy_from_slice(rand);
clamp(dk);
secret_to_public(ek, dk);
Ok(())
}

fn encaps(
ct: &mut [u8; EK_LEN],
ss: &mut [u8; SS_LEN],
ek: &[u8; EK_LEN],
rand: &[u8; DK_LEN],
) -> Result<(), libcrux_traits::kem::arrayref::EncapsError> {
let mut eph_dk = *rand;
clamp(&mut eph_dk);
secret_to_public(ct, &eph_dk);

ecdh(ss, ek, &eph_dk).map_err(|_| libcrux_traits::kem::arrayref::EncapsError::Unknown)
}

fn decaps(
ss: &mut [u8; SS_LEN],
ct: &[u8; DK_LEN],
dk: &[u8; EK_LEN],
) -> Result<(), libcrux_traits::kem::arrayref::DecapsError> {
ecdh(ss, ct, dk).map_err(|_| libcrux_traits::kem::arrayref::DecapsError::Unknown)
}
}

libcrux_traits::kem::slice::impl_trait!(X25519 => EK_LEN, DK_LEN, EK_LEN, EK_LEN, DK_LEN, DK_LEN);

/// Clamp a scalar.
fn clamp(scalar: &mut [u8; DK_LEN]) {
// We clamp the key already to make sure it can't be misused.
scalar[0] &= 248u8;
scalar[31] &= 127u8;
scalar[31] |= 64u8;
}
5 changes: 5 additions & 0 deletions libcrux-kem/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ libcrux-ml-kem = { version = "=0.0.3", path = "../libcrux-ml-kem", default-featu
] }
libcrux-sha3 = { version = "=0.0.3", path = "../libcrux-sha3" }
libcrux-ecdh = { version = "=0.0.3", path = "../libcrux-ecdh", default-features = false }
libcrux-curve25519 = { version = "=0.0.3", path = "../curve25519", default-features = false }
libcrux-p256 = { version = "=0.0.3", path = "../p256", default-features = false }
libcrux-traits = { version = "=0.0.3", path = "../traits" }
rand = { version = "0.9", default-features = false }

Expand All @@ -33,3 +35,6 @@ pre-verification = []
libcrux-kem = { path = "./", features = ["tests"] }
rand = { version = "0.9", features = ["os_rng"] }
hex = { version = "0.4.3", features = ["serde"] }
libcrux-traits = { version = "0.0.3", path = "../traits", features = [
"generic-tests",
] }
135 changes: 135 additions & 0 deletions libcrux-kem/src/kem.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,13 @@ pub mod deterministic {
pub use libcrux_ml_kem::mlkem768::generate_key_pair as mlkem768_generate_keypair_derand;
}

pub use libcrux_curve25519::X25519;
pub use libcrux_ml_kem::mlkem1024::MlKem1024;
pub use libcrux_ml_kem::mlkem512::MlKem512;
pub use libcrux_ml_kem::mlkem768::MlKem768;
pub use libcrux_p256::P256;
pub use xwing::XWing;

use libcrux_ml_kem::MlKemSharedSecret;
pub use libcrux_ml_kem::{
mlkem1024::{MlKem1024Ciphertext, MlKem1024PrivateKey, MlKem1024PublicKey},
Expand Down Expand Up @@ -883,6 +890,134 @@ mod xwing {

use super::*;

const MLKEM768_EK_LEN: usize = libcrux_ml_kem::mlkem768::MlKem768PublicKey::len();
const MLKEM768_DK_LEN: usize = libcrux_ml_kem::mlkem768::MlKem768PrivateKey::len();
const MLKEM768_CT_LEN: usize = libcrux_ml_kem::mlkem768::MlKem768Ciphertext::len();
const MLKEM768_SS_LEN: usize = libcrux_ml_kem::SHARED_SECRET_SIZE;
const MLKEM768_RAND_KEYGEN_LEN: usize = libcrux_ml_kem::KEY_GENERATION_SEED_SIZE;
const MLKEM768_RAND_ENCAPS_LEN: usize = MLKEM768_SS_LEN;

const X25519_EK_LEN: usize = libcrux_curve25519::EK_LEN;
const X25519_DK_LEN: usize = libcrux_curve25519::DK_LEN;
const X25519_CT_LEN: usize = X25519_EK_LEN;
const X25519_RAND_KEYGEN_LEN: usize = X25519_DK_LEN;
const X25519_RAND_ENCAPS_LEN: usize = X25519_DK_LEN;

const EK_LEN: usize = MLKEM768_EK_LEN + X25519_EK_LEN;
const DK_LEN: usize = MLKEM768_DK_LEN + X25519_DK_LEN;
const CT_LEN: usize = MLKEM768_CT_LEN + X25519_CT_LEN;
const SS_LEN: usize = 32;
const RAND_KEYGEN_LEN: usize = 32; // gets expanded later
const RAND_ENCAPS_LEN: usize = MLKEM768_RAND_ENCAPS_LEN + X25519_RAND_ENCAPS_LEN;

use libcrux_curve25519::X25519;
use libcrux_ml_kem::mlkem768::MlKem768;

pub struct XWing;

impl
libcrux_traits::kem::arrayref::Kem<
EK_LEN,
DK_LEN,
CT_LEN,
SS_LEN,
RAND_KEYGEN_LEN,
RAND_ENCAPS_LEN,
> for XWing
{
fn keygen(
ek: &mut [u8; EK_LEN],
dk: &mut [u8; DK_LEN],
rand: &[u8; RAND_KEYGEN_LEN],
) -> Result<(), libcrux_traits::kem::owned::KeyGenError> {
let expanded: [u8; MLKEM768_RAND_KEYGEN_LEN + X25519_RAND_KEYGEN_LEN] =
libcrux_sha3::shake256(rand);

let (rand_m, rand_x) = expanded.split_at(MLKEM768_RAND_KEYGEN_LEN);
let rand_m: &[u8; MLKEM768_RAND_KEYGEN_LEN] = rand_m.try_into().unwrap();
let rand_x: &[u8; X25519_RAND_KEYGEN_LEN] = rand_x.try_into().unwrap();

let (ek_m, ek_x) = ek.split_at_mut(MLKEM768_EK_LEN);
let ek_m: &mut [u8; MLKEM768_EK_LEN] = ek_m.try_into().unwrap();
let ek_x: &mut [u8; X25519_EK_LEN] = ek_x.try_into().unwrap();

let (dk_m, dk_x) = dk.split_at_mut(MLKEM768_DK_LEN);
let dk_m: &mut [u8; MLKEM768_DK_LEN] = dk_m.try_into().unwrap();
let dk_x: &mut [u8; X25519_DK_LEN] = dk_x.try_into().unwrap();

MlKem768::keygen(ek_m, dk_m, rand_m)?;
X25519::keygen(ek_x, dk_x, rand_x)?;

Ok(())
}

fn encaps(
ct: &mut [u8; CT_LEN],
ss: &mut [u8; SS_LEN],
ek: &[u8; EK_LEN],
rand: &[u8; RAND_ENCAPS_LEN],
) -> Result<(), libcrux_traits::kem::owned::EncapsError> {
let (rand_m, rand_x) = rand.split_at(MLKEM768_RAND_ENCAPS_LEN);
let rand_m: &[u8; MLKEM768_RAND_ENCAPS_LEN] = rand_m.try_into().unwrap();
let rand_x: &[u8; X25519_RAND_ENCAPS_LEN] = rand_x.try_into().unwrap();

let (ek_m, ek_x) = ek.split_at(MLKEM768_EK_LEN);
let ek_m: &[u8; MLKEM768_EK_LEN] = ek_m.try_into().unwrap();
let ek_x: &[u8; X25519_EK_LEN] = ek_x.try_into().unwrap();

let (ct_m, ct_x) = ct.split_at_mut(MLKEM768_CT_LEN);
let ct_m: &mut [u8; MLKEM768_CT_LEN] = ct_m.try_into().unwrap();
let ct_x: &mut [u8; X25519_CT_LEN] = ct_x.try_into().unwrap();

let mut hash_buffer = [0u8; 32 + 32 + X25519_CT_LEN + X25519_EK_LEN + 6];
hash_buffer[96..128].copy_from_slice(ek_x);
hash_buffer[128..134].copy_from_slice(&[0x5c, 0x2e, 0x2f, 0x2f, 0x5e, 0x5c]);

let ss_m: &mut [u8; 32] = (&mut hash_buffer[0..32]).try_into().unwrap();
MlKem768::encaps(ct_m, ss_m, ek_m, rand_m)?;

let ss_x: &mut [u8; 32] = (&mut hash_buffer[32..64]).try_into().unwrap();
X25519::encaps(ct_x, ss_x, ek_x, rand_x)?;
hash_buffer[64..96].copy_from_slice(ct_x);
sha3::sha256_ema(ss, &hash_buffer);

Ok(())
}

fn decaps(
ss: &mut [u8; SS_LEN],
ct: &[u8; CT_LEN],
dk: &[u8; DK_LEN],
) -> Result<(), libcrux_traits::kem::owned::DecapsError> {
let (dk_m, dk_x) = dk.split_at(MLKEM768_DK_LEN);
let dk_m: &[u8; MLKEM768_DK_LEN] = dk_m.try_into().unwrap();
let dk_x: &[u8; X25519_DK_LEN] = dk_x.try_into().unwrap();

let (ct_m, ct_x) = ct.split_at(MLKEM768_CT_LEN);
let ct_m: &[u8; MLKEM768_CT_LEN] = ct_m.try_into().unwrap();
let ct_x: &[u8; X25519_CT_LEN] = ct_x.try_into().unwrap();

let mut ek_x = [0u8; X25519_EK_LEN];
libcrux_curve25519::secret_to_public(&mut ek_x, dk_x);

let mut hash_buffer = [0u8; 32 + 32 + X25519_CT_LEN + X25519_EK_LEN + 6];
hash_buffer[64..96].copy_from_slice(ct_x);
hash_buffer[96..128].copy_from_slice(&ek_x);
hash_buffer[128..134].copy_from_slice(&[0x5c, 0x2e, 0x2f, 0x2f, 0x5e, 0x5c]);

let ss_m: &mut [u8; 32] = (&mut hash_buffer[0..32]).try_into().unwrap();
MlKem768::decaps(ss_m, ct_m, dk_m)?;

let ss_x: &mut [u8; 32] = (&mut hash_buffer[32..64]).try_into().unwrap();
X25519::decaps(ss_x, ct_x, dk_x)?;
sha3::sha256_ema(ss, &hash_buffer);

Ok(())
}
}

libcrux_traits::kem::slice::impl_trait!(XWing => EK_LEN, DK_LEN, CT_LEN, SS_LEN, RAND_KEYGEN_LEN, RAND_ENCAPS_LEN);

pub struct XWingSharedSecret {
pub(super) value: [u8; 32],
}
Expand Down
4 changes: 4 additions & 0 deletions libcrux-ml-kem/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ libcrux-platform = { version = "0.0.2", path = "../sys/platform" }
libcrux-sha3 = { version = "0.0.3", path = "../libcrux-sha3" }
libcrux-intrinsics = { version = "0.0.3", path = "../libcrux-intrinsics" }
libcrux-secrets = { version = "0.0.3", path = "../secrets" }
libcrux-traits = { version = "0.0.3", path = "../traits" }
hax-lib.workspace = true

[features]
Expand Down Expand Up @@ -71,6 +72,9 @@ serde_json = { version = "1.0" }
serde = { version = "1.0", features = ["derive"] }
hex = { version = "0.4.3", features = ["serde"] }
criterion = "0.6"
libcrux-traits = { version = "0.0.3", path = "../traits", features = [
"generic-tests",
] }

[[bench]]
name = "ml-kem"
Expand Down
67 changes: 67 additions & 0 deletions libcrux-ml-kem/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ cfg_kyber! {
pub mod kyber512 {
//! Kyber 512 (NIST PQC Round 3)
cfg_no_eurydice! {
pub use crate::mlkem512::kyber::Kyber512;
pub use crate::mlkem512::kyber::generate_key_pair;
pub use crate::mlkem512::kyber::decapsulate;
pub use crate::mlkem512::kyber::encapsulate;
Expand All @@ -150,6 +151,7 @@ cfg_kyber! {
pub mod kyber768 {
//! Kyber 768 (NIST PQC Round 3)
cfg_no_eurydice! {
pub use crate::mlkem768::kyber::Kyber768;
pub use crate::mlkem768::kyber::generate_key_pair;
pub use crate::mlkem768::kyber::decapsulate;
pub use crate::mlkem768::kyber::encapsulate;
Expand All @@ -163,6 +165,7 @@ cfg_kyber! {
pub mod kyber1024 {
//! Kyber 1024 (NIST PQC Round 3)
cfg_no_eurydice! {
pub use crate::mlkem1024::kyber::Kyber1024;
pub use crate::mlkem1024::kyber::generate_key_pair;
pub use crate::mlkem1024::kyber::decapsulate;
pub use crate::mlkem1024::kyber::encapsulate;
Expand All @@ -171,3 +174,67 @@ cfg_kyber! {
}
}
}

macro_rules! impl_kem_trait {
($variant:ty, $pk:ty, $sk:ty, $ct:ty) => {
impl
libcrux_traits::kem::arrayref::Kem<
CPA_PKE_PUBLIC_KEY_SIZE,
SECRET_KEY_SIZE,
CPA_PKE_CIPHERTEXT_SIZE,
SHARED_SECRET_SIZE,
KEY_GENERATION_SEED_SIZE,
SHARED_SECRET_SIZE,
> for $variant
{
fn keygen(
ek: &mut [u8; CPA_PKE_PUBLIC_KEY_SIZE],
dk: &mut [u8; SECRET_KEY_SIZE],
rand: &[u8; KEY_GENERATION_SEED_SIZE],
) -> Result<(), libcrux_traits::kem::owned::KeyGenError> {
let key_pair = generate_key_pair(*rand);
ek.copy_from_slice(key_pair.pk());
dk.copy_from_slice(key_pair.sk());

Ok(())
}

fn encaps(
ct: &mut [u8; CPA_PKE_CIPHERTEXT_SIZE],
ss: &mut [u8; SHARED_SECRET_SIZE],
ek: &[u8; CPA_PKE_PUBLIC_KEY_SIZE],
rand: &[u8; SHARED_SECRET_SIZE],
) -> Result<(), libcrux_traits::kem::owned::EncapsError> {
let public_key: $pk = ek.into();

let (ct_, ss_) = encapsulate(&public_key, *rand);
ct.copy_from_slice(ct_.as_slice());
ss.copy_from_slice(ss_.as_slice());

Ok(())
}

fn decaps(
ss: &mut [u8; SHARED_SECRET_SIZE],
ct: &[u8; CPA_PKE_CIPHERTEXT_SIZE],
dk: &[u8; SECRET_KEY_SIZE],
) -> Result<(), libcrux_traits::kem::owned::DecapsError> {
let secret_key: $sk = dk.into();
let ciphertext: $ct = ct.into();

let ss_ = decapsulate(&secret_key, &ciphertext);

ss.copy_from_slice(ss_.as_slice());

Ok(())
}
}

libcrux_traits::kem::slice::impl_trait!($variant =>
CPA_PKE_PUBLIC_KEY_SIZE, SECRET_KEY_SIZE,
CPA_PKE_CIPHERTEXT_SIZE, SHARED_SECRET_SIZE,
KEY_GENERATION_SEED_SIZE, SHARED_SECRET_SIZE);
};
}

use impl_kem_trait;
Loading
Loading