@@ -61,6 +61,7 @@ lazy_static::lazy_static! {
61
61
static ref FIXED_MARIADB_VERSION_RE : Regex =
62
62
Regex :: new( r"^(?:5.5.5-)?(\d{1,2})\.(\d{1,2})\.(\d{1,3})-MariaDB" ) . unwrap( ) ;
63
63
}
64
+ const DEFAULT_WAIT_TIMEOUT : usize = 28800 ;
64
65
65
66
/// Helper that asynchronously disconnects the givent connection on the default tokio executor.
66
67
fn disconnect ( mut conn : Conn ) {
@@ -917,42 +918,123 @@ impl Conn {
917
918
/// * It reads and stores `wait_timeout` in the connection unless it's already in [`Opts`]
918
919
///
919
920
async fn read_settings ( & mut self ) -> Result < ( ) > {
920
- let read_socket = self . inner . opts . prefer_socket ( ) && self . inner . socket . is_none ( ) ;
921
- let read_max_allowed_packet = self . opts ( ) . max_allowed_packet ( ) . is_none ( ) ;
922
- let read_wait_timeout = self . opts ( ) . wait_timeout ( ) . is_none ( ) ;
921
+ enum Action {
922
+ Load ( Cfg ) ,
923
+ Apply ( CfgData ) ,
924
+ }
923
925
924
- let settings: Option < Row > = if read_socket || read_max_allowed_packet || read_wait_timeout {
925
- self . query_internal ( "SELECT @@socket, @@max_allowed_packet, @@wait_timeout" )
926
- . await ?
927
- } else {
928
- None
929
- } ;
926
+ enum CfgData {
927
+ MaxAllowedPacket ( usize ) ,
928
+ WaitTimeout ( usize ) ,
929
+ }
930
930
931
- // set socket inside the connection
932
- if read_socket {
933
- self . inner . socket = settings. as_ref ( ) . map ( |s| s. get ( "@@socket" ) ) . unwrap_or ( None ) ;
931
+ impl CfgData {
932
+ fn apply ( & self , conn : & mut Conn ) {
933
+ match self {
934
+ Self :: MaxAllowedPacket ( value) => {
935
+ if let Some ( stream) = conn. inner . stream . as_mut ( ) {
936
+ stream. set_max_allowed_packet ( * value) ;
937
+ }
938
+ }
939
+ Self :: WaitTimeout ( value) => {
940
+ conn. inner . wait_timeout = Duration :: from_secs ( * value as u64 ) ;
941
+ }
942
+ }
943
+ }
934
944
}
935
945
936
- // set max_allowed_packet
937
- let max_allowed_packet = if read_max_allowed_packet {
938
- settings
939
- . as_ref ( )
940
- . map ( |s| s. get ( "@@max_allowed_packet" ) )
941
- . unwrap ( )
942
- } else {
943
- self . opts ( ) . max_allowed_packet ( )
944
- } ;
945
- if let Some ( stream) = self . inner . stream . as_mut ( ) {
946
- stream. set_max_allowed_packet ( max_allowed_packet. unwrap_or ( DEFAULT_MAX_ALLOWED_PACKET ) ) ;
946
+ enum Cfg {
947
+ Socket ,
948
+ MaxAllowedPacket ,
949
+ WaitTimeout ,
950
+ }
951
+
952
+ impl Cfg {
953
+ const fn name ( & self ) -> & ' static str {
954
+ match self {
955
+ Self :: Socket => "@@socket" ,
956
+ Self :: MaxAllowedPacket => "@@max_allowed_packet" ,
957
+ Self :: WaitTimeout => "@@wait_timeout" ,
958
+ }
959
+ }
960
+
961
+ fn apply ( & self , conn : & mut Conn , value : Option < crate :: Value > ) {
962
+ match self {
963
+ Cfg :: Socket => {
964
+ conn. inner . socket = value. map ( crate :: from_value) . flatten ( ) ;
965
+ }
966
+ Cfg :: MaxAllowedPacket => {
967
+ if let Some ( stream) = conn. inner . stream . as_mut ( ) {
968
+ stream. set_max_allowed_packet (
969
+ value
970
+ . map ( crate :: from_value)
971
+ . flatten ( )
972
+ . unwrap_or ( DEFAULT_MAX_ALLOWED_PACKET ) ,
973
+ ) ;
974
+ }
975
+ }
976
+ Cfg :: WaitTimeout => {
977
+ conn. inner . wait_timeout = Duration :: from_secs (
978
+ value
979
+ . map ( crate :: from_value)
980
+ . flatten ( )
981
+ . unwrap_or ( DEFAULT_WAIT_TIMEOUT ) as u64 ,
982
+ ) ;
983
+ }
984
+ }
985
+ }
947
986
}
948
987
949
- // set read_wait_timeout
950
- let wait_timeout = if read_wait_timeout {
951
- settings. as_ref ( ) . map ( |s| s. get ( "@@wait_timeout" ) ) . unwrap ( )
988
+ let mut actions = vec ! [
989
+ if let Some ( x) = self . opts( ) . max_allowed_packet( ) {
990
+ Action :: Apply ( CfgData :: MaxAllowedPacket ( x) )
991
+ } else {
992
+ Action :: Load ( Cfg :: MaxAllowedPacket )
993
+ } ,
994
+ if let Some ( x) = self . opts( ) . wait_timeout( ) {
995
+ Action :: Apply ( CfgData :: WaitTimeout ( x) )
996
+ } else {
997
+ Action :: Load ( Cfg :: WaitTimeout )
998
+ } ,
999
+ ] ;
1000
+
1001
+ if self . inner . opts . prefer_socket ( ) && self . inner . socket . is_none ( ) {
1002
+ actions. push ( Action :: Load ( Cfg :: Socket ) )
1003
+ }
1004
+
1005
+ let loads = actions
1006
+ . iter ( )
1007
+ . filter_map ( |x| match x {
1008
+ Action :: Load ( x) => Some ( x) ,
1009
+ Action :: Apply ( _) => None ,
1010
+ } )
1011
+ . collect :: < Vec < _ > > ( ) ;
1012
+
1013
+ let loaded = if !loads. is_empty ( ) {
1014
+ let query = loads
1015
+ . iter ( )
1016
+ . zip ( std:: iter:: once ( ' ' ) . chain ( std:: iter:: repeat ( ',' ) ) )
1017
+ . fold ( "SELECT" . to_owned ( ) , |mut acc, ( cfg, prefix) | {
1018
+ acc. push ( prefix) ;
1019
+ acc. push_str ( cfg. name ( ) ) ;
1020
+ acc
1021
+ } ) ;
1022
+
1023
+ self . query_internal :: < Row , String > ( query)
1024
+ . await ?
1025
+ . map ( |row| row. unwrap ( ) )
1026
+ . unwrap_or_else ( || vec ! [ crate :: Value :: NULL ; loads. len( ) ] )
952
1027
} else {
953
- self . opts ( ) . wait_timeout ( )
1028
+ vec ! [ ]
954
1029
} ;
955
- self . inner . wait_timeout = Duration :: from_secs ( wait_timeout. unwrap_or ( 28800 ) as u64 ) ;
1030
+ let mut loaded = loaded. into_iter ( ) ;
1031
+
1032
+ for action in actions {
1033
+ match action {
1034
+ Action :: Load ( cfg) => cfg. apply ( self , loaded. next ( ) ) ,
1035
+ Action :: Apply ( cfg) => cfg. apply ( self ) ,
1036
+ }
1037
+ }
956
1038
957
1039
Ok ( ( ) )
958
1040
}
0 commit comments