Skip to content

Commit 371ca9b

Browse files
committed
feat: add custom socket transport support for Postgres and MySQL
Add `connect_socket()` methods to `PgConnection` and `MySqlConnection` that accept any pre-connected socket implementing the `Socket` trait. This enables using custom transport layers (e.g., vsock for AWS Nitro Enclaves, QUIC, or other non-TCP/UDS transports) without forking sqlx. Re-export `Socket` and `ReadBuf` traits from `sqlx::net` so users can implement custom socket types.
1 parent 05c8dc1 commit 371ca9b

10 files changed

Lines changed: 200 additions & 60 deletions

File tree

sqlx-core/src/net/mod.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@ mod socket;
22
pub mod tls;
33

44
pub use socket::{
5-
connect_tcp, connect_uds, BufferedSocket, Socket, SocketIntoBox, WithSocket, WriteBuffer,
5+
connect_socket, connect_tcp, connect_uds, BufferedSocket, Socket, SocketIntoBox, WithSocket,
6+
WriteBuffer,
67
};

sqlx-core/src/net/socket/mod.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,20 @@ pub async fn connect_tcp<Ws: WithSocket>(
202202
}
203203
}
204204

205+
/// Connect using a pre-connected socket that implements [`Socket`].
206+
///
207+
/// This allows using custom transport layers (e.g., vsock, QUIC, or any
208+
/// `AsyncRead + AsyncWrite` type) with SQLx database connections.
209+
///
210+
/// The socket will be passed through the `with_socket` handler, which
211+
/// typically performs TLS upgrade negotiation.
212+
pub async fn connect_socket<S: Socket, Ws: WithSocket>(
213+
socket: S,
214+
with_socket: Ws,
215+
) -> crate::Result<Ws::Output> {
216+
Ok(with_socket.with_socket(socket).await)
217+
}
218+
205219
/// Open a TCP socket to `host` and `port`.
206220
///
207221
/// If `host` is a hostname, attempt to connect to each address it resolves to.

sqlx-mysql/src/connection/establish.rs

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,29 @@ impl MySqlConnection {
2222

2323
let stream = handshake?;
2424

25-
Ok(Self {
25+
Ok(Self::establish_with_stream(stream, options))
26+
}
27+
28+
pub(crate) async fn establish_with_socket<S: Socket>(
29+
socket: S,
30+
options: &MySqlConnectOptions,
31+
) -> Result<Self, Error> {
32+
let do_handshake = DoHandshake::new(options)?;
33+
let stream = do_handshake.with_socket(socket).await?;
34+
35+
Ok(Self::establish_with_stream(stream, options))
36+
}
37+
38+
fn establish_with_stream(stream: MySqlStream, options: &MySqlConnectOptions) -> Self {
39+
Self {
2640
inner: Box::new(MySqlConnectionInner {
2741
stream,
2842
transaction_depth: 0,
2943
status_flags: Default::default(),
3044
cache_statement: StatementCache::new(options.statement_cache_capacity),
3145
log_settings: options.log_settings.clone(),
3246
}),
33-
})
47+
}
3448
}
3549
}
3650

sqlx-mysql/src/connection/mod.rs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,38 @@ pub(crate) struct MySqlConnectionInner {
5252
}
5353

5454
impl MySqlConnection {
55+
/// Connect to a MySQL database using a pre-connected socket.
56+
///
57+
/// This allows using custom transport layers such as vsock, QUIC,
58+
/// or any type that implements [`sqlx_core::net::Socket`].
59+
///
60+
/// The provided socket will go through TLS upgrade negotiation based on the
61+
/// SSL mode configured in `options`.
62+
///
63+
/// # Example
64+
///
65+
/// ```rust,ignore
66+
/// use sqlx::mysql::{MySqlConnectOptions, MySqlConnection};
67+
///
68+
/// # async fn example() -> sqlx::Result<()> {
69+
/// let socket: tokio::net::TcpStream = todo!();
70+
/// let options = MySqlConnectOptions::new()
71+
/// .username("root")
72+
/// .database("mydb");
73+
///
74+
/// let _conn = MySqlConnection::connect_socket(socket, &options).await?;
75+
/// # Ok(())
76+
/// # }
77+
/// ```
78+
pub async fn connect_socket<S: sqlx_core::net::Socket>(
79+
socket: S,
80+
options: &MySqlConnectOptions,
81+
) -> Result<Self, Error> {
82+
let mut conn = Self::establish_with_socket(socket, options).await?;
83+
crate::options::apply_connect_options(&mut conn, options).await?;
84+
Ok(conn)
85+
}
86+
5587
pub(crate) fn in_transaction(&self) -> bool {
5688
self.inner
5789
.status_flags

sqlx-mysql/src/options/connect.rs

Lines changed: 67 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -24,80 +24,92 @@ impl ConnectOptions for MySqlConnectOptions {
2424
{
2525
let mut conn = MySqlConnection::establish(self).await?;
2626

27-
// After the connection is established, we initialize by configuring a few
28-
// connection parameters
27+
apply_connect_options(&mut conn, self).await?;
2928

30-
// https://mariadb.com/kb/en/sql-mode/
29+
Ok(conn)
30+
}
3131

32-
// PIPES_AS_CONCAT - Allows using the pipe character (ASCII 124) as string concatenation operator.
33-
// This means that "A" || "B" can be used in place of CONCAT("A", "B").
32+
fn log_statements(mut self, level: LevelFilter) -> Self {
33+
self.log_settings.log_statements(level);
34+
self
35+
}
3436

35-
// NO_ENGINE_SUBSTITUTION - If not set, if the available storage engine specified by a CREATE TABLE is
36-
// not available, a warning is given and the default storage
37-
// engine is used instead.
37+
fn log_slow_statements(mut self, level: LevelFilter, duration: Duration) -> Self {
38+
self.log_settings.log_slow_statements(level, duration);
39+
self
40+
}
41+
}
3842

39-
// NO_ZERO_DATE - Don't allow '0000-00-00'. This is invalid in Rust.
43+
pub(crate) async fn apply_connect_options(
44+
conn: &mut MySqlConnection,
45+
options: &MySqlConnectOptions,
46+
) -> Result<(), Error> {
47+
// After the connection is established, we initialize by configuring a few
48+
// connection parameters
4049

41-
// NO_ZERO_IN_DATE - Don't allow 'YYYY-00-00'. This is invalid in Rust.
50+
// https://mariadb.com/kb/en/sql-mode/
4251

43-
// --
52+
// PIPES_AS_CONCAT - Allows using the pipe character (ASCII 124) as string concatenation operator.
53+
// This means that "A" || "B" can be used in place of CONCAT("A", "B").
4454

45-
// Setting the time zone allows us to assume that the output
46-
// from a TIMESTAMP field is UTC
55+
// NO_ENGINE_SUBSTITUTION - If not set, if the available storage engine specified by a CREATE TABLE is
56+
// not available, a warning is given and the default storage
57+
// engine is used instead.
4758

48-
// --
59+
// NO_ZERO_DATE - Don't allow '0000-00-00'. This is invalid in Rust.
4960

50-
// https://mathiasbynens.be/notes/mysql-utf8mb4
61+
// NO_ZERO_IN_DATE - Don't allow 'YYYY-00-00'. This is invalid in Rust.
5162

52-
let mut sql_mode = Vec::new();
53-
if self.pipes_as_concat {
54-
sql_mode.push(r#"PIPES_AS_CONCAT"#);
55-
}
56-
if self.no_engine_substitution {
57-
sql_mode.push(r#"NO_ENGINE_SUBSTITUTION"#);
58-
}
63+
// --
5964

60-
let mut options = Vec::new();
61-
if !sql_mode.is_empty() {
62-
options.push(format!(
63-
r#"sql_mode=(SELECT CONCAT(@@sql_mode, ',{}'))"#,
64-
sql_mode.join(",")
65-
));
66-
}
65+
// Setting the time zone allows us to assume that the output
66+
// from a TIMESTAMP field is UTC
6767

68-
if let Some(timezone) = &self.timezone {
69-
options.push(format!(r#"time_zone='{}'"#, timezone));
70-
}
68+
// --
7169

72-
if self.set_names {
73-
// As it turns out, we don't _have_ to set a collation if we don't want to.
74-
// We can let the server choose the default collation for the charset.
75-
let set_names = if let Some(collation) = &self.collation {
76-
format!(r#"NAMES {} COLLATE {collation}"#, self.charset,)
77-
} else {
78-
// Leaves the default collation up to the server,
79-
// but ensures statements and results are encoded using the proper charset.
80-
format!("NAMES {}", self.charset)
81-
};
70+
// https://mathiasbynens.be/notes/mysql-utf8mb4
8271

83-
options.push(set_names);
84-
}
72+
let mut sql_mode = Vec::new();
73+
if options.pipes_as_concat {
74+
sql_mode.push(r#"PIPES_AS_CONCAT"#);
75+
}
76+
if options.no_engine_substitution {
77+
sql_mode.push(r#"NO_ENGINE_SUBSTITUTION"#);
78+
}
8579

86-
if !options.is_empty() {
87-
conn.execute(AssertSqlSafe(format!(r#"SET {};"#, options.join(","))))
88-
.await?;
89-
}
80+
let mut session_options = Vec::new();
81+
if !sql_mode.is_empty() {
82+
session_options.push(format!(
83+
r#"sql_mode=(SELECT CONCAT(@@sql_mode, ',{}'))"#,
84+
sql_mode.join(",")
85+
));
86+
}
9087

91-
Ok(conn)
88+
if let Some(timezone) = &options.timezone {
89+
session_options.push(format!(r#"time_zone='{}'"#, timezone));
9290
}
9391

94-
fn log_statements(mut self, level: LevelFilter) -> Self {
95-
self.log_settings.log_statements(level);
96-
self
92+
if options.set_names {
93+
// As it turns out, we don't _have_ to set a collation if we don't want to.
94+
// We can let the server choose the default collation for the charset.
95+
let set_names = if let Some(collation) = &options.collation {
96+
format!(r#"NAMES {} COLLATE {collation}"#, options.charset,)
97+
} else {
98+
// Leaves the default collation up to the server,
99+
// but ensures statements and results are encoded using the proper charset.
100+
format!("NAMES {}", options.charset)
101+
};
102+
103+
session_options.push(set_names);
97104
}
98105

99-
fn log_slow_statements(mut self, level: LevelFilter, duration: Duration) -> Self {
100-
self.log_settings.log_slow_statements(level, duration);
101-
self
106+
if !session_options.is_empty() {
107+
conn.execute(AssertSqlSafe(format!(
108+
r#"SET {};"#,
109+
session_options.join(",")
110+
)))
111+
.await?;
102112
}
113+
114+
Ok(())
103115
}

sqlx-mysql/src/options/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ mod parse;
55
mod ssl_mode;
66

77
use crate::{connection::LogSettings, net::tls::CertificateInput};
8+
pub(crate) use connect::apply_connect_options;
89
pub use ssl_mode::MySqlSslMode;
910

1011
/// Options and flags which can be used to configure a MySQL connection.

sqlx-postgres/src/connection/establish.rs

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use crate::io::StatementId;
77
use crate::message::{
88
Authentication, BackendKeyData, BackendMessageFormat, Password, ReadyForQuery, Startup,
99
};
10+
use crate::net::Socket;
1011
use crate::{PgConnectOptions, PgConnection};
1112

1213
use super::PgConnectionInner;
@@ -16,9 +17,22 @@ use super::PgConnectionInner;
1617

1718
impl PgConnection {
1819
pub(crate) async fn establish(options: &PgConnectOptions) -> Result<Self, Error> {
19-
// Upgrade to TLS if we were asked to and the server supports it
20-
let mut stream = PgStream::connect(options).await?;
20+
let stream = PgStream::connect(options).await?;
21+
Self::establish_with_stream(stream, options).await
22+
}
23+
24+
pub(crate) async fn establish_with_socket<S: Socket>(
25+
socket: S,
26+
options: &PgConnectOptions,
27+
) -> Result<Self, Error> {
28+
let stream = PgStream::connect_socket(socket, options).await?;
29+
Self::establish_with_stream(stream, options).await
30+
}
2131

32+
async fn establish_with_stream(
33+
mut stream: PgStream,
34+
options: &PgConnectOptions,
35+
) -> Result<Self, Error> {
2236
// To begin a session, a frontend opens a connection to the server
2337
// and sends a startup message.
2438

sqlx-postgres/src/connection/mod.rs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,36 @@ pub(crate) struct TableColumns {
8585
}
8686

8787
impl PgConnection {
88+
/// Connect to a PostgreSQL database using a pre-connected socket.
89+
///
90+
/// This allows using custom transport layers such as vsock, QUIC,
91+
/// or any type that implements [`sqlx_core::net::Socket`].
92+
///
93+
/// The provided socket will go through TLS upgrade negotiation based on the
94+
/// SSL mode configured in `options`.
95+
///
96+
/// # Example
97+
///
98+
/// ```rust,ignore
99+
/// use sqlx::postgres::{PgConnectOptions, PgConnection};
100+
///
101+
/// # async fn example() -> sqlx::Result<()> {
102+
/// let socket: tokio::net::TcpStream = todo!();
103+
/// let options = PgConnectOptions::new()
104+
/// .username("postgres")
105+
/// .database("mydb");
106+
///
107+
/// let _conn = PgConnection::connect_socket(socket, &options).await?;
108+
/// # Ok(())
109+
/// # }
110+
/// ```
111+
pub async fn connect_socket<S: sqlx_core::net::Socket>(
112+
socket: S,
113+
options: &PgConnectOptions,
114+
) -> Result<Self, Error> {
115+
Self::establish_with_socket(socket, options).await
116+
}
117+
88118
/// the version number of the server in `libpq` format
89119
pub fn server_version_num(&self) -> Option<u32> {
90120
self.inner.stream.server_version_num

sqlx-postgres/src/connection/stream.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,22 @@ impl PgStream {
5757
})
5858
}
5959

60+
pub(super) async fn connect_socket<S: Socket>(
61+
socket: S,
62+
options: &PgConnectOptions,
63+
) -> Result<Self, Error> {
64+
let socket = net::connect_socket(socket, MaybeUpgradeTls(options)).await?;
65+
66+
let socket = socket?;
67+
68+
Ok(Self {
69+
inner: BufferedSocket::new(socket),
70+
notifications: None,
71+
parameter_statuses: BTreeMap::default(),
72+
server_version_num: None,
73+
})
74+
}
75+
6076
#[inline(always)]
6177
pub(crate) fn write_msg(&mut self, message: impl FrontendMessage) -> Result<(), Error> {
6278
self.write(EncodeMessage(message))

src/lib.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,12 @@ pub mod decode {
153153

154154
pub use self::decode::Decode;
155155

156+
/// Networking traits for custom transport implementations.
157+
pub mod net {
158+
pub use sqlx_core::io::ReadBuf;
159+
pub use sqlx_core::net::Socket;
160+
}
161+
156162
/// Types and traits for the `query` family of functions and macros.
157163
pub mod query {
158164
pub use sqlx_core::query::{Map, Query};

0 commit comments

Comments
 (0)