diff --git a/Cargo.toml b/Cargo.toml index 351e3133..09308c42 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,14 +11,14 @@ readme = "README.md" vendored = ["openssl/vendored"] [target.'cfg(any(target_os = "macos", target_os = "ios"))'.dependencies] -security-framework = "0.3.1" -security-framework-sys = "0.3.1" +security-framework = {version = "0.3.1", features = ["OSX_10_12"]} +security-framework-sys = {version = "0.3.1", features = ["OSX_10_12"]} lazy_static = "1.0" libc = "0.2" tempfile = "3.0" [target.'cfg(target_os = "windows")'.dependencies] -schannel = "0.1.13" +schannel = {version = "0.1.15"} [target.'cfg(not(any(target_os = "windows", target_os = "macos", target_os = "ios")))'.dependencies] log = "0.4.5" diff --git a/src/imp/openssl.rs b/src/imp/openssl.rs index 75264d24..34a0582e 100644 --- a/src/imp/openssl.rs +++ b/src/imp/openssl.rs @@ -9,7 +9,8 @@ use self::openssl::ssl::{ self, MidHandshakeSslStream, SslAcceptor, SslConnector, SslContextBuilder, SslMethod, SslVerifyMode, }; -use self::openssl::x509::{X509, X509VerifyResult}; +use self::openssl::stack; +use self::openssl::x509::{X509VerifyResult, X509}; use std::error; use std::fmt; use std::io; @@ -177,6 +178,12 @@ impl Certificate { let der = self.0.to_der()?; Ok(der) } + + pub fn public_key_info_der(&self) -> Result, Error> { + let pk = self.0.public_key()?; + let der = pk.public_key_to_der()?; + Ok(der) + } } pub struct MidHandshakeTlsStream(MidHandshakeSslStream); @@ -324,6 +331,19 @@ impl TlsAcceptor { } } +pub struct ChainIterator<'a, S: 'a>(Option>, &'a TlsStream); + +impl<'a, S> Iterator for ChainIterator<'a, S> { + type Item = Certificate; + + fn next(&mut self) -> Option { + if let Some(i) = self.0.as_mut() { + return i.next().map(|c| Certificate(c.to_owned())); + } + None + } +} + pub struct TlsStream(ssl::SslStream); impl fmt::Debug for TlsStream { @@ -349,6 +369,13 @@ impl TlsStream { Ok(self.0.ssl().peer_certificate().map(Certificate)) } + pub fn certificate_chain(&mut self) -> Result, Error> { + Ok(ChainIterator( + self.0.ssl().peer_cert_chain().map(|stack| stack.iter()), + self, + )) + } + pub fn tls_server_end_point(&self) -> Result>, Error> { let cert = if self.0.ssl().is_server() { self.0.ssl().certificate().map(|x| x.to_owned()) diff --git a/src/imp/schannel.rs b/src/imp/schannel.rs index fee17e89..fb059301 100644 --- a/src/imp/schannel.rs +++ b/src/imp/schannel.rs @@ -1,7 +1,7 @@ extern crate schannel; use self::schannel::cert_context::{CertContext, HashAlgorithm}; -use self::schannel::cert_store::{CertAdd, CertStore, Memory, PfxImportOptions}; +use self::schannel::cert_store::{CertAdd, CertStore, Certs, Memory, PfxImportOptions}; use self::schannel::schannel_cred::{Direction, Protocol, SchannelCred}; use self::schannel::tls_stream; use std::error; @@ -89,7 +89,8 @@ impl Identity { return Err(io::Error::new( io::ErrorKind::InvalidInput, "No identity found in PKCS #12 archive", - ).into()); + ) + .into()); } }; @@ -115,13 +116,18 @@ impl Certificate { Err(_) => Err(io::Error::new( io::ErrorKind::InvalidInput, "PEM representation contains non-UTF-8 bytes", - ).into()), + ) + .into()), } } pub fn to_der(&self) -> Result, Error> { Ok(self.0.to_der().to_vec()) } + + pub fn public_key_info_der(&self) -> Result, Error> { + Ok(self.0.subject_public_key_info_der()?) + } } pub struct MidHandshakeTlsStream(tls_stream::MidHandshakeTlsStream); @@ -149,7 +155,10 @@ where pub fn handshake(self) -> Result, HandshakeError> { match self.0.handshake() { - Ok(s) => Ok(TlsStream(s)), + Ok(stream) => Ok(TlsStream { + stream, + store: None, + }), Err(e) => Err(e.into()), } } @@ -227,7 +236,10 @@ impl TlsConnector { builder.verify_callback(|_| Ok(())); } match builder.connect(cred, stream) { - Ok(s) => Ok(TlsStream(s)), + Ok(stream) => Ok(TlsStream { + stream, + store: None, + }), Err(e) => Err(e.into()), } } @@ -259,46 +271,85 @@ impl TlsAcceptor { // FIXME we're probably missing the certificate chain? let cred = builder.acquire(Direction::Inbound)?; match tls_stream::Builder::new().accept(cred, stream) { - Ok(s) => Ok(TlsStream(s)), + Ok(stream) => Ok(TlsStream { + stream, + store: None, + }), Err(e) => Err(e.into()), } } } -pub struct TlsStream(tls_stream::TlsStream); +pub struct ChainIterator<'a, S: 'a> { + certs: Option>, + _stream: &'a TlsStream, +} +impl<'a, S> Iterator for ChainIterator<'a, S> { + type Item = Certificate; + + fn next(&mut self) -> Option { + if let Some(certs) = self.certs.as_mut() { + return certs.next().map(Certificate); + } + None + } +} + +pub struct TlsStream { + stream: tls_stream::TlsStream, + store: Option, +} impl fmt::Debug for TlsStream { fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - fmt::Debug::fmt(&self.0, fmt) + fmt::Debug::fmt(&self.stream, fmt) } } impl TlsStream { pub fn get_ref(&self) -> &S { - self.0.get_ref() + self.stream.get_ref() } pub fn get_mut(&mut self) -> &mut S { - self.0.get_mut() + self.stream.get_mut() } pub fn buffered_read_size(&self) -> Result { - Ok(self.0.get_buf().len()) + Ok(self.stream.get_buf().len()) } pub fn peer_certificate(&self) -> Result, Error> { - match self.0.peer_certificate() { + match self.stream.peer_certificate() { Ok(cert) => Ok(Some(Certificate(cert))), Err(ref e) if e.raw_os_error() == Some(SEC_E_NO_CREDENTIALS as i32) => Ok(None), Err(e) => Err(Error(e)), } } + pub fn certificate_chain(&mut self) -> Result, Error> { + if self.store.is_none() { + match self.stream.peer_certificate() { + Ok(cert) => { + self.store = cert.cert_store(); + } + Err(ref e) if e.raw_os_error() == Some(SEC_E_NO_CREDENTIALS as i32) => { + self.store = None; + } + Err(e) => return Err(Error(e)), + } + } + Ok(ChainIterator { + certs: self.store.as_ref().map(|c| c.certs()), + _stream: self, + }) + } + pub fn tls_server_end_point(&self) -> Result>, Error> { - let cert = if self.0.is_server() { - self.0.certificate() + let cert = if self.stream.is_server() { + self.stream.certificate() } else { - self.0.peer_certificate() + self.stream.peer_certificate() }; let cert = match cert { @@ -320,23 +371,23 @@ impl TlsStream { } pub fn shutdown(&mut self) -> io::Result<()> { - self.0.shutdown()?; + self.stream.shutdown()?; Ok(()) } } impl io::Read for TlsStream { fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.0.read(buf) + self.stream.read(buf) } } impl io::Write for TlsStream { fn write(&mut self, buf: &[u8]) -> io::Result { - self.0.write(buf) + self.stream.write(buf) } fn flush(&mut self) -> io::Result<()> { - self.0.flush() + self.stream.flush() } } diff --git a/src/imp/security_framework.rs b/src/imp/security_framework.rs index f23fe131..f9bd897b 100644 --- a/src/imp/security_framework.rs +++ b/src/imp/security_framework.rs @@ -3,18 +3,20 @@ extern crate security_framework; extern crate security_framework_sys; extern crate tempfile; -use self::security_framework::base; use self::security_framework::certificate::SecCertificate; use self::security_framework::identity::SecIdentity; use self::security_framework::import_export::{ImportedIdentity, Pkcs12ImportOptions}; use self::security_framework::secure_transport::{ self, ClientBuilder, SslConnectionType, SslContext, SslProtocol, SslProtocolSide, }; +use self::security_framework::{base, trust::SecTrust}; use self::security_framework_sys::base::errSecIO; + use self::tempfile::TempDir; use std::error; use std::fmt; use std::io; + use std::sync::Mutex; use std::sync::{Once, ONCE_INIT}; @@ -174,6 +176,10 @@ impl Certificate { pub fn to_der(&self) -> Result, Error> { Ok(self.0.to_der()) } + + pub fn public_key_info_der(&self) -> Result, Error> { + Ok(self.0.public_key_info_der()?.unwrap_or(Vec::new())) + } } pub enum HandshakeError { @@ -351,6 +357,24 @@ impl TlsAcceptor { } } +pub struct ChainIterator<'a, S: 'a> { + trust: Option, + pos: usize, + _stream: &'a TlsStream, +} +impl<'a, S> Iterator for ChainIterator<'a, S> { + type Item = Certificate; + + fn next(&mut self) -> Option { + if let Some(trust) = self.trust.as_ref() { + let pos = self.pos; + self.pos += 1; + return trust.certificate_at_index(pos as _).map(Certificate); + } + None + } +} + pub struct TlsStream { stream: secure_transport::SslStream, cert: Option, @@ -385,6 +409,21 @@ impl TlsStream { Ok(trust.certificate_at_index(0).map(Certificate)) } + pub fn certificate_chain(&mut self) -> Result, Error> { + let trust = match self.stream.context().peer_trust2()? { + Some(trust) => { + trust.evaluate()?; + Some(trust) + } + None => None, + }; + Ok(ChainIterator { + trust, + pos: 0, + _stream: self, + }) + } + #[cfg(target_os = "ios")] pub fn tls_server_end_point(&self) -> Result>, Error> { Ok(None) diff --git a/src/lib.rs b/src/lib.rs index 5efb08fc..bb08914e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -206,6 +206,23 @@ impl Certificate { let der = self.0.to_der()?; Ok(der) } + + /// Returns der encoded subjectPublicKeyInfo. + pub fn public_key_info_der(&self) -> Result> { + let der = self.0.public_key_info_der()?; + Ok(der) + } +} + +/// An iterator over a certificate chain. +pub struct ChainIterator<'a, S: 'a>(imp::ChainIterator<'a, S>); + +impl<'a, S> Iterator for ChainIterator<'a, S> { + type Item = Certificate; + + fn next(&mut self) -> Option { + self.0.next().map(Certificate) + } } /// A TLS stream which has been interrupted midway through the handshake process. @@ -630,6 +647,11 @@ impl TlsStream { Ok(self.0.peer_certificate()?.map(Certificate)) } + /// Returns an iterator over certificate chain. It may be an empty iterator if chain not available. + pub fn certificate_chain(&mut self) -> Result> { + Ok(ChainIterator(self.0.certificate_chain()?)) + } + /// Returns the tls-server-end-point channel binding data as defined in [RFC 5929]. /// /// [RFC 5929]: https://tools.ietf.org/html/rfc5929