Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 114 additions & 4 deletions ml-dsa/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -373,12 +373,18 @@ impl<P: MlDsaParams> SigningKey<P> {
// the concatenated M'.
// XXX(RLB) Should the API represent this as an input?
let mu = message_representative(&self.tr, Mp);
self.raw_sign_mu(&mu, rnd)
}

fn raw_sign_mu(&self, mu: &B64, rnd: &B32) -> Signature<P>
where
P: MlDsaParams,
{
// Compute the private random seed
let rhopp: B64 = H::default()
.absorb(&self.K)
.absorb(rnd)
.absorb(&mu)
.absorb(mu)
.squeeze_new();

// Rejection sampling loop
Expand All @@ -389,7 +395,7 @@ impl<P: MlDsaParams> SigningKey<P> {

let w1_tilde = P::encode_w1(&w1);
let c_tilde = H::default()
.absorb(&mu)
.absorb(mu)
.absorb(&w1_tilde)
.squeeze_new::<P::Lambda>();
let c = sample_in_ball(&c_tilde, P::TAU);
Expand Down Expand Up @@ -448,6 +454,24 @@ impl<P: MlDsaParams> SigningKey<P> {
Ok(self.sign_internal(Mp, &rnd))
}

/// This method reflects the randomized ML-DSA.Sign algorithm with a pre-computed μ.
///
/// # Errors
///
/// This method can return an opaque error if it fails to get enough randomness.
// Algorithm 2 ML-DSA.Sign (optional pre-computed μ variant)
#[cfg(feature = "rand_core")]
pub fn sign_mu_randomized<R: TryCryptoRng + ?Sized>(
&self,
mu: &B64,
rng: &mut R,
) -> Result<Signature<P>, Error> {
let mut rnd = B32::default();
rng.try_fill_bytes(&mut rnd).map_err(|_| Error::new())?;

Ok(self.raw_sign_mu(mu, &rnd))
}

/// This method reflects the optional deterministic variant of the ML-DSA.Sign algorithm.
///
/// # Errors
Expand All @@ -458,6 +482,14 @@ impl<P: MlDsaParams> SigningKey<P> {
self.raw_sign_deterministic(&[M], ctx)
}

/// This method reflects the optional deterministic variant of the ML-DSA.Sign algorithm with a
/// pre-computed μ.
// Algorithm 2 ML-DSA.Sign (optional deterministic and pre-computed μ variant)
pub fn sign_mu_deterministic(&self, mu: &B64) -> Signature<P> {
let rnd = B32::default();
self.raw_sign_mu(mu, &rnd)
}

fn raw_sign_deterministic(&self, M: &[&[u8]], ctx: &[u8]) -> Result<Signature<P>, Error> {
if ctx.len() > 255 {
return Err(Error::new());
Expand Down Expand Up @@ -616,7 +648,13 @@ impl<P: MlDsaParams> VerifyingKey<P> {
{
// Compute the message representative
let mu = message_representative(&self.tr, Mp);
self.raw_verify_mu(&mu, sigma)
}

fn raw_verify_mu(&self, mu: &B64, sigma: &Signature<P>) -> bool
where
P: MlDsaParams,
{
// Reconstruct w
let c = sample_in_ball(&sigma.c_tilde, P::TAU);

Expand All @@ -630,19 +668,25 @@ impl<P: MlDsaParams> VerifyingKey<P> {

let w1p_tilde = P::encode_w1(&w1p);
let cp_tilde = H::default()
.absorb(&mu)
.absorb(mu)
.absorb(&w1p_tilde)
.squeeze_new::<P::Lambda>();

sigma.c_tilde == cp_tilde
}

/// This algorithm reflect the ML-DSA.Verify algorithm from FIPS 204.
/// This algorithm reflects the ML-DSA.Verify algorithm from FIPS 204.
// Algorithm 3 ML-DSA.Verify
pub fn verify_with_context(&self, M: &[u8], ctx: &[u8], sigma: &Signature<P>) -> bool {
self.raw_verify_with_context(&[M], ctx, sigma)
}

/// This algorithm reflects the ML-DSA.Verify algorithm with a pre-computed μ from FIPS 204.
// Algorithm 3 ML-DSA.Verify (optional pre-computed μ variant)
pub fn verify_mu(&self, mu: &B64, sigma: &Signature<P>) -> bool {
self.raw_verify_mu(mu, sigma)
}

fn raw_verify_with_context(&self, M: &[&[u8]], ctx: &[u8], sigma: &Signature<P>) -> bool {
if ctx.len() > 255 {
return false;
Expand Down Expand Up @@ -1006,4 +1050,70 @@ mod test {
many_round_trip_test::<MlDsa65>();
many_round_trip_test::<MlDsa87>();
}

#[test]
fn sign_mu_verify_mu_round_trip() {
fn sign_mu_verify_mu<P>()
where
P: MlDsaParams,
{
let kp = P::key_gen_internal(&Array::default());
let sk = kp.signing_key;
let vk = kp.verifying_key;

let M = b"Hello world";
let rnd = Array([0u8; 32]);
let mu = message_representative(&sk.tr, &[&[M]]);
let sig = sk.raw_sign_mu(&mu, &rnd);

assert!(vk.raw_verify_mu(&mu, &sig));
}
sign_mu_verify_mu::<MlDsa44>();
sign_mu_verify_mu::<MlDsa65>();
sign_mu_verify_mu::<MlDsa87>();
}

#[test]
fn sign_mu_verify_internal_round_trip() {
fn sign_mu_verify_internal<P>()
where
P: MlDsaParams,
{
let kp = P::key_gen_internal(&Array::default());
let sk = kp.signing_key;
let vk = kp.verifying_key;

let M = b"Hello world";
let rnd = Array([0u8; 32]);
let mu = message_representative(&sk.tr, &[&[M]]);
let sig = sk.raw_sign_mu(&mu, &rnd);

assert!(vk.verify_internal(&[M], &sig));
}
sign_mu_verify_internal::<MlDsa44>();
sign_mu_verify_internal::<MlDsa65>();
sign_mu_verify_internal::<MlDsa87>();
}

#[test]
fn sign_internal_verify_mu_round_trip() {
fn sign_internal_verify_mu<P>()
where
P: MlDsaParams,
{
let kp = P::key_gen_internal(&Array::default());
let sk = kp.signing_key;
let vk = kp.verifying_key;

let M = b"Hello world";
let rnd = Array([0u8; 32]);
let mu = message_representative(&sk.tr, &[&[M]]);
let sig = sk.sign_internal(&[M], &rnd);

assert!(vk.raw_verify_mu(&mu, &sig));
}
sign_internal_verify_mu::<MlDsa44>();
sign_internal_verify_mu::<MlDsa65>();
sign_internal_verify_mu::<MlDsa87>();
}
}