diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 73cffedc..067757d2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -55,4 +55,33 @@ jobs: path: target key: target-${{ runner.os }}-${{ steps.rust-version.outputs.version }}-${{ hashFiles('Cargo.lock') }} - run: cargo test --features vendored - - run: cargo test --features vendored + + mbedtls: + name: test-mbedtls + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: sfackler/actions/rustup@master + with: + version: 1.63.0 + - run: echo "::set-output name=version::$(rustc --version)" + id: rust-version + # trigger mbedtls implementation + - run: sed -i 's/target_env = "sgx"/target_env = "gnu"/' {Cargo.toml,src/lib.rs,src/test.rs} + - uses: actions/cache@v1 + with: + path: ~/.cargo/registry/index + key: index-${{ runner.os }}-mbdetls-${{ github.run_number }} + restore-keys: | + index-${{ runner.os }}-mbedtls- + - run: cargo generate-lockfile + - uses: actions/cache@v1 + with: + path: ~/.cargo/registry/cache + key: registry-${{ runner.os }}-mbedtls-${{ steps.rust-version.outputs.version }}-${{ hashFiles('Cargo.lock') }} + - run: cargo fetch + - uses: actions/cache@v1 + with: + path: target + key: target-${{ runner.os }}-mbedtls-${{ steps.rust-version.outputs.version }}-${{ hashFiles('Cargo.lock') }} + - run: cargo test --features alpn --lib diff --git a/Cargo.toml b/Cargo.toml index f63b0223..b53d407b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,12 +26,20 @@ tempfile = "3.1.0" [target.'cfg(target_os = "windows")'.dependencies] schannel = "0.1.17" -[target.'cfg(not(any(target_os = "windows", target_os = "macos", target_os = "ios")))'.dependencies] +[target.'cfg(not(any(target_os = "windows", target_os = "macos", target_os = "ios", target_env = "sgx")))'.dependencies] log = "0.4.5" openssl = "0.10.29" openssl-sys = "0.9.55" openssl-probe = "0.1" +[target.'cfg(target_env = "sgx")'.dependencies] +mbedtls = { version = "0.9.1", features = ["std", "rdrand", "mpi_force_c_code" ], default-features = false } +pkcs5 = { version = "0.7.1", features = ["alloc", "pbes2"] } +p12 = "0.6.3" +yasna = "0.5" + [dev-dependencies] +lazy_static = "1.4.0" tempfile = "3.0" test-cert-gen = "0.9" +ureq = "2.6" diff --git a/src/imp/mbedtls.rs b/src/imp/mbedtls.rs new file mode 100644 index 00000000..329a3e7e --- /dev/null +++ b/src/imp/mbedtls.rs @@ -0,0 +1,523 @@ +extern crate mbedtls; + +use self::mbedtls::alloc::{Box as MbedtlsBox, List as MbedtlsList}; +use self::mbedtls::hash::{Md, Type as MdType}; +use self::mbedtls::pk::Pk; +use self::mbedtls::rng::{CtrDrbg, Rdseed}; +#[cfg(feature = "alpn")] +use self::mbedtls::ssl::config::NullTerminatedStrList; +use self::mbedtls::ssl::config::{Endpoint, Preset, Transport}; +use self::mbedtls::ssl::{Config, Context, Version}; +use self::mbedtls::x509::certificate::Certificate as MbedtlsCert; +use self::mbedtls::Error as TlsError; + +use std::convert::TryFrom; +use std::error; +use std::fmt::{self, Debug}; +use std::io; +use std::sync::Arc; + +use {Protocol, TlsAcceptorBuilder, TlsConnectorBuilder}; + +#[derive(Debug)] +pub enum Error { + Tls(TlsError), + Pkcs12(yasna::ASN1Error), + Pkcs5(pkcs5::Error), + Der(pkcs5::der::Error), + Custom(String), +} + +impl From for Error { + fn from(err: TlsError) -> Error { + Error::Tls(err) + } +} + +impl From for Error { + fn from(err: yasna::ASN1Error) -> Error { + Error::Pkcs12(err) + } +} + +impl From for Error { + fn from(err: pkcs5::Error) -> Error { + Error::Pkcs5(err) + } +} + +impl From for Error { + fn from(err: pkcs5::der::Error) -> Error { + Error::Der(err) + } +} + +impl From for HandshakeError { + fn from(e: TlsError) -> HandshakeError { + HandshakeError::Failure(e.into()) + } +} + +impl error::Error for Error { + fn source(&self) -> Option<&(dyn error::Error + 'static)> { + match *self { + Error::Tls(ref e) => e.source(), + Error::Pkcs12(ref e) => e.source(), + Error::Pkcs5(_) => None, + Error::Der(_) => None, + Error::Custom(_) => None, + } + } +} + +impl fmt::Display for Error { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + match *self { + Error::Tls(ref e) => fmt::Display::fmt(e, fmt), + Error::Pkcs12(ref e) => fmt::Display::fmt(e, fmt), + Error::Pkcs5(ref e) => fmt::Display::fmt(e, fmt), + Error::Der(ref e) => fmt::Display::fmt(e, fmt), + Error::Custom(ref e) => fmt::Display::fmt(e, fmt), + } + } +} + +fn to_mbedtls_version(protocol: Protocol) -> Version { + match protocol { + Protocol::Sslv3 => Version::Ssl3, + Protocol::Tlsv10 => Version::Tls1_0, + Protocol::Tlsv11 => Version::Tls1_1, + Protocol::Tlsv12 => Version::Tls1_2, + } +} + +trait NullTerminated { + fn null_terminated(&self) -> Vec; +} + +impl> NullTerminated for T { + fn null_terminated(&self) -> Vec { + let mut buf = self.as_ref().to_vec(); + buf.push(0); + buf + } +} + +fn pkcs12_decode_key_bag>( + key_bag: &p12::EncryptedPrivateKeyInfo, + pass: B, +) -> Result, Error> { + // try to decrypt the key with algorithms supported by p12 crate + if let Some(decrypted) = key_bag.decrypt(pass.as_ref()) { + Ok(decrypted) + // try to decrypt the key with algorithms supported by pkcs5 standard + } else if let p12::AlgorithmIdentifier::OtherAlg(_) = key_bag.encryption_algorithm { + // write the algorithm identifier back to DER format + let algorithm_der = + yasna::construct_der(|writer| key_bag.encryption_algorithm.write(writer)); + // and construct pkcs5 decoder from it + let scheme = pkcs5::EncryptionScheme::try_from(&algorithm_der[..])?; + + Ok(scheme.decrypt(pass.as_ref(), &key_bag.encrypted_data)?) + } else { + Err(Error::Custom( + "Unsupported key encryption algorithm".to_owned(), + )) + } +} + +#[derive(Clone)] +pub struct Identity { + key: Arc, + certificates: Arc>, +} + +impl Identity { + pub fn from_pkcs12(buf: &[u8], pass: &str) -> Result { + let pfx = p12::PFX::parse(buf)?; + let key = pfx + .bags(pass)? + .iter() + .find_map(|safe_bag| { + if let p12::SafeBagKind::Pkcs8ShroudedKeyBag(ref key_bag) = safe_bag.bag { + Some(pkcs12_decode_key_bag(key_bag, pass)) + } else { + None + } + }) + .ok_or(Error::Custom("No private key in pkcs12 DER".to_owned()))? + .map(|key| Pk::from_private_key(&key, Some(pass.as_bytes())))??; + let certificates: MbedtlsList<_> = pfx + .cert_bags(pass)? + .iter() + .map(|cert| MbedtlsCert::from_der(cert)) + .collect::>()?; + + if !certificates.is_empty() { + Ok(Identity { + key: Arc::new(key), + certificates: Arc::new(certificates), + }) + } else { + Err(Error::Custom( + "PKCS12 file is missing certificate chain".to_owned(), + )) + } + } + + pub fn from_pkcs8(buf: &[u8], key: &[u8]) -> Result { + let key = Pk::from_private_key(&key.null_terminated(), None)?; + let certificates = MbedtlsCert::from_pem_multiple(&buf.null_terminated())?; + + if !certificates.is_empty() { + Ok(Identity { + key: Arc::new(key), + certificates: Arc::new(certificates), + }) + } else { + Err(Error::Custom( + "X509 chain file is missing certificate chain".to_owned(), + )) + } + } + + fn certificates(&self) -> Arc> { + self.certificates.clone() + } + + fn private_key(&self) -> Arc { + self.key.clone() + } +} + +impl Debug for Identity { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Identity") + .field( + "certificates", + &self + .certificates + .iter() + .map(|cert| cert.as_der().to_vec()) + .collect::>(), + ) + .field( + "key_name", + &self.key.name().map(String::from).map_err(Error::Tls), + ) + .finish() + } +} + +#[derive(Clone)] +pub struct Certificate(MbedtlsBox); + +impl Certificate { + pub fn from_der(buf: &[u8]) -> Result { + let cert = MbedtlsCert::from_der(buf).map_err(Error::Tls)?; + Ok(Certificate(cert)) + } + + pub fn from_pem(buf: &[u8]) -> Result { + let cert = MbedtlsCert::from_pem(&buf.null_terminated()).map_err(Error::Tls)?; + Ok(Certificate(cert)) + } + + pub fn to_der(&self) -> Result, Error> { + let der = self.0.as_der().to_vec(); + Ok(der) + } +} + +impl Debug for Certificate { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("Certificate") + .field(&self.0.as_der()) + .finish() + } +} + +pub struct TlsStream { + ctx: Context, + role: Endpoint, + identity: Option, +} + +impl TlsStream { + pub fn get_ref(&self) -> &S { + self.ctx.io().expect("Not connected") + } + + pub fn get_mut(&mut self) -> &mut S { + self.ctx.io_mut().expect("Not connected") + } + + pub fn buffered_read_size(&self) -> Result { + Ok(self.ctx.bytes_available()) + } + + #[cfg(feature = "alpn")] + pub fn negotiated_alpn(&self) -> Result>, Error> { + Ok(self.ctx.get_alpn_protocol()?.map(|s| s.as_bytes().to_vec())) + } + + pub fn peer_certificate(&self) -> Result, Error> { + let cert = match self.ctx.peer_cert() { + Ok(Some(certs)) => certs.iter().next().map(|cert| Certificate(cert.clone())), + Ok(_) => None, + Err(e) => match e { + TlsError::SslBadInputData => None, + _ => return Err(Error::Tls(e)), + }, + }; + Ok(cert) + } + + fn server_certificate(&self) -> Result, Error> { + match self.role { + Endpoint::Client => self.peer_certificate(), + Endpoint::Server => match self.identity { + Some(ref idt) => Ok(idt + .certificates() + .iter() + .map(|cert| Certificate(cert.clone())) + .next()), + None => Ok(None), + }, + } + } + + pub fn tls_server_end_point(&self) -> Result>, Error> { + let cert = match self.server_certificate()? { + Some(cert) => cert, + None => return Ok(None), + }; + + let md = match cert.0.digest_type() { + MdType::Md5 | MdType::Sha1 => MdType::Sha256, + md => md, + }; + + let der = cert.to_der()?; + let mut digest = vec![0; 64]; + let len = Md::hash(md, &der, &mut digest).map_err(Error::Tls)?; + digest.truncate(len); + + Ok(Some(digest)) + } + + pub fn shutdown(&mut self) -> io::Result<()> { + self.ctx.close(); + Ok(()) + } +} + +impl io::Read for TlsStream { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + self.ctx.read(buf) + } +} + +impl io::Write for TlsStream { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.ctx.write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + self.ctx.flush() + } +} + +impl Debug for TlsStream { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("TlsStream") + .field( + "role", + &match self.role { + Endpoint::Client => "client", + Endpoint::Server => "server", + }, + ) + .field("identity", &self.identity) + .finish() + } +} + +#[derive(Debug)] +pub struct MidHandshakeTlsStream(TlsStream); + +pub enum HandshakeError { + Failure(Error), + // this is actually unused + WouldBlock(MidHandshakeTlsStream), +} + +impl MidHandshakeTlsStream { + pub fn get_ref(&self) -> &S { + self.0.get_ref() + } + + pub fn get_mut(&mut self) -> &mut S { + self.0.get_mut() + } +} + +impl MidHandshakeTlsStream +where + S: io::Read + io::Write, +{ + pub fn handshake(self) -> Result, HandshakeError> { + Ok(self.0) + } +} + +#[derive(Clone)] +pub struct TlsConnector { + config: Arc, + identity: Option<::Identity>, + accept_invalid_hostnames: bool, +} + +impl Debug for TlsConnector { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("TlsConnector") + .field("identity", &self.identity.as_ref().map(|idt| &idt.0)) + .field("accept_invalid_hostnames", &self.accept_invalid_hostnames) + .finish() + } +} + +impl TlsConnector { + pub fn new(builder: &TlsConnectorBuilder) -> Result { + let mut config = Config::new(Endpoint::Client, Transport::Stream, Preset::Default); + + // Set Rng + let entropy = Arc::new(Rdseed); + let rng = Arc::new(CtrDrbg::new(entropy, None)?); + config.set_rng(rng); + + // Set root certificates + let ca_list = builder + .root_certificates + .iter() + .map(|cert| (cert.0).0.clone()) + .collect(); + config.set_ca_list(Arc::new(ca_list), None); + + // Add identity certificates and key + if let Some(identity) = &builder.identity { + config.push_cert(identity.0.certificates(), identity.0.private_key())?; + } + + // Set authmode + if builder.accept_invalid_certs { + config.set_authmode(mbedtls::ssl::config::AuthMode::None); + } + + // Set minimum protocol version + if let Some(min_version) = builder.min_protocol.map(to_mbedtls_version) { + config.set_min_version(min_version)?; + } + + // Set maximum protocol version + if let Some(max_version) = builder.max_protocol.map(to_mbedtls_version) { + config.set_max_version(max_version)?; + } + + #[cfg(feature = "alpn")] + { + if !builder.alpn.is_empty() { + let alpns: Vec<_> = builder + .alpn + .iter() + .map(|protocol| protocol.as_str()) + .collect(); + config.set_alpn_protocols(Arc::new(NullTerminatedStrList::new(&alpns)?))?; + } + } + + Ok(TlsConnector { + config: Arc::new(config), + identity: builder.identity.clone(), + accept_invalid_hostnames: builder.accept_invalid_hostnames, + }) + } + + pub fn connect(&self, domain: &str, stream: S) -> Result, HandshakeError> + where + S: io::Read + io::Write, + { + // Create mbedtls context + let mut ctx = Context::new(self.config.clone()); + + // Establish connection + let hostname = if self.accept_invalid_hostnames { + None + } else { + Some(domain) + }; + + ctx.establish(stream, hostname)?; + + Ok(TlsStream { + ctx, + role: Endpoint::Client, + identity: self.identity.clone().map(|idt| idt.0), + }) + } +} + +#[derive(Clone)] +pub struct TlsAcceptor { + config: Arc, + identity: Identity, +} + +impl TlsAcceptor { + pub fn new(builder: &TlsAcceptorBuilder) -> Result { + let mut config = Config::new(Endpoint::Server, Transport::Stream, Preset::Default); + + // Set Rng + let entropy = Arc::new(Rdseed); + let rng = Arc::new(CtrDrbg::new(entropy, None)?); + config.set_rng(rng); + + // Add identity certificates and key + config.push_cert( + builder.identity.0.certificates(), + builder.identity.0.private_key(), + )?; + + // Set minimum protocol version + if let Some(min_version) = builder.min_protocol.map(to_mbedtls_version) { + config.set_min_version(min_version)?; + } + + // Set maximum protocol version + if let Some(max_version) = builder.max_protocol.map(to_mbedtls_version) { + config.set_max_version(max_version)?; + } + + Ok(TlsAcceptor { + config: Arc::new(config), + identity: (builder.identity.0).clone(), + }) + } + + pub fn accept(&self, stream: S) -> Result, HandshakeError> + where + S: io::Read + io::Write, + { + // Create mbedtls context + let mut ctx = Context::new(self.config.clone()); + + // Establish connection + ctx.establish(stream, None)?; + + Ok(TlsStream { + ctx, + role: Endpoint::Server, + identity: Some(self.identity.clone()), + }) + } +} diff --git a/src/lib.rs b/src/lib.rs index bc907130..6c06fd41 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -108,18 +108,32 @@ use std::fmt; use std::io; use std::result; -#[cfg(not(any(target_os = "macos", target_os = "windows", target_os = "ios")))] +#[cfg(not(any( + target_os = "macos", + target_os = "windows", + target_os = "ios", + target_env = "sgx" +)))] #[macro_use] extern crate log; #[cfg(any(target_os = "macos", target_os = "ios"))] #[path = "imp/security_framework.rs"] mod imp; -#[cfg(target_os = "windows")] +#[cfg(all(target_os = "windows", not(target_env = "sgx")))] #[path = "imp/schannel.rs"] mod imp; -#[cfg(not(any(target_os = "macos", target_os = "windows", target_os = "ios")))] +#[cfg(not(any( + target_os = "macos", + target_os = "windows", + target_os = "ios", + target_os = "espidf", + target_env = "sgx" +)))] #[path = "imp/openssl.rs"] mod imp; +#[cfg(target_env = "sgx")] +#[path = "imp/mbedtls.rs"] +mod imp; #[cfg(test)] mod test; diff --git a/src/test.rs b/src/test.rs index c51b0bc4..1b68bd38 100644 --- a/src/test.rs +++ b/src/test.rs @@ -7,6 +7,38 @@ use std::thread; use super::*; +#[cfg(target_env = "sgx")] +lazy_static::lazy_static! { + static ref ROOT_CERTIFICATES: Vec = { + // except digicert just because we have to provide any exclusion to get the rest + let mut root_certs = ureq::get("https://mkcert.org/generate/all/except/digicert") + .call() + .unwrap() + .into_string() + .unwrap(); + root_certs.push('\0'); + let root_certs = mbedtls::x509::certificate::Certificate::from_pem_multiple(root_certs.as_bytes()).unwrap(); + root_certs.iter().map(|cert| Certificate::from_der(cert.as_der()).unwrap()).collect() + }; +} + +// for mbedtls there is no 'standard' way to get default ca root chain +// so for tests where some default is needed we manually add mozilla trust chain. +macro_rules! connector { + () => {{ + #[cfg(target_env = "sgx")] + { + let mut builder = TlsConnector::builder(); + ROOT_CERTIFICATES.iter().for_each(|cert| { + builder.add_root_certificate(cert.clone()); + }); + builder + } + #[cfg(not(target_env = "sgx"))] + TlsConnector::builder() + }}; +} + macro_rules! p { ($e:expr) => { match $e { @@ -18,7 +50,7 @@ macro_rules! p { #[test] fn connect_google() { - let builder = p!(TlsConnector::new()); + let builder = p!(connector!().build()); let s = p!(TcpStream::connect("google.com:443")); let mut socket = p!(builder.connect("google.com", s)); @@ -26,23 +58,20 @@ fn connect_google() { let mut result = vec![]; p!(socket.read_to_end(&mut result)); - println!("{}", String::from_utf8_lossy(&result)); assert!(result.starts_with(b"HTTP/1.0")); assert!(result.ends_with(b"\r\n") || result.ends_with(b"")); } #[test] fn connect_bad_hostname() { - let builder = p!(TlsConnector::new()); + let builder = p!(connector!().build()); let s = p!(TcpStream::connect("google.com:443")); builder.connect("goggle.com", s).unwrap_err(); } #[test] fn connect_bad_hostname_ignored() { - let builder = p!(TlsConnector::builder() - .danger_accept_invalid_hostnames(true) - .build()); + let builder = p!(connector!().danger_accept_invalid_hostnames(true).build()); let s = p!(TcpStream::connect("google.com:443")); builder.connect("goggle.com", s).unwrap(); } @@ -408,7 +437,7 @@ fn shutdown() { #[test] #[cfg(feature = "alpn")] fn alpn_google_h2() { - let builder = p!(TlsConnector::builder().request_alpns(&["h2"]).build()); + let builder = p!(connector!().request_alpns(&["h2"]).build()); let s = p!(TcpStream::connect("google.com:443")); let socket = p!(builder.connect("google.com", s)); let alpn = p!(socket.negotiated_alpn()); @@ -418,7 +447,7 @@ fn alpn_google_h2() { #[test] #[cfg(feature = "alpn")] fn alpn_google_invalid() { - let builder = p!(TlsConnector::builder().request_alpns(&["h2c"]).build()); + let builder = p!(connector!().request_alpns(&["h2c"]).build()); let s = p!(TcpStream::connect("google.com:443")); let socket = p!(builder.connect("google.com", s)); let alpn = p!(socket.negotiated_alpn()); @@ -428,7 +457,7 @@ fn alpn_google_invalid() { #[test] #[cfg(feature = "alpn")] fn alpn_google_none() { - let builder = p!(TlsConnector::new()); + let builder = p!(connector!().build()); let s = p!(TcpStream::connect("google.com:443")); let socket = p!(builder.connect("google.com", s)); let alpn = p!(socket.negotiated_alpn());