Skip to content

Commit 5f48205

Browse files
blackbeamdengfuping
authored andcommitted
fix(connection): prefer_socket=false should work, fix prisma/prisma#24010
1 parent 0d40d0d commit 5f48205

File tree

1 file changed

+110
-28
lines changed

1 file changed

+110
-28
lines changed

src/conn/mod.rs

Lines changed: 110 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ lazy_static::lazy_static! {
6161
static ref FIXED_MARIADB_VERSION_RE: Regex =
6262
Regex::new(r"^(?:5.5.5-)?(\d{1,2})\.(\d{1,2})\.(\d{1,3})-MariaDB").unwrap();
6363
}
64+
const DEFAULT_WAIT_TIMEOUT: usize = 28800;
6465

6566
/// Helper that asynchronously disconnects the givent connection on the default tokio executor.
6667
fn disconnect(mut conn: Conn) {
@@ -917,42 +918,123 @@ impl Conn {
917918
/// * It reads and stores `wait_timeout` in the connection unless it's already in [`Opts`]
918919
///
919920
async fn read_settings(&mut self) -> Result<()> {
920-
let read_socket = self.inner.opts.prefer_socket() && self.inner.socket.is_none();
921-
let read_max_allowed_packet = self.opts().max_allowed_packet().is_none();
922-
let read_wait_timeout = self.opts().wait_timeout().is_none();
921+
enum Action {
922+
Load(Cfg),
923+
Apply(CfgData),
924+
}
923925

924-
let settings: Option<Row> = if read_socket || read_max_allowed_packet || read_wait_timeout {
925-
self.query_internal("SELECT @@socket, @@max_allowed_packet, @@wait_timeout")
926-
.await?
927-
} else {
928-
None
929-
};
926+
enum CfgData {
927+
MaxAllowedPacket(usize),
928+
WaitTimeout(usize),
929+
}
930930

931-
// set socket inside the connection
932-
if read_socket {
933-
self.inner.socket = settings.as_ref().map(|s| s.get("@@socket")).unwrap_or(None);
931+
impl CfgData {
932+
fn apply(&self, conn: &mut Conn) {
933+
match self {
934+
Self::MaxAllowedPacket(value) => {
935+
if let Some(stream) = conn.inner.stream.as_mut() {
936+
stream.set_max_allowed_packet(*value);
937+
}
938+
}
939+
Self::WaitTimeout(value) => {
940+
conn.inner.wait_timeout = Duration::from_secs(*value as u64);
941+
}
942+
}
943+
}
934944
}
935945

936-
// set max_allowed_packet
937-
let max_allowed_packet = if read_max_allowed_packet {
938-
settings
939-
.as_ref()
940-
.map(|s| s.get("@@max_allowed_packet"))
941-
.unwrap()
942-
} else {
943-
self.opts().max_allowed_packet()
944-
};
945-
if let Some(stream) = self.inner.stream.as_mut() {
946-
stream.set_max_allowed_packet(max_allowed_packet.unwrap_or(DEFAULT_MAX_ALLOWED_PACKET));
946+
enum Cfg {
947+
Socket,
948+
MaxAllowedPacket,
949+
WaitTimeout,
950+
}
951+
952+
impl Cfg {
953+
const fn name(&self) -> &'static str {
954+
match self {
955+
Self::Socket => "@@socket",
956+
Self::MaxAllowedPacket => "@@max_allowed_packet",
957+
Self::WaitTimeout => "@@wait_timeout",
958+
}
959+
}
960+
961+
fn apply(&self, conn: &mut Conn, value: Option<crate::Value>) {
962+
match self {
963+
Cfg::Socket => {
964+
conn.inner.socket = value.map(crate::from_value).flatten();
965+
}
966+
Cfg::MaxAllowedPacket => {
967+
if let Some(stream) = conn.inner.stream.as_mut() {
968+
stream.set_max_allowed_packet(
969+
value
970+
.map(crate::from_value)
971+
.flatten()
972+
.unwrap_or(DEFAULT_MAX_ALLOWED_PACKET),
973+
);
974+
}
975+
}
976+
Cfg::WaitTimeout => {
977+
conn.inner.wait_timeout = Duration::from_secs(
978+
value
979+
.map(crate::from_value)
980+
.flatten()
981+
.unwrap_or(DEFAULT_WAIT_TIMEOUT) as u64,
982+
);
983+
}
984+
}
985+
}
947986
}
948987

949-
// set read_wait_timeout
950-
let wait_timeout = if read_wait_timeout {
951-
settings.as_ref().map(|s| s.get("@@wait_timeout")).unwrap()
988+
let mut actions = vec![
989+
if let Some(x) = self.opts().max_allowed_packet() {
990+
Action::Apply(CfgData::MaxAllowedPacket(x))
991+
} else {
992+
Action::Load(Cfg::MaxAllowedPacket)
993+
},
994+
if let Some(x) = self.opts().wait_timeout() {
995+
Action::Apply(CfgData::WaitTimeout(x))
996+
} else {
997+
Action::Load(Cfg::WaitTimeout)
998+
},
999+
];
1000+
1001+
if self.inner.opts.prefer_socket() && self.inner.socket.is_none() {
1002+
actions.push(Action::Load(Cfg::Socket))
1003+
}
1004+
1005+
let loads = actions
1006+
.iter()
1007+
.filter_map(|x| match x {
1008+
Action::Load(x) => Some(x),
1009+
Action::Apply(_) => None,
1010+
})
1011+
.collect::<Vec<_>>();
1012+
1013+
let loaded = if !loads.is_empty() {
1014+
let query = loads
1015+
.iter()
1016+
.zip(std::iter::once(' ').chain(std::iter::repeat(',')))
1017+
.fold("SELECT".to_owned(), |mut acc, (cfg, prefix)| {
1018+
acc.push(prefix);
1019+
acc.push_str(cfg.name());
1020+
acc
1021+
});
1022+
1023+
self.query_internal::<Row, String>(query)
1024+
.await?
1025+
.map(|row| row.unwrap())
1026+
.unwrap_or_else(|| vec![crate::Value::NULL; loads.len()])
9521027
} else {
953-
self.opts().wait_timeout()
1028+
vec![]
9541029
};
955-
self.inner.wait_timeout = Duration::from_secs(wait_timeout.unwrap_or(28800) as u64);
1030+
let mut loaded = loaded.into_iter();
1031+
1032+
for action in actions {
1033+
match action {
1034+
Action::Load(cfg) => cfg.apply(self, loaded.next()),
1035+
Action::Apply(cfg) => cfg.apply(self),
1036+
}
1037+
}
9561038

9571039
Ok(())
9581040
}

0 commit comments

Comments
 (0)