@@ -5,7 +5,8 @@ use super::ext::Ext as _;
55use crate :: { features, message:: cmsg:: Encoder } ;
66use libc:: msghdr;
77use s2n_quic_core:: {
8- inet:: { AncillaryData , SocketAddressV4 } ,
8+ ensure,
9+ inet:: { AncillaryData , SocketAddressV4 , Unspecified } ,
910 path:: { self , LocalAddress , RemoteAddress } ,
1011} ;
1112
@@ -54,8 +55,8 @@ impl path::Handle for Handle {
5455 }
5556
5657 #[ inline]
57- fn set_remote_port ( & mut self , port : u16 ) {
58- self . remote_address . 0 . set_port ( port ) ;
58+ fn set_remote_address ( & mut self , addr : RemoteAddress ) {
59+ self . remote_address = addr ;
5960 }
6061
6162 #[ inline]
@@ -64,15 +65,41 @@ impl path::Handle for Handle {
6465 }
6566
6667 #[ inline]
67- fn eq ( & self , other : & Self ) -> bool {
68- let mut eq = true ;
68+ fn set_local_address ( & mut self , addr : LocalAddress ) {
69+ self . local_address = addr;
70+ }
71+
72+ #[ inline]
73+ fn unmapped_eq ( & self , other : & Self ) -> bool {
74+ ensure ! (
75+ self . remote_address. unmapped_eq( & other. remote_address) ,
76+ false
77+ ) ;
6978
7079 // only compare local addresses if the OS returns them
71- if features:: pktinfo:: IS_SUPPORTED {
72- eq &= self . local_address . eq ( & other. local_address ) ;
80+ ensure ! ( features:: pktinfo:: IS_SUPPORTED , true ) ;
81+
82+ // Make sure to only compare the fields if they're both set
83+ //
84+ // This avoids cases where we don't have the full context for the local address and find it
85+ // out with a later packet.
86+ if !self . local_address . ip ( ) . is_unspecified ( ) && !other. local_address . ip ( ) . is_unspecified ( ) {
87+ ensure ! (
88+ self . local_address
89+ . ip( )
90+ . unmapped_eq( & other. local_address. ip( ) ) ,
91+ false
92+ ) ;
7393 }
7494
75- eq && path:: Handle :: eq ( & self . remote_address , & other. remote_address )
95+ if self . local_address . port ( ) > 0 && other. local_address . port ( ) > 0 {
96+ ensure ! (
97+ self . local_address. port( ) == other. local_address. port( ) ,
98+ false
99+ ) ;
100+ }
101+
102+ true
76103 }
77104
78105 #[ inline]
@@ -92,3 +119,50 @@ impl path::Handle for Handle {
92119 }
93120 }
94121}
122+
123+ #[ cfg( test) ]
124+ mod tests {
125+ use crate :: message:: msg:: Handle ;
126+ use s2n_quic_core:: {
127+ inet:: { IpAddress , IpV4Address } ,
128+ path:: { Handle as _, LocalAddress } ,
129+ } ;
130+
131+ /// Checks that unmapped_eq is correct independent of argument ordering
132+ fn reflexive_check ( a : Handle , b : Handle ) {
133+ assert ! ( a. unmapped_eq( & b) ) ;
134+ assert ! ( b. unmapped_eq( & a) ) ;
135+ }
136+
137+ #[ test]
138+ fn unmapped_eq_test ( ) {
139+ // All of these values should be considered equivalent for local addresses
140+ let ips: & [ IpAddress ] = & [
141+ // if we have an unspecified IP address then don't consider it for equality
142+ IpV4Address :: new ( [ 0 , 0 , 0 , 0 ] ) . into ( ) ,
143+ // a regular IPv4 IP should match the IPv4-mapped into IPv6
144+ IpV4Address :: new ( [ 1 , 1 , 1 , 1 ] ) . into ( ) ,
145+ IpV4Address :: new ( [ 1 , 1 , 1 , 1 ] ) . to_ipv6_mapped ( ) . into ( ) ,
146+ ] ;
147+ let ports = [ 0u16 , 4440 ] ;
148+
149+ for ip_a in ips {
150+ for ip_b in ips {
151+ for port_a in ports {
152+ for port_b in ports {
153+ reflexive_check (
154+ Handle {
155+ remote_address : Default :: default ( ) ,
156+ local_address : LocalAddress :: from ( ip_a. with_port ( port_a) ) ,
157+ } ,
158+ Handle {
159+ remote_address : Default :: default ( ) ,
160+ local_address : LocalAddress :: from ( ip_b. with_port ( port_b) ) ,
161+ } ,
162+ ) ;
163+ }
164+ }
165+ }
166+ }
167+ }
168+ }
0 commit comments