1
1
#![ cfg( feature = "early-data" ) ]
2
2
3
- use std:: io:: { self , BufReader , Cursor , Read , Write } ;
4
- use std:: net:: { SocketAddr , TcpListener } ;
3
+ use std:: io:: { self , BufReader , Cursor } ;
4
+ use std:: net:: SocketAddr ;
5
5
use std:: pin:: Pin ;
6
6
use std:: sync:: Arc ;
7
7
use std:: task:: { Context , Poll } ;
8
- use std:: thread;
9
8
10
9
use futures_util:: { future:: Future , ready} ;
11
- use rustls:: { self , ClientConfig , RootCertStore , ServerConfig , ServerConnection , Stream } ;
12
- use tokio:: io:: { AsyncRead , AsyncReadExt , AsyncWriteExt , ReadBuf } ;
13
- use tokio:: net:: TcpStream ;
14
- use tokio_rustls:: { client:: TlsStream , TlsConnector } ;
10
+ use pin_project_lite:: pin_project;
11
+ use rustls:: { self , ClientConfig , RootCertStore , ServerConfig } ;
12
+ use tokio:: io:: { AsyncRead , AsyncReadExt , AsyncWrite , AsyncWriteExt , ReadBuf } ;
13
+ use tokio:: net:: { TcpListener , TcpStream } ;
14
+ use tokio_rustls:: { client, server, TlsAcceptor , TlsConnector } ;
15
15
16
16
struct Read1 < T > ( T ) ;
17
17
@@ -33,12 +33,27 @@ impl<T: AsyncRead + Unpin> Future for Read1<T> {
33
33
}
34
34
}
35
35
36
+ pin_project ! {
37
+ struct TlsStreamEarlyWrapper <IO > {
38
+ #[ pin]
39
+ inner: server:: TlsStream <IO >
40
+ }
41
+ }
42
+
43
+ impl < IO > AsyncRead for TlsStreamEarlyWrapper < IO >
44
+ where
45
+ IO : AsyncRead + AsyncWrite + Unpin {
46
+ fn poll_read ( self : Pin < & mut Self > , cx : & mut Context < ' _ > , buf : & mut ReadBuf < ' _ > ) -> Poll < io:: Result < ( ) > > {
47
+ return self . project ( ) . inner . poll_read_early_data ( cx, buf) ;
48
+ }
49
+ }
50
+
36
51
async fn send (
37
52
config : Arc < ClientConfig > ,
38
53
addr : SocketAddr ,
39
54
data : & [ u8 ] ,
40
55
vectored : bool ,
41
- ) -> io:: Result < ( TlsStream < TcpStream > , Vec < u8 > ) > {
56
+ ) -> io:: Result < ( client :: TlsStream < TcpStream > , Vec < u8 > ) > {
42
57
let connector = TlsConnector :: from ( config) . early_data ( true ) ;
43
58
let stream = TcpStream :: connect ( & addr) . await ?;
44
59
let domain = pki_types:: ServerName :: try_from ( "foobar.com" ) . unwrap ( ) ;
@@ -75,38 +90,33 @@ async fn test_0rtt_impl(vectored: bool) -> io::Result<()> {
75
90
. unwrap ( ) ;
76
91
server. max_early_data_size = 8192 ;
77
92
let server = Arc :: new ( server) ;
93
+ let acceptor = Arc :: new ( TlsAcceptor :: from ( server) ) ;
78
94
79
- let listener = TcpListener :: bind ( "127.0.0.1:0" ) ?;
95
+ let listener = TcpListener :: bind ( "127.0.0.1:0" ) . await ?;
80
96
let server_port = listener. local_addr ( ) . unwrap ( ) . port ( ) ;
81
- thread:: spawn ( move || loop {
82
- let ( mut sock, _addr) = listener. accept ( ) . unwrap ( ) ;
97
+ tokio:: spawn ( async move {
98
+ loop {
99
+ let ( mut sock, _addr) = listener. accept ( ) . await . unwrap ( ) ;
100
+
101
+ let acceptor = acceptor. clone ( ) ;
102
+ tokio:: spawn ( async move {
103
+ let stream = acceptor. accept ( & mut sock) . await . unwrap ( ) ;
83
104
84
- let server = Arc :: clone ( & server) ;
85
- thread:: spawn ( move || {
86
- let mut conn = ServerConnection :: new ( server) . unwrap ( ) ;
87
- conn. complete_io ( & mut sock) . unwrap ( ) ;
105
+ let mut buf = Vec :: new ( ) ;
106
+ let mut stream_wrapper = TlsStreamEarlyWrapper { inner : stream } ;
107
+ stream_wrapper. read_to_end ( & mut buf) . await . unwrap ( ) ;
108
+ let mut stream = stream_wrapper. inner ;
109
+ stream. write_all ( b"EARLY:" ) . await . unwrap ( ) ;
110
+ stream. write_all ( & buf) . await . unwrap ( ) ;
88
111
89
- if let Some ( mut early_data) = conn. early_data ( ) {
90
112
let mut buf = Vec :: new ( ) ;
91
- early_data. read_to_end ( & mut buf) . unwrap ( ) ;
92
- let mut stream = Stream :: new ( & mut conn, & mut sock) ;
93
- stream. write_all ( b"EARLY:" ) . unwrap ( ) ;
94
- stream. write_all ( & buf) . unwrap ( ) ;
95
- }
96
-
97
- let mut stream = Stream :: new ( & mut conn, & mut sock) ;
98
- stream. write_all ( b"LATE:" ) . unwrap ( ) ;
99
- loop {
100
- let mut buf = [ 0 ; 1024 ] ;
101
- let n = stream. read ( & mut buf) . unwrap ( ) ;
102
- if n == 0 {
103
- conn. send_close_notify ( ) ;
104
- conn. complete_io ( & mut sock) . unwrap ( ) ;
105
- break ;
106
- }
107
- stream. write_all ( & buf[ ..n] ) . unwrap ( ) ;
108
- }
109
- } ) ;
113
+ stream. read_to_end ( & mut buf) . await . unwrap ( ) ;
114
+ stream. write_all ( b"LATE:" ) . await . unwrap ( ) ;
115
+ stream. write_all ( & buf) . await . unwrap ( ) ;
116
+
117
+ stream. shutdown ( ) . await . unwrap ( ) ;
118
+ } ) ;
119
+ }
110
120
} ) ;
111
121
112
122
let mut chain = BufReader :: new ( Cursor :: new ( include_str ! ( "end.chain" ) ) ) ;
@@ -125,7 +135,7 @@ async fn test_0rtt_impl(vectored: bool) -> io::Result<()> {
125
135
126
136
let ( io, buf) = send ( config. clone ( ) , addr, b"hello" , vectored) . await ?;
127
137
assert ! ( !io. get_ref( ) . 1 . is_early_data_accepted( ) ) ;
128
- assert_eq ! ( "LATE:hello" , String :: from_utf8_lossy( & buf) ) ;
138
+ assert_eq ! ( "EARLY: LATE:hello" , String :: from_utf8_lossy( & buf) ) ;
129
139
130
140
let ( io, buf) = send ( config, addr, b"world!" , vectored) . await ?;
131
141
assert ! ( io. get_ref( ) . 1 . is_early_data_accepted( ) ) ;
0 commit comments