Skip to content

Commit 04b409a

Browse files
committed
feat(sqlx-core): split connector creation from handshake
1 parent b180eba commit 04b409a

3 files changed

Lines changed: 66 additions & 17 deletions

File tree

sqlx-core/src/net/tls/mod.rs

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,15 +60,37 @@ impl std::fmt::Display for CertificateInput {
6060
pub struct TlsConfig<'a> {
6161
pub accept_invalid_certs: bool,
6262
pub accept_invalid_hostnames: bool,
63-
pub hostname: &'a str,
6463
pub root_cert_path: Option<&'a CertificateInput>,
6564
pub client_cert_path: Option<&'a CertificateInput>,
6665
pub client_key_path: Option<&'a CertificateInput>,
6766
}
6867

68+
#[cfg(feature = "_tls-native-tls")]
69+
pub use self::tls_native_tls::NativeTlsConnector as TlsConnector;
70+
#[cfg(all(feature = "_tls-rustls", not(feature = "_tls-native-tls")))]
71+
pub use self::tls_rustls::RustlsConnector as TlsConnector;
72+
#[cfg(not(any(feature = "_tls-native-tls", feature = "_tls-rustls")))]
73+
#[derive(Clone)]
74+
pub struct TlsConnector(std::convert::Infallible);
75+
76+
pub async fn connector(config: TlsConfig<'_>) -> crate::Result<TlsConnector> {
77+
#[cfg(feature = "_tls-native-tls")]
78+
return Ok(tls_native_tls::connector(config).await?);
79+
80+
#[cfg(all(feature = "_tls-rustls", not(feature = "_tls-native-tls")))]
81+
return Ok(tls_rustls::connector(config).await?);
82+
83+
#[cfg(not(any(feature = "_tls-native-tls", feature = "_tls-rustls")))]
84+
{
85+
_ = config;
86+
panic!("one of the `runtime-*-native-tls` or `runtime-*-rustls` features must be enabled")
87+
}
88+
}
89+
6990
pub async fn handshake<S, Ws>(
7091
socket: S,
71-
config: TlsConfig<'_>,
92+
hostname: &str,
93+
connector: TlsConnector,
7294
with_socket: Ws,
7395
) -> crate::Result<Ws::Output>
7496
where
@@ -77,18 +99,18 @@ where
7799
{
78100
#[cfg(feature = "_tls-native-tls")]
79101
return Ok(with_socket
80-
.with_socket(tls_native_tls::handshake(socket, config).await?)
102+
.with_socket(tls_native_tls::handshake(socket, hostname, connector).await?)
81103
.await);
82104

83105
#[cfg(all(feature = "_tls-rustls", not(feature = "_tls-native-tls")))]
84106
return Ok(with_socket
85-
.with_socket(tls_rustls::handshake(socket, config).await?)
107+
.with_socket(tls_rustls::handshake(socket, hostname, connector).await?)
86108
.await);
87109

88110
#[cfg(not(any(feature = "_tls-native-tls", feature = "_tls-rustls")))]
89111
{
90-
drop((socket, config, with_socket));
91-
panic!("one of the `runtime-*-native-tls` or `runtime-*-rustls` features must be enabled")
112+
drop((socket, hostname, with_socket));
113+
match connector.0 {}
92114
}
93115
}
94116

sqlx-core/src/net/tls/tls_native_tls.rs

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,12 @@ impl<S: Socket> Socket for NativeTlsSocket<S> {
3939
}
4040
}
4141

42-
pub async fn handshake<S: Socket>(
43-
socket: S,
44-
config: TlsConfig<'_>,
45-
) -> crate::Result<NativeTlsSocket<S>> {
42+
#[derive(Clone)]
43+
pub struct NativeTlsConnector {
44+
connector: native_tls::TlsConnector,
45+
}
46+
47+
pub async fn connector(config: TlsConfig<'_>) -> crate::Result<NativeTlsConnector> {
4648
let mut builder = native_tls::TlsConnector::builder();
4749

4850
builder
@@ -67,8 +69,18 @@ pub async fn handshake<S: Socket>(
6769
let connector = rt::spawn_blocking(move || builder.build())
6870
.await
6971
.map_err(Error::tls)?;
72+
Ok(NativeTlsConnector { connector })
73+
}
7074

71-
let mut mid_handshake = match connector.connect(config.hostname, StdSocket::new(socket)) {
75+
pub async fn handshake<S: Socket>(
76+
socket: S,
77+
hostname: &str,
78+
connector: NativeTlsConnector,
79+
) -> crate::Result<NativeTlsSocket<S>> {
80+
let mut mid_handshake = match connector
81+
.connector
82+
.connect(hostname, StdSocket::new(socket))
83+
{
7284
Ok(tls_stream) => return Ok(NativeTlsSocket { stream: tls_stream }),
7385
Err(HandshakeError::Failure(e)) => return Err(Error::tls(e)),
7486
Err(HandshakeError::WouldBlock(mid_handshake)) => mid_handshake,

sqlx-core/src/net/tls/tls_rustls.rs

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,12 @@ impl<S: Socket> Socket for RustlsSocket<S> {
8787
}
8888
}
8989

90-
pub async fn handshake<S>(socket: S, tls_config: TlsConfig<'_>) -> Result<RustlsSocket<S>, Error>
91-
where
92-
S: Socket,
93-
{
90+
#[derive(Clone)]
91+
pub struct RustlsConnector {
92+
config: Arc<ClientConfig>,
93+
}
94+
95+
pub async fn connector(tls_config: TlsConfig<'_>) -> Result<RustlsConnector, Error> {
9496
#[cfg(all(
9597
feature = "_tls-rustls-aws-lc-rs",
9698
not(feature = "_tls-rustls-ring-webpki"),
@@ -180,11 +182,24 @@ where
180182
}
181183
};
182184

183-
let host = ServerName::try_from(tls_config.hostname.to_owned()).map_err(Error::tls)?;
185+
Ok(RustlsConnector {
186+
config: Arc::new(config),
187+
})
188+
}
189+
190+
pub async fn handshake<S>(
191+
socket: S,
192+
hostname: &str,
193+
connector: RustlsConnector,
194+
) -> Result<RustlsSocket<S>, Error>
195+
where
196+
S: Socket,
197+
{
198+
let host = ServerName::try_from(hostname.to_owned()).map_err(Error::tls)?;
184199

185200
let mut socket = RustlsSocket {
186201
inner: StdSocket::new(socket),
187-
state: ClientConnection::new(Arc::new(config), host).map_err(Error::tls)?,
202+
state: ClientConnection::new(connector.config, host).map_err(Error::tls)?,
188203
close_notify_sent: false,
189204
};
190205

0 commit comments

Comments
 (0)