Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 14 additions & 51 deletions payjoin/src/core/hpke.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ impl<'de> serde::Deserialize<'de> for HpkePublicKey {

/// Message A is sent from the sender to the receiver containing an Original PSBT payload
pub fn encrypt_message_a(
body: Vec<u8>,
body: &[u8; PADDED_PLAINTEXT_A_LENGTH],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@spacebear21 I think the main motivation was just to get these function types to be strongly typed with this function signature. Other than that, I think it was a style thing to get rid of the function. How exactly it's implemented is not of the absolute greatest importance.

reply_pk: &HpkePublicKey,
receiver_pk: &HpkePublicKey,
) -> Result<Vec<u8>, HpkeError> {
Expand All @@ -182,8 +182,6 @@ pub fn encrypt_message_a(
INFO_A,
&mut OsRng,
)?;
let mut body = body;
pad_plaintext(&mut body, PADDED_PLAINTEXT_A_LENGTH)?;
let mut plaintext = compressed_bytes_from_pubkey(reply_pk).to_vec();
plaintext.extend(body);
let ciphertext = encryption_context.seal(&plaintext, &[])?;
Expand Down Expand Up @@ -223,7 +221,7 @@ pub fn decrypt_message_a(

/// Message B is sent from the receiver to the sender containing a Payjoin PSBT payload or an error
pub fn encrypt_message_b(
mut plaintext: Vec<u8>,
body: &[u8; PADDED_PLAINTEXT_B_LENGTH],
receiver_keypair: &HpkeKeyPair,
sender_pk: &HpkePublicKey,
) -> Result<Vec<u8>, HpkeError> {
Expand All @@ -237,8 +235,7 @@ pub fn encrypt_message_b(
INFO_B,
&mut OsRng,
)?;
let plaintext: &[u8] = pad_plaintext(&mut plaintext, PADDED_PLAINTEXT_B_LENGTH)?;
let ciphertext = encryption_context.seal(plaintext, &[])?;
let ciphertext = encryption_context.seal(body, &[])?;
let mut message_b = ellswift_bytes_from_encapped_key(&encapsulated_key)?.to_vec();
message_b.extend(&ciphertext);
Ok(message_b)
Expand All @@ -261,14 +258,6 @@ pub fn decrypt_message_b(
Ok(plaintext)
}

fn pad_plaintext(msg: &mut Vec<u8>, padded_length: usize) -> Result<&[u8], HpkeError> {
if msg.len() > padded_length {
return Err(HpkeError::PayloadTooLarge { actual: msg.len(), max: padded_length });
}
msg.resize(padded_length, 0);
Ok(msg)
}

/// Error from de/encrypting a v2 Hybrid Public Key Encryption payload.
#[derive(Debug, PartialEq, Eq)]
pub enum HpkeError {
Expand Down Expand Up @@ -304,7 +293,7 @@ impl fmt::Display for HpkeError {
PayloadTooLarge { actual, max } => {
write!(
f,
"Plaintext too large, max size is {max} bytes, actual size is {actual} bytes"
"Plaintext length incorrect, expected size is {max} bytes, actual size is {actual} bytes"
)
}
PayloadTooShort => write!(f, "Payload too small"),
Expand Down Expand Up @@ -332,13 +321,13 @@ mod test {

#[test]
fn message_a_round_trip() {
let mut plaintext = "foo".as_bytes().to_vec();
let mut plaintext = [0u8; PADDED_PLAINTEXT_A_LENGTH];

let reply_keypair = HpkeKeyPair::gen_keypair();
let receiver_keypair = HpkeKeyPair::gen_keypair();

let message_a = encrypt_message_a(
plaintext.clone(),
&plaintext,
reply_keypair.public_key(),
receiver_keypair.public_key(),
)
Expand All @@ -350,14 +339,12 @@ mod test {

assert_eq!(decrypted.0.len(), PADDED_PLAINTEXT_A_LENGTH);

// decrypted plaintext is padded, so pad the expected plaintext
plaintext.resize(PADDED_PLAINTEXT_A_LENGTH, 0);
assert_eq!(decrypted, (plaintext.to_vec(), reply_keypair.public_key().clone()));

// ensure full plaintext round trips
plaintext[PADDED_PLAINTEXT_A_LENGTH - 1] = 42;
let message_a = encrypt_message_a(
plaintext.clone(),
&plaintext,
reply_keypair.public_key(),
receiver_keypair.public_key(),
)
Expand Down Expand Up @@ -387,30 +374,17 @@ mod test {
decrypt_message_a(&corrupted_message_a, receiver_keypair.secret_key().clone()),
Err(HpkeError::Hpke(hpke::HpkeError::OpenError))
);

plaintext.resize(PADDED_PLAINTEXT_A_LENGTH + 1, 0);
assert_eq!(
encrypt_message_a(
plaintext.clone(),
reply_keypair.public_key(),
receiver_keypair.public_key(),
),
Err(HpkeError::PayloadTooLarge {
actual: PADDED_PLAINTEXT_A_LENGTH + 1,
max: PADDED_PLAINTEXT_A_LENGTH,
})
);
}

#[test]
fn message_b_round_trip() {
let mut plaintext = "foo".as_bytes().to_vec();
let mut plaintext = [0u8; PADDED_PLAINTEXT_B_LENGTH];

let reply_keypair = HpkeKeyPair::gen_keypair();
let receiver_keypair = HpkeKeyPair::gen_keypair();

let message_b =
encrypt_message_b(plaintext.clone(), &receiver_keypair, reply_keypair.public_key())
encrypt_message_b(&plaintext, &receiver_keypair, reply_keypair.public_key())
.expect("encryption should work");

assert_eq!(message_b.len(), PADDED_MESSAGE_BYTES);
Expand All @@ -423,13 +397,11 @@ mod test {
.expect("decryption should work");

assert_eq!(decrypted.len(), PADDED_PLAINTEXT_B_LENGTH);
// decrypted plaintext is padded, so pad the expected plaintext
plaintext.resize(PADDED_PLAINTEXT_B_LENGTH, 0);
assert_eq!(decrypted, plaintext.to_vec());

plaintext[PADDED_PLAINTEXT_B_LENGTH - 1] = 42;
let message_b =
encrypt_message_b(plaintext.clone(), &receiver_keypair, reply_keypair.public_key())
encrypt_message_b(&plaintext, &receiver_keypair, reply_keypair.public_key())
.expect("encryption should work");

assert_eq!(message_b.len(), PADDED_MESSAGE_BYTES);
Expand Down Expand Up @@ -481,15 +453,6 @@ mod test {
),
Err(HpkeError::Hpke(hpke::HpkeError::OpenError))
);

plaintext.resize(PADDED_PLAINTEXT_B_LENGTH + 1, 0);
assert_eq!(
encrypt_message_b(plaintext.clone(), &receiver_keypair, reply_keypair.public_key()),
Err(HpkeError::PayloadTooLarge {
actual: PADDED_PLAINTEXT_B_LENGTH + 1,
max: PADDED_PLAINTEXT_B_LENGTH
})
);
}

/// Test that the encrypted payloads are uniform.
Expand All @@ -508,17 +471,17 @@ mod test {
let receiver_keypair = HpkeKeyPair::gen_keypair();
let reply_keypair = HpkeKeyPair::gen_keypair();

let plaintext_a = vec![0u8; PADDED_PLAINTEXT_A_LENGTH];
let plaintext_a = [0u8; PADDED_PLAINTEXT_A_LENGTH];
let message_a = encrypt_message_a(
plaintext_a,
&plaintext_a,
reply_keypair.public_key(),
receiver_keypair.public_key(),
)
.expect("encryption should work");

let plaintext_b = vec![0u8; PADDED_PLAINTEXT_B_LENGTH];
let plaintext_b = [0u8; PADDED_PLAINTEXT_B_LENGTH];
let message_b =
encrypt_message_b(plaintext_b, &receiver_keypair, sender_keypair.public_key())
encrypt_message_b(&plaintext_b, &receiver_keypair, sender_keypair.public_key())
.expect("encryption should work");

messages_a.push(message_a);
Expand Down
13 changes: 13 additions & 0 deletions payjoin/src/core/receive/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::{error, fmt};
use crate::error_codes::ErrorCode::{
self, NotEnoughMoney, OriginalPsbtRejected, Unavailable, VersionUnsupported,
};
// use crate::hpke::HpkeError::PayloadTooLarge;

/// The top-level error type for the payjoin receiver
#[derive(Debug)]
Expand All @@ -14,13 +15,18 @@ pub enum Error {
///
/// e.g. database errors, network failures, wallet errors
Implementation(crate::ImplementationError),
PayloadTooLarge {
actual: usize,
max: usize,
},
}

impl From<&Error> for JsonReply {
fn from(e: &Error) -> Self {
match e {
Error::Protocol(e) => e.into(),
Error::Implementation(_) => JsonReply::new(Unavailable, "Receiver error"),
Error::PayloadTooLarge { actual: _, max: _ } => todo!("unimplemented"),
}
}
}
Expand All @@ -34,6 +40,12 @@ impl fmt::Display for Error {
match self {
Error::Protocol(e) => write!(f, "Protocol error: {e}"),
Error::Implementation(e) => write!(f, "Implementation error: {e}"),
Error::PayloadTooLarge { actual, max } => {
write!(
f,
"Plaintext length incorrect, expected size is {max} bytes, actual size is {actual} bytes"
)
}
}
}
}
Expand All @@ -43,6 +55,7 @@ impl error::Error for Error {
match self {
Error::Protocol(e) => e.source(),
Error::Implementation(e) => e.source(),
Error::PayloadTooLarge { .. } => None,
}
}
}
Expand Down
17 changes: 15 additions & 2 deletions payjoin/src/core/receive/v2/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
//! Note: Even fresh requests may be linkable via metadata (e.g. client IP, request timing),
//! but request reuse makes correlation trivial for the relay.

use std::io::Write;
use std::str::FromStr;
use std::time::{Duration, SystemTime};

Expand All @@ -42,7 +43,9 @@ use super::error::{Error, InputContributionError};
use super::{
common, InternalPayloadError, JsonReply, OutputSubstitutionError, ProtocolError, SelectionError,
};
use crate::hpke::{decrypt_message_a, encrypt_message_b, HpkeKeyPair, HpkePublicKey};
use crate::hpke::{
decrypt_message_a, encrypt_message_b, HpkeKeyPair, HpkePublicKey, PADDED_PLAINTEXT_B_LENGTH,
};
use crate::ohttp::{
ohttp_encapsulate, process_get_res, process_post_res, OhttpEncapsulationError, OhttpKeys,
};
Expand Down Expand Up @@ -1035,7 +1038,17 @@ impl Receiver<PayjoinProposal> {
let payjoin_bytes = self.psbt.serialize();
let sender_mailbox = short_id_from_pubkey(e);
target_resource = mailbox_endpoint(&self.session_context.directory, &sender_mailbox);
body = encrypt_message_b(payjoin_bytes, &self.session_context.receiver_key, e)?;

let mut buf = [0u8; PADDED_PLAINTEXT_B_LENGTH];

(&mut &mut buf[..]).write_all(&payjoin_bytes).map_err(|e| {
assert!(e.kind() == std::io::ErrorKind::WriteZero);
Error::PayloadTooLarge {
actual: payjoin_bytes.len(),
max: PADDED_PLAINTEXT_B_LENGTH,
}
})?;
body = encrypt_message_b(&buf, &self.session_context.receiver_key, e)?;
method = "POST";
} else {
// Prepare v2 wrapped and backwards-compatible v1 payload
Expand Down
Loading