Skip to content

Commit d152300

Browse files
committed
add poll_read_early_data
1 parent 7448a86 commit d152300

File tree

3 files changed

+103
-36
lines changed

3 files changed

+103
-36
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ exclude = ["/.github", "/examples", "/scripts"]
1515
[dependencies]
1616
tokio = "1.0"
1717
rustls = { version = "0.23.5", default-features = false, features = ["std"] }
18+
pin-project-lite = "0.2.14"
1819
pki-types = { package = "rustls-pki-types", version = "1" }
1920

2021
[features]

src/server.rs

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use std::io;
2+
use std::io::Read;
23
#[cfg(unix)]
34
use std::os::unix::io::{AsRawFd, RawFd};
45
#[cfg(windows)]
@@ -96,6 +97,61 @@ where
9697
}
9798
}
9899

100+
#[cfg(feature = "early-data")]
101+
impl<IO> TlsStream<IO>
102+
where
103+
IO: AsyncRead + AsyncWrite + Unpin,
104+
{
105+
pub fn poll_read_early_data(
106+
self: Pin<&mut Self>,
107+
cx: &mut Context<'_>,
108+
buf: &mut ReadBuf<'_>,
109+
) -> Poll<io::Result<()>> {
110+
let this = self.get_mut();
111+
let mut stream =
112+
Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
113+
114+
match &this.state {
115+
TlsState::Stream | TlsState::WriteShutdown => {
116+
{
117+
let mut stream = stream.as_mut_pin();
118+
119+
while !stream.eof && stream.session.wants_read() {
120+
match stream.read_io(cx) {
121+
Poll::Ready(Ok(0)) => {
122+
break;
123+
}
124+
Poll::Ready(Ok(_)) => (),
125+
Poll::Pending => {
126+
break;
127+
}
128+
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
129+
}
130+
}
131+
}
132+
133+
if let Some(mut early_data) = stream.session.early_data() {
134+
match early_data.read(buf.initialize_unfilled()) {
135+
Ok(n) => if n > 0 {
136+
buf.advance(n);
137+
return Poll::Ready(Ok(()));
138+
}
139+
Err(err) => return Poll::Ready(Err(err))
140+
}
141+
}
142+
143+
if stream.session.is_handshaking() {
144+
return Poll::Pending;
145+
}
146+
147+
return Poll::Ready(Ok(()));
148+
}
149+
TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(())),
150+
s => unreachable!("server TLS can not hit this state: {:?}", s),
151+
}
152+
}
153+
}
154+
99155
impl<IO> AsyncWrite for TlsStream<IO>
100156
where
101157
IO: AsyncRead + AsyncWrite + Unpin,

tests/early-data.rs

Lines changed: 46 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
#![cfg(feature = "early-data")]
22

3-
use std::io::{self, BufReader, Cursor, Read, Write};
4-
use std::net::{SocketAddr, TcpListener};
3+
use std::io::{self, BufReader, Cursor};
4+
use std::net::SocketAddr;
55
use std::pin::Pin;
66
use std::sync::Arc;
77
use std::task::{Context, Poll};
8-
use std::thread;
98

109
use futures_util::{future::Future, ready};
11-
use rustls::{self, ClientConfig, RootCertStore, ServerConfig, ServerConnection, Stream};
12-
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt, ReadBuf};
13-
use tokio::net::TcpStream;
14-
use tokio_rustls::{client::TlsStream, TlsConnector};
10+
use pin_project_lite::pin_project;
11+
use rustls::{self, ClientConfig, RootCertStore, ServerConfig};
12+
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
13+
use tokio::net::{TcpListener, TcpStream};
14+
use tokio_rustls::{client, server, TlsAcceptor, TlsConnector};
1515

1616
struct Read1<T>(T);
1717

@@ -33,12 +33,27 @@ impl<T: AsyncRead + Unpin> Future for Read1<T> {
3333
}
3434
}
3535

36+
pin_project! {
37+
struct TlsStreamEarlyWrapper<IO> {
38+
#[pin]
39+
inner: server::TlsStream<IO>
40+
}
41+
}
42+
43+
impl<IO> AsyncRead for TlsStreamEarlyWrapper<IO>
44+
where
45+
IO: AsyncRead + AsyncWrite + Unpin {
46+
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
47+
return self.project().inner.poll_read_early_data(cx, buf);
48+
}
49+
}
50+
3651
async fn send(
3752
config: Arc<ClientConfig>,
3853
addr: SocketAddr,
3954
data: &[u8],
4055
vectored: bool,
41-
) -> io::Result<(TlsStream<TcpStream>, Vec<u8>)> {
56+
) -> io::Result<(client::TlsStream<TcpStream>, Vec<u8>)> {
4257
let connector = TlsConnector::from(config).early_data(true);
4358
let stream = TcpStream::connect(&addr).await?;
4459
let domain = pki_types::ServerName::try_from("foobar.com").unwrap();
@@ -75,38 +90,33 @@ async fn test_0rtt_impl(vectored: bool) -> io::Result<()> {
7590
.unwrap();
7691
server.max_early_data_size = 8192;
7792
let server = Arc::new(server);
93+
let acceptor = Arc::new(TlsAcceptor::from(server));
7894

79-
let listener = TcpListener::bind("127.0.0.1:0")?;
95+
let listener = TcpListener::bind("127.0.0.1:0").await?;
8096
let server_port = listener.local_addr().unwrap().port();
81-
thread::spawn(move || loop {
82-
let (mut sock, _addr) = listener.accept().unwrap();
97+
tokio::spawn(async move {
98+
loop {
99+
let (mut sock, _addr) = listener.accept().await.unwrap();
100+
101+
let acceptor = acceptor.clone();
102+
tokio::spawn(async move {
103+
let stream = acceptor.accept(&mut sock).await.unwrap();
83104

84-
let server = Arc::clone(&server);
85-
thread::spawn(move || {
86-
let mut conn = ServerConnection::new(server).unwrap();
87-
conn.complete_io(&mut sock).unwrap();
105+
let mut buf = Vec::new();
106+
let mut stream_wrapper = TlsStreamEarlyWrapper{ inner: stream };
107+
stream_wrapper.read_to_end(&mut buf).await.unwrap();
108+
let mut stream = stream_wrapper.inner;
109+
stream.write_all(b"EARLY:").await.unwrap();
110+
stream.write_all(&buf).await.unwrap();
88111

89-
if let Some(mut early_data) = conn.early_data() {
90112
let mut buf = Vec::new();
91-
early_data.read_to_end(&mut buf).unwrap();
92-
let mut stream = Stream::new(&mut conn, &mut sock);
93-
stream.write_all(b"EARLY:").unwrap();
94-
stream.write_all(&buf).unwrap();
95-
}
96-
97-
let mut stream = Stream::new(&mut conn, &mut sock);
98-
stream.write_all(b"LATE:").unwrap();
99-
loop {
100-
let mut buf = [0; 1024];
101-
let n = stream.read(&mut buf).unwrap();
102-
if n == 0 {
103-
conn.send_close_notify();
104-
conn.complete_io(&mut sock).unwrap();
105-
break;
106-
}
107-
stream.write_all(&buf[..n]).unwrap();
108-
}
109-
});
113+
stream.read_to_end(&mut buf).await.unwrap();
114+
stream.write_all(b"LATE:").await.unwrap();
115+
stream.write_all(&buf).await.unwrap();
116+
117+
stream.shutdown().await.unwrap();
118+
});
119+
}
110120
});
111121

112122
let mut chain = BufReader::new(Cursor::new(include_str!("end.chain")));
@@ -125,7 +135,7 @@ async fn test_0rtt_impl(vectored: bool) -> io::Result<()> {
125135

126136
let (io, buf) = send(config.clone(), addr, b"hello", vectored).await?;
127137
assert!(!io.get_ref().1.is_early_data_accepted());
128-
assert_eq!("LATE:hello", String::from_utf8_lossy(&buf));
138+
assert_eq!("EARLY:LATE:hello", String::from_utf8_lossy(&buf));
129139

130140
let (io, buf) = send(config, addr, b"world!", vectored).await?;
131141
assert!(io.get_ref().1.is_early_data_accepted());

0 commit comments

Comments
 (0)