diff --git a/README.md b/README.md index b5865977..c256c54a 100644 --- a/README.md +++ b/README.md @@ -79,6 +79,11 @@ If you want to set the `kid` parameter or change the algorithm for example: ```rust let mut header = Header::new(Algorithm::HS512); header.kid = Some("blabla".to_owned()); + +let mut extras = HashMap::with_capacity(1); +extras.insert("custom".to_string(), "header".to_string()); +header.extras = Some(extras); + let token = encode(&header, &my_claims, &EncodingKey::from_secret("secret".as_ref()))?; ``` Look at `examples/custom_header.rs` for a full working example. diff --git a/benches/jwt.rs b/benches/jwt.rs index d2fee79e..5949183d 100644 --- a/benches/jwt.rs +++ b/benches/jwt.rs @@ -1,6 +1,7 @@ use criterion::{black_box, criterion_group, criterion_main, Criterion}; use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation}; use serde::{Deserialize, Serialize}; +use std::collections::HashMap; #[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] struct Claims { @@ -17,6 +18,18 @@ fn bench_encode(c: &mut Criterion) { }); } +fn bench_encode_custom_extra_headers(c: &mut Criterion) { + let claim = Claims { sub: "b@b.com".to_owned(), company: "ACME".to_owned() }; + let key = EncodingKey::from_secret("secret".as_ref()); + let mut extras = HashMap::with_capacity(1); + extras.insert("custom".to_string(), "header".to_string()); + let header = &Header { extras, ..Default::default() }; + + c.bench_function("bench_encode", |b| { + b.iter(|| encode(black_box(header), black_box(&claim), black_box(&key))) + }); +} + fn bench_decode(c: &mut Criterion) { let token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ"; let key = DecodingKey::from_secret("secret".as_ref()); @@ -32,5 +45,5 @@ fn bench_decode(c: &mut Criterion) { }); } -criterion_group!(benches, bench_encode, bench_decode); +criterion_group!(benches, bench_encode, bench_encode_custom_extra_headers, bench_decode); criterion_main!(benches); diff --git a/examples/custom_header.rs b/examples/custom_header.rs index 5041b773..18f3cb81 100644 --- a/examples/custom_header.rs +++ b/examples/custom_header.rs @@ -1,4 +1,5 @@ use serde::{Deserialize, Serialize}; +use std::collections::HashMap; use jsonwebtoken::errors::ErrorKind; use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation}; @@ -15,8 +16,15 @@ fn main() { Claims { sub: "b@b.com".to_owned(), company: "ACME".to_owned(), exp: 10000000000 }; let key = b"secret"; - let header = - Header { kid: Some("signing_key".to_owned()), alg: Algorithm::HS512, ..Default::default() }; + let mut extras = HashMap::with_capacity(1); + extras.insert("custom".to_string(), "header".to_string()); + + let header = Header { + kid: Some("signing_key".to_owned()), + alg: Algorithm::HS512, + extras, + ..Default::default() + }; let token = match encode(&header, &my_claims, &EncodingKey::from_secret(key)) { Ok(t) => t, diff --git a/src/header.rs b/src/header.rs index 220f0fa4..0947d468 100644 --- a/src/header.rs +++ b/src/header.rs @@ -1,3 +1,4 @@ +use std::collections::HashMap; use std::result; use base64::{engine::general_purpose::STANDARD, Engine}; @@ -10,7 +11,7 @@ use crate::serialization::b64_decode; /// A basic JWT header, the alg defaults to HS256 and typ is automatically /// set to `JWT`. All the other fields are optional. -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct Header { /// The type of JWS: it can only be "JWT" here /// @@ -64,6 +65,12 @@ pub struct Header { #[serde(skip_serializing_if = "Option::is_none")] #[serde(rename = "x5t#S256")] pub x5t_s256: Option, + + /// Any additional non-standard headers not defined in [RFC7515#4.1](https://datatracker.ietf.org/doc/html/rfc7515#section-4.1). + /// Once serialized, all keys will be converted to fields at the root level of the header payload + /// Ex: Dict("custom" -> "header") will be converted to "{"typ": "JWT", ..., "custom": "header"}" + #[serde(flatten)] + pub extras: HashMap, } impl Header { @@ -80,6 +87,7 @@ impl Header { x5c: None, x5t: None, x5t_s256: None, + extras: Default::default(), } } diff --git a/src/jwk.rs b/src/jwk.rs index 49c58003..d89bc085 100644 --- a/src/jwk.rs +++ b/src/jwk.rs @@ -43,7 +43,7 @@ impl<'de> Deserialize<'de> for PublicKeyUse { D: Deserializer<'de>, { struct PublicKeyUseVisitor; - impl<'de> de::Visitor<'de> for PublicKeyUseVisitor { + impl de::Visitor<'_> for PublicKeyUseVisitor { type Value = PublicKeyUse; fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -116,7 +116,7 @@ impl<'de> Deserialize<'de> for KeyOperations { D: Deserializer<'de>, { struct KeyOperationsVisitor; - impl<'de> de::Visitor<'de> for KeyOperationsVisitor { + impl de::Visitor<'_> for KeyOperationsVisitor { type Value = KeyOperations; fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { diff --git a/src/serialization.rs b/src/serialization.rs index 2ec8cc91..dcad4f53 100644 --- a/src/serialization.rs +++ b/src/serialization.rs @@ -3,10 +3,12 @@ use serde::{Deserialize, Serialize}; use crate::errors::Result; +#[inline] pub(crate) fn b64_encode>(input: T) -> String { URL_SAFE_NO_PAD.encode(input) } +#[inline] pub(crate) fn b64_decode>(input: T) -> Result> { URL_SAFE_NO_PAD.decode(input).map_err(|e| e.into()) } diff --git a/src/validation.rs b/src/validation.rs index 289ba03c..bfceded5 100644 --- a/src/validation.rs +++ b/src/validation.rs @@ -337,13 +337,20 @@ where { struct NumericType(PhantomData TryParse>); - impl<'de> Visitor<'de> for NumericType { + impl Visitor<'_> for NumericType { type Value = TryParse; fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { formatter.write_str("A NumericType that can be reasonably coerced into a u64") } + fn visit_u64(self, value: u64) -> std::result::Result + where + E: de::Error, + { + Ok(TryParse::Parsed(value)) + } + fn visit_f64(self, value: f64) -> std::result::Result where E: de::Error, @@ -354,13 +361,6 @@ where Err(serde::de::Error::custom("NumericType must be representable as a u64")) } } - - fn visit_u64(self, value: u64) -> std::result::Result - where - E: de::Error, - { - Ok(TryParse::Parsed(value)) - } } match deserializer.deserialize_any(NumericType(PhantomData)) { diff --git a/tests/hmac.rs b/tests/hmac.rs index 47ca4484..dba0a169 100644 --- a/tests/hmac.rs +++ b/tests/hmac.rs @@ -5,6 +5,7 @@ use jsonwebtoken::{ decode, decode_header, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation, }; use serde::{Deserialize, Serialize}; +use std::collections::HashMap; use time::OffsetDateTime; use wasm_bindgen_test::wasm_bindgen_test; @@ -51,6 +52,56 @@ fn encode_with_custom_header() { .unwrap(); assert_eq!(my_claims, token_data.claims); assert_eq!("kid", token_data.header.kid.unwrap()); + assert!(token_data.header.extras.is_empty()); +} + +#[test] +#[wasm_bindgen_test] +fn encode_with_extra_custom_header() { + let my_claims = Claims { + sub: "b@b.com".to_string(), + company: "ACME".to_string(), + exp: OffsetDateTime::now_utc().unix_timestamp() + 10000, + }; + let mut extras = HashMap::with_capacity(1); + extras.insert("custom".to_string(), "header".to_string()); + let header = Header { kid: Some("kid".to_string()), extras, ..Default::default() }; + let token = encode(&header, &my_claims, &EncodingKey::from_secret(b"secret")).unwrap(); + let token_data = decode::( + &token, + &DecodingKey::from_secret(b"secret"), + &Validation::new(Algorithm::HS256), + ) + .unwrap(); + assert_eq!(my_claims, token_data.claims); + assert_eq!("kid", token_data.header.kid.unwrap()); + assert_eq!("header", token_data.header.extras.get("custom").unwrap().as_str()); +} + +#[test] +#[wasm_bindgen_test] +fn encode_with_multiple_extra_custom_headers() { + let my_claims = Claims { + sub: "b@b.com".to_string(), + company: "ACME".to_string(), + exp: OffsetDateTime::now_utc().unix_timestamp() + 10000, + }; + let mut extras = HashMap::with_capacity(2); + extras.insert("custom1".to_string(), "header1".to_string()); + extras.insert("custom2".to_string(), "header2".to_string()); + let header = Header { kid: Some("kid".to_string()), extras, ..Default::default() }; + let token = encode(&header, &my_claims, &EncodingKey::from_secret(b"secret")).unwrap(); + let token_data = decode::( + &token, + &DecodingKey::from_secret(b"secret"), + &Validation::new(Algorithm::HS256), + ) + .unwrap(); + assert_eq!(my_claims, token_data.claims); + assert_eq!("kid", token_data.header.kid.unwrap()); + let extras = token_data.header.extras; + assert_eq!("header1", extras.get("custom1").unwrap().as_str()); + assert_eq!("header2", extras.get("custom2").unwrap().as_str()); } #[test] @@ -86,6 +137,25 @@ fn decode_token() { claims.unwrap(); } +#[test] +#[wasm_bindgen_test] +fn decode_token_custom_headers() { + let token = "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiIsImtpZCI6ImtpZCIsImN1c3RvbTEiOiJoZWFkZXIxIiwiY3VzdG9tMiI6ImhlYWRlcjIifQ.eyJzdWIiOiJiQGIuY29tIiwiY29tcGFueSI6IkFDTUUiLCJleHAiOjI1MzI1MjQ4OTF9.FtOHsoKcNH3SriK3tnR-uWJg4UV4FkOzvq_JCfLngfU"; + let claims = decode::( + token, + &DecodingKey::from_secret(b"secret"), + &Validation::new(Algorithm::HS256), + ) + .unwrap(); + let my_claims = + Claims { sub: "b@b.com".to_string(), company: "ACME".to_string(), exp: 2532524891 }; + assert_eq!(my_claims, claims.claims); + assert_eq!("kid", claims.header.kid.unwrap()); + let extras = claims.header.extras; + assert_eq!("header1", extras.get("custom1").unwrap().as_str()); + assert_eq!("header2", extras.get("custom2").unwrap().as_str()); +} + #[test] #[wasm_bindgen_test] #[should_panic(expected = "InvalidToken")]