diff --git a/src/net.rs b/src/net.rs index 90a2fdc..160448b 100644 --- a/src/net.rs +++ b/src/net.rs @@ -3,7 +3,7 @@ use std::{ io::{self, BufRead, BufReader, Write}, net::{SocketAddr, TcpListener, TcpStream}, sync::{ - mpsc::{self, SendError}, + mpsc::{self}, Arc, Mutex, }, thread::JoinHandle, @@ -205,22 +205,34 @@ impl Default for TimeoutParams { } } +#[derive(Debug)] +enum WriteRequest { + Shutdown, + SendMessage(NetworkMessage), +} + /// Send messages to an open connection. #[derive(Debug)] pub struct ConnectionWriter { - sender: mpsc::Sender, + sender: mpsc::Sender, task_handle: JoinHandle>, } +#[allow(clippy::result_large_err)] impl ConnectionWriter { /// Send a network message to this peer. Errors indicate that the connection is terminated and /// no further messages will succeed. - #[allow(clippy::result_large_err)] - pub fn send_message( - &self, - network_message: NetworkMessage, - ) -> Result<(), SendError> { - self.sender.send(network_message) + pub fn send_message(&self, network_message: NetworkMessage) -> Result<(), Error> { + self.sender + .send(WriteRequest::SendMessage(network_message)) + .map_err(|_| Error::ChannelClosed) + } + + /// Kill both sides of the connection, erroring if the stream is already closed. + pub fn shutdown(&self) -> Result<(), Error> { + self.sender + .send(WriteRequest::Shutdown) + .map_err(|_| Error::ChannelClosed) } /// In the event of a failed message, investigate IO related failures if the connection was not @@ -234,7 +246,7 @@ impl ConnectionWriter { struct OpenWriter { tcp_stream: TcpStream, transport: WriteTransport, - receiver: mpsc::Receiver, + receiver: mpsc::Receiver, outbound_ping_state: Arc>, ping_interval: Duration, } @@ -244,10 +256,14 @@ impl OpenWriter { loop { let message = self.receiver.recv_timeout(Duration::from_secs(1)); match message { - Ok(network_message) => { - self.transport - .write_message(network_message, &mut self.tcp_stream)?; - } + Ok(request) => match request { + WriteRequest::SendMessage(message) => self + .transport + .write_message(message, &mut self.tcp_stream)?, + WriteRequest::Shutdown => { + self.tcp_stream.shutdown(std::net::Shutdown::Both)?; + } + }, Err(e) => match e { mpsc::RecvTimeoutError::Timeout => (), _ => return Ok(()), @@ -436,6 +452,8 @@ pub enum Error { UnexpectedMagic(Magic), /// The peer did not send a version message. MissingVersion, + /// The channel to the message writing thread was closed. + ChannelClosed, } impl Display for Error { @@ -446,6 +464,7 @@ impl Display for Error { Error::Handshake(e) => e.fmt(f), Error::UnexpectedMagic(magic) => write!(f, "unexpected network magic: {magic}"), Error::MissingVersion => write!(f, "missing version message."), + Error::ChannelClosed => write!(f, "channel closed"), } } }