Skip to content

Commit 9db9a93

Browse files
committed
feat(sqlx-mysql, sqlx-postgres): cache the TLS connector
1 parent 04b409a commit 9db9a93

10 files changed

Lines changed: 148 additions & 84 deletions

File tree

sqlx-core/src/net/tls/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ pub use self::tls_native_tls::NativeTlsConnector as TlsConnector;
7070
#[cfg(all(feature = "_tls-rustls", not(feature = "_tls-native-tls")))]
7171
pub use self::tls_rustls::RustlsConnector as TlsConnector;
7272
#[cfg(not(any(feature = "_tls-native-tls", feature = "_tls-rustls")))]
73-
#[derive(Clone)]
73+
#[derive(Debug, Clone)]
7474
pub struct TlsConnector(std::convert::Infallible);
7575

7676
pub async fn connector(config: TlsConfig<'_>) -> crate::Result<TlsConnector> {
@@ -90,7 +90,7 @@ pub async fn connector(config: TlsConfig<'_>) -> crate::Result<TlsConnector> {
9090
pub async fn handshake<S, Ws>(
9191
socket: S,
9292
hostname: &str,
93-
connector: TlsConnector,
93+
connector: &TlsConnector,
9494
with_socket: Ws,
9595
) -> crate::Result<Ws::Output>
9696
where

sqlx-core/src/net/tls/tls_native_tls.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ impl<S: Socket> Socket for NativeTlsSocket<S> {
3939
}
4040
}
4141

42-
#[derive(Clone)]
42+
#[derive(Debug, Clone)]
4343
pub struct NativeTlsConnector {
4444
connector: native_tls::TlsConnector,
4545
}
@@ -75,7 +75,7 @@ pub async fn connector(config: TlsConfig<'_>) -> crate::Result<NativeTlsConnecto
7575
pub async fn handshake<S: Socket>(
7676
socket: S,
7777
hostname: &str,
78-
connector: NativeTlsConnector,
78+
connector: &NativeTlsConnector,
7979
) -> crate::Result<NativeTlsSocket<S>> {
8080
let mut mid_handshake = match connector
8181
.connector

sqlx-core/src/net/tls/tls_rustls.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ impl<S: Socket> Socket for RustlsSocket<S> {
8787
}
8888
}
8989

90-
#[derive(Clone)]
90+
#[derive(Debug, Clone)]
9191
pub struct RustlsConnector {
9292
config: Arc<ClientConfig>,
9393
}
@@ -190,7 +190,7 @@ pub async fn connector(tls_config: TlsConfig<'_>) -> Result<RustlsConnector, Err
190190
pub async fn handshake<S>(
191191
socket: S,
192192
hostname: &str,
193-
connector: RustlsConnector,
193+
connector: &RustlsConnector,
194194
) -> Result<RustlsSocket<S>, Error>
195195
where
196196
S: Socket,
@@ -199,7 +199,7 @@ where
199199

200200
let mut socket = RustlsSocket {
201201
inner: StdSocket::new(socket),
202-
state: ClientConnection::new(connector.config, host).map_err(Error::tls)?,
202+
state: ClientConnection::new(connector.config.clone(), host).map_err(Error::tls)?,
203203
close_notify_sent: false,
204204
};
205205

sqlx-mysql/src/connection/establish.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ impl<'a> DoHandshake<'a> {
4242
fn new(options: &'a MySqlConnectOptions) -> Result<Self, Error> {
4343
if options.enable_cleartext_plugin
4444
&& matches!(
45-
options.ssl_mode,
45+
options.ssl_options.ssl_mode,
4646
MySqlSslMode::Disabled | MySqlSslMode::Preferred
4747
)
4848
{

sqlx-mysql/src/connection/tls.rs

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,13 @@ pub(super) async fn maybe_upgrade<S: Socket>(
2020
) -> Result<MySqlStream, Error> {
2121
let server_supports_tls = stream.capabilities.contains(Capabilities::SSL);
2222

23-
if matches!(options.ssl_mode, MySqlSslMode::Disabled) || !tls::available() {
23+
if matches!(options.ssl_options.ssl_mode, MySqlSslMode::Disabled) || !tls::available() {
2424
// remove the SSL capability if SSL has been explicitly disabled
2525
stream.capabilities.remove(Capabilities::SSL);
2626
}
2727

2828
// https://www.postgresql.org/docs/12/libpq-ssl.html#LIBPQ-SSL-SSLMODE-STATEMENTS
29-
match options.ssl_mode {
29+
match options.ssl_options.ssl_mode {
3030
MySqlSslMode::Disabled => return Ok(stream.boxed_socket()),
3131

3232
MySqlSslMode::Preferred => {
@@ -53,16 +53,27 @@ pub(super) async fn maybe_upgrade<S: Socket>(
5353
}
5454
}
5555

56-
let tls_config = TlsConfig {
57-
accept_invalid_certs: !matches!(
58-
options.ssl_mode,
59-
MySqlSslMode::VerifyCa | MySqlSslMode::VerifyIdentity
60-
),
61-
accept_invalid_hostnames: !matches!(options.ssl_mode, MySqlSslMode::VerifyIdentity),
62-
hostname: &options.host,
63-
root_cert_path: options.ssl_ca.as_ref(),
64-
client_cert_path: options.ssl_client_cert.as_ref(),
65-
client_key_path: options.ssl_client_key.as_ref(),
56+
let connector = if let Some(c) = options.ssl_options.cached_connector.get() {
57+
c
58+
} else {
59+
let tls_config = TlsConfig {
60+
accept_invalid_certs: !matches!(
61+
options.ssl_options.ssl_mode,
62+
MySqlSslMode::VerifyCa | MySqlSslMode::VerifyIdentity
63+
),
64+
accept_invalid_hostnames: !matches!(
65+
options.ssl_options.ssl_mode,
66+
MySqlSslMode::VerifyIdentity
67+
),
68+
root_cert_path: options.ssl_options.ssl_ca.as_ref(),
69+
client_cert_path: options.ssl_options.ssl_client_cert.as_ref(),
70+
client_key_path: options.ssl_options.ssl_client_key.as_ref(),
71+
};
72+
let connector = tls::connector(tls_config).await?;
73+
options
74+
.ssl_options
75+
.cached_connector
76+
.get_or_init(|| connector)
6677
};
6778

6879
// Request TLS upgrade
@@ -75,7 +86,8 @@ pub(super) async fn maybe_upgrade<S: Socket>(
7586

7687
tls::handshake(
7788
stream.socket.into_inner(),
78-
tls_config,
89+
&options.host,
90+
connector,
7991
MapStream {
8092
server_version: stream.server_version,
8193
capabilities: stream.capabilities,

sqlx-mysql/src/options/mod.rs

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
1-
use std::path::{Path, PathBuf};
1+
use std::{
2+
path::{Path, PathBuf},
3+
sync::{Arc, OnceLock},
4+
};
25

36
mod connect;
47
mod parse;
58
mod ssl_mode;
69

710
use crate::{connection::LogSettings, net::tls::CertificateInput};
11+
use sqlx_core::net::tls::TlsConnector;
812
pub use ssl_mode::MySqlSslMode;
913

1014
/// Options and flags which can be used to configure a MySQL connection.
@@ -67,10 +71,7 @@ pub struct MySqlConnectOptions {
6771
pub(crate) username: String,
6872
pub(crate) password: Option<String>,
6973
pub(crate) database: Option<String>,
70-
pub(crate) ssl_mode: MySqlSslMode,
71-
pub(crate) ssl_ca: Option<CertificateInput>,
72-
pub(crate) ssl_client_cert: Option<CertificateInput>,
73-
pub(crate) ssl_client_key: Option<CertificateInput>,
74+
pub(crate) ssl_options: SslOptions,
7475
pub(crate) statement_cache_capacity: usize,
7576
pub(crate) charset: String,
7677
pub(crate) collation: Option<String>,
@@ -88,6 +89,15 @@ impl Default for MySqlConnectOptions {
8889
}
8990
}
9091

92+
#[derive(Debug, Clone)]
93+
pub(crate) struct SslOptions {
94+
pub(crate) ssl_mode: MySqlSslMode,
95+
pub(crate) ssl_ca: Option<CertificateInput>,
96+
pub(crate) ssl_client_cert: Option<CertificateInput>,
97+
pub(crate) ssl_client_key: Option<CertificateInput>,
98+
pub(crate) cached_connector: Arc<OnceLock<TlsConnector>>,
99+
}
100+
91101
impl MySqlConnectOptions {
92102
/// Creates a new, default set of options ready for configuration
93103
pub fn new() -> Self {
@@ -100,10 +110,13 @@ impl MySqlConnectOptions {
100110
database: None,
101111
charset: String::from("utf8mb4"),
102112
collation: None,
103-
ssl_mode: MySqlSslMode::Preferred,
104-
ssl_ca: None,
105-
ssl_client_cert: None,
106-
ssl_client_key: None,
113+
ssl_options: SslOptions {
114+
ssl_mode: MySqlSslMode::Preferred,
115+
ssl_ca: None,
116+
ssl_client_cert: None,
117+
ssl_client_key: None,
118+
cached_connector: Arc::new(OnceLock::new()),
119+
},
107120
statement_cache_capacity: 100,
108121
log_settings: Default::default(),
109122
pipes_as_concat: true,
@@ -158,6 +171,11 @@ impl MySqlConnectOptions {
158171
self
159172
}
160173

174+
fn ssl_options_mut(&mut self) -> &mut SslOptions {
175+
Arc::make_mut(&mut self.ssl_options.cached_connector).take();
176+
&mut self.ssl_options
177+
}
178+
161179
/// Sets whether or with what priority a secure SSL TCP/IP connection will be negotiated
162180
/// with the server.
163181
///
@@ -172,7 +190,7 @@ impl MySqlConnectOptions {
172190
/// .ssl_mode(MySqlSslMode::Required);
173191
/// ```
174192
pub fn ssl_mode(mut self, mode: MySqlSslMode) -> Self {
175-
self.ssl_mode = mode;
193+
self.ssl_options_mut().ssl_mode = mode;
176194
self
177195
}
178196

@@ -187,7 +205,7 @@ impl MySqlConnectOptions {
187205
/// .ssl_ca("path/to/ca.crt");
188206
/// ```
189207
pub fn ssl_ca(mut self, file_name: impl AsRef<Path>) -> Self {
190-
self.ssl_ca = Some(CertificateInput::File(file_name.as_ref().to_owned()));
208+
self.ssl_options_mut().ssl_ca = Some(CertificateInput::File(file_name.as_ref().to_owned()));
191209
self
192210
}
193211

@@ -202,7 +220,7 @@ impl MySqlConnectOptions {
202220
/// .ssl_ca_from_pem(vec![]);
203221
/// ```
204222
pub fn ssl_ca_from_pem(mut self, pem_certificate: Vec<u8>) -> Self {
205-
self.ssl_ca = Some(CertificateInput::Inline(pem_certificate));
223+
self.ssl_options_mut().ssl_ca = Some(CertificateInput::Inline(pem_certificate));
206224
self
207225
}
208226

@@ -217,7 +235,8 @@ impl MySqlConnectOptions {
217235
/// .ssl_client_cert("path/to/client.crt");
218236
/// ```
219237
pub fn ssl_client_cert(mut self, cert: impl AsRef<Path>) -> Self {
220-
self.ssl_client_cert = Some(CertificateInput::File(cert.as_ref().to_path_buf()));
238+
self.ssl_options_mut().ssl_client_cert =
239+
Some(CertificateInput::File(cert.as_ref().to_path_buf()));
221240
self
222241
}
223242

@@ -242,7 +261,8 @@ impl MySqlConnectOptions {
242261
/// .ssl_client_cert_from_pem(CERT);
243262
/// ```
244263
pub fn ssl_client_cert_from_pem(mut self, cert: impl AsRef<[u8]>) -> Self {
245-
self.ssl_client_cert = Some(CertificateInput::Inline(cert.as_ref().to_vec()));
264+
self.ssl_options_mut().ssl_client_cert =
265+
Some(CertificateInput::Inline(cert.as_ref().to_vec()));
246266
self
247267
}
248268

@@ -257,7 +277,8 @@ impl MySqlConnectOptions {
257277
/// .ssl_client_key("path/to/client.key");
258278
/// ```
259279
pub fn ssl_client_key(mut self, key: impl AsRef<Path>) -> Self {
260-
self.ssl_client_key = Some(CertificateInput::File(key.as_ref().to_path_buf()));
280+
self.ssl_options_mut().ssl_client_key =
281+
Some(CertificateInput::File(key.as_ref().to_path_buf()));
261282
self
262283
}
263284

@@ -282,7 +303,8 @@ impl MySqlConnectOptions {
282303
/// .ssl_client_key_from_pem(KEY);
283304
/// ```
284305
pub fn ssl_client_key_from_pem(mut self, key: impl AsRef<[u8]>) -> Self {
285-
self.ssl_client_key = Some(CertificateInput::Inline(key.as_ref().to_vec()));
306+
self.ssl_options_mut().ssl_client_key =
307+
Some(CertificateInput::Inline(key.as_ref().to_vec()));
286308
self
287309
}
288310

@@ -497,7 +519,7 @@ impl MySqlConnectOptions {
497519
/// assert!(matches!(options.get_ssl_mode(), MySqlSslMode::Preferred));
498520
/// ```
499521
pub fn get_ssl_mode(&self) -> MySqlSslMode {
500-
self.ssl_mode
522+
self.ssl_options.ssl_mode
501523
}
502524

503525
/// Get the server charset.

sqlx-mysql/src/options/parse.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ impl MySqlConnectOptions {
103103
url.set_path(database);
104104
}
105105

106-
let ssl_mode = match self.ssl_mode {
106+
let ssl_mode = match self.ssl_options.ssl_mode {
107107
MySqlSslMode::Disabled => "DISABLED",
108108
MySqlSslMode::Preferred => "PREFERRED",
109109
MySqlSslMode::Required => "REQUIRED",
@@ -112,7 +112,7 @@ impl MySqlConnectOptions {
112112
};
113113
url.query_pairs_mut().append_pair("ssl-mode", ssl_mode);
114114

115-
if let Some(ssl_ca) = &self.ssl_ca {
115+
if let Some(ssl_ca) = &self.ssl_options.ssl_ca {
116116
url.query_pairs_mut()
117117
.append_pair("ssl-ca", &ssl_ca.to_string());
118118
}
@@ -123,12 +123,12 @@ impl MySqlConnectOptions {
123123
url.query_pairs_mut().append_pair("charset", collation);
124124
}
125125

126-
if let Some(ssl_client_cert) = &self.ssl_client_cert {
126+
if let Some(ssl_client_cert) = &self.ssl_options.ssl_client_cert {
127127
url.query_pairs_mut()
128128
.append_pair("ssl-cert", &ssl_client_cert.to_string());
129129
}
130130

131-
if let Some(ssl_client_key) = &self.ssl_client_key {
131+
if let Some(ssl_client_key) = &self.ssl_options.ssl_client_key {
132132
url.query_pairs_mut()
133133
.append_pair("ssl-key", &ssl_client_key.to_string());
134134
}

sqlx-postgres/src/connection/tls.rs

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ async fn maybe_upgrade<S: Socket>(
2020
options: &PgConnectOptions,
2121
) -> Result<Box<dyn Socket>, Error> {
2222
// https://www.postgresql.org/docs/12/libpq-ssl.html#LIBPQ-SSL-SSLMODE-STATEMENTS
23-
match options.ssl_mode {
23+
match options.ssl_options.ssl_mode {
2424
// FIXME: Implement ALLOW
2525
PgSslMode::Allow | PgSslMode::Disable => return Ok(Box::new(socket)),
2626

@@ -45,22 +45,31 @@ async fn maybe_upgrade<S: Socket>(
4545
}
4646
}
4747

48-
let accept_invalid_certs = !matches!(
49-
options.ssl_mode,
50-
PgSslMode::VerifyCa | PgSslMode::VerifyFull
51-
);
52-
let accept_invalid_hostnames = !matches!(options.ssl_mode, PgSslMode::VerifyFull);
53-
54-
let config = TlsConfig {
55-
accept_invalid_certs,
56-
accept_invalid_hostnames,
57-
hostname: &options.host,
58-
root_cert_path: options.ssl_root_cert.as_ref(),
59-
client_cert_path: options.ssl_client_cert.as_ref(),
60-
client_key_path: options.ssl_client_key.as_ref(),
48+
let connector = if let Some(c) = options.ssl_options.cached_connector.get() {
49+
c
50+
} else {
51+
let accept_invalid_certs = !matches!(
52+
options.ssl_options.ssl_mode,
53+
PgSslMode::VerifyCa | PgSslMode::VerifyFull
54+
);
55+
let accept_invalid_hostnames =
56+
!matches!(options.ssl_options.ssl_mode, PgSslMode::VerifyFull);
57+
58+
let config = TlsConfig {
59+
accept_invalid_certs,
60+
accept_invalid_hostnames,
61+
root_cert_path: options.ssl_options.ssl_root_cert.as_ref(),
62+
client_cert_path: options.ssl_options.ssl_client_cert.as_ref(),
63+
client_key_path: options.ssl_options.ssl_client_key.as_ref(),
64+
};
65+
let connector = tls::connector(config).await?;
66+
options
67+
.ssl_options
68+
.cached_connector
69+
.get_or_init(|| connector)
6170
};
6271

63-
tls::handshake(socket, config, SocketIntoBox).await
72+
tls::handshake(socket, &options.host, connector, SocketIntoBox).await
6473
}
6574

6675
async fn request_upgrade(

0 commit comments

Comments
 (0)