diff --git a/src/tls.rs b/src/tls.rs index aa743875..48cc8a8b 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -24,12 +24,15 @@ pub(crate) enum TlsConfigError { /// An Error parsing the Certificate CertParseError, /// Identity PEM is invalid + #[allow(dead_code)] InvalidIdentityPem, /// Identity PEM is missing a private key such as RSA, ECC or PKCS8 MissingPrivateKey, /// Unknown private key format + #[allow(dead_code)] UnknownPrivateKeyFormat, /// An error from an empty key + #[allow(dead_code)] EmptyKey, /// An error from an invalid key InvalidKey(TlsError), @@ -171,38 +174,16 @@ impl TlsConfigBuilder { self } - pub(crate) fn build(mut self) -> Result { + pub(crate) fn build(self) -> Result { let mut cert_rdr = BufReader::new(self.cert); let cert = rustls_pemfile::certs(&mut cert_rdr) .collect::, _>>() .map_err(|_e| TlsConfigError::CertParseError)?; - let mut key_vec = Vec::new(); - self.key - .read_to_end(&mut key_vec) - .map_err(TlsConfigError::Io)?; - - if key_vec.is_empty() { - return Err(TlsConfigError::EmptyKey); - } - - let mut key_opt = None; - let mut key_cur = std::io::Cursor::new(key_vec); - for item in rustls_pemfile::read_all(&mut key_cur) - .collect::, _>>() - .map_err(|_e| TlsConfigError::InvalidIdentityPem)? - { - match item { - rustls_pemfile::Item::Pkcs1Key(k) => key_opt = Some(k.into()), - rustls_pemfile::Item::Pkcs8Key(k) => key_opt = Some(k.into()), - rustls_pemfile::Item::Sec1Key(k) => key_opt = Some(k.into()), - _ => return Err(TlsConfigError::UnknownPrivateKeyFormat), - } - } - let key = match key_opt { - Some(v) => v, - _ => return Err(TlsConfigError::MissingPrivateKey), - }; + let mut key_rdr = BufReader::new(self.key); + let key = rustls_pemfile::private_key(&mut key_rdr) + .map_err(TlsConfigError::Io)? + .ok_or(TlsConfigError::MissingPrivateKey)?; fn read_trust_anchor( trust_anchor: Box, @@ -442,4 +423,16 @@ mod tests { .build() .unwrap(); } + + #[test] + fn cert_key_as_one() { + let key = include_str!("../examples/tls/key.ecc"); + let cert = include_str!("../examples/tls/cert.ecc.pem"); + let combined = format!("{cert}\n{key}"); + TlsConfigBuilder::new() + .key(combined.as_bytes()) + .cert(combined.as_bytes()) + .build() + .unwrap(); + } }