Skip to content

RFC: listen on all ports #814

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/multicast.rs
Original file line number Diff line number Diff line change
@@ -111,7 +111,7 @@ fn main() {
}

let socket = sockets.get_mut::<udp::Socket>(udp_handle);
if !socket.is_open() {
if !socket.is_bound() {
socket.bind(MDNS_PORT).unwrap()
}

2 changes: 1 addition & 1 deletion examples/server.rs
Original file line number Diff line number Diff line change
@@ -96,7 +96,7 @@ fn main() {

// udp:6969: respond "hello"
let socket = sockets.get_mut::<udp::Socket>(udp_handle);
if !socket.is_open() {
if !socket.is_bound() {
socket.bind(6969).unwrap()
}

2 changes: 1 addition & 1 deletion examples/sixlowpan.rs
Original file line number Diff line number Diff line change
@@ -113,7 +113,7 @@ fn main() {

// udp:6969: respond "hello"
let socket = sockets.get_mut::<udp::Socket>(udp_handle);
if !socket.is_open() {
if !socket.is_bound() {
socket.bind(6969).unwrap()
}

19 changes: 10 additions & 9 deletions src/socket/tcp.rs
Original file line number Diff line number Diff line change
@@ -745,16 +745,12 @@ impl<'a> Socket<'a> {
/// Start listening on the given endpoint.
///
/// This function returns `Err(Error::Illegal)` if the socket was already open
/// (see [is_open](#method.is_open)), and `Err(Error::Unaddressable)`
/// if the port in the given endpoint is zero.
/// (see [is_open](#method.is_open)).
pub fn listen<T>(&mut self, local_endpoint: T) -> Result<(), ListenError>
where
T: Into<IpListenEndpoint>,
{
let local_endpoint = local_endpoint.into();
if local_endpoint.port == 0 {
return Err(ListenError::Unaddressable);
}

if self.is_open() {
return Err(ListenError::InvalidState);
@@ -1349,7 +1345,9 @@ impl<'a> Socket<'a> {
Some(addr) => ip_repr.dst_addr() == addr,
None => true,
};
addr_ok && repr.dst_port != 0 && repr.dst_port == self.listen_endpoint.port
addr_ok
&& repr.dst_port != 0
&& (self.listen_endpoint.port == 0 || repr.dst_port == self.listen_endpoint.port)
}
}

@@ -1868,7 +1866,10 @@ impl<'a> Socket<'a> {
let assembler_was_empty = self.assembler.is_empty();

// Try adding payload octets to the assembler.
let Ok(contig_len) = self.assembler.add_then_remove_front(payload_offset, payload_len) else {
let Ok(contig_len) = self
.assembler
.add_then_remove_front(payload_offset, payload_len)
else {
net_debug!(
"assembler: too many holes to add {} octets at offset {}",
payload_len,
@@ -2895,9 +2896,9 @@ mod test {
}

#[test]
fn test_listen_validation() {
fn test_listen_any_port() {
let mut s = socket();
assert_eq!(s.listen(0), Err(ListenError::Unaddressable));
assert_eq!(s.listen(0), Ok(()));
}

#[test]
276 changes: 209 additions & 67 deletions src/socket/udp.rs
Original file line number Diff line number Diff line change
@@ -27,6 +27,34 @@ impl<T: Into<IpEndpoint>> From<T> for UdpMetadata {
}
}

/// Extended metadata for a sent or received UDP packet.
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub struct ExtendedUdpMetadata {
pub local_endpoint: IpListenEndpoint,
pub remote_endpoint: IpEndpoint,
pub meta: PacketMeta,
}

impl ExtendedUdpMetadata {
fn new(local_endpoint: IpListenEndpoint, meta: UdpMetadata) -> Self {
Self {
local_endpoint,
remote_endpoint: meta.endpoint,
meta: meta.meta,
}
}
}

impl From<ExtendedUdpMetadata> for UdpMetadata {
fn from(value: ExtendedUdpMetadata) -> Self {
Self {
endpoint: value.remote_endpoint,
meta: value.meta,
}
}
}

impl core::fmt::Display for UdpMetadata {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
#[cfg(feature = "packetmeta-id")]
@@ -37,11 +65,25 @@ impl core::fmt::Display for UdpMetadata {
}
}

impl core::fmt::Display for ExtendedUdpMetadata {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
#[cfg(feature = "packetmeta-id")]
return write!(
f,
"{}/{}, PacketID: {:?}",
self.local_endpoint, self.remote_endpoint, self.meta
);

#[cfg(not(feature = "packetmeta-id"))]
write!(f, "{}/{}", self.local_endpoint, self.remote_endpoint)
}
}

/// A UDP packet metadata.
pub type PacketMetadata = crate::storage::PacketMetadata<UdpMetadata>;
pub type PacketMetadata = crate::storage::PacketMetadata<ExtendedUdpMetadata>;

/// A UDP packet ring buffer.
pub type PacketBuffer<'a> = crate::storage::PacketBuffer<'a, UdpMetadata>;
pub type PacketBuffer<'a> = crate::storage::PacketBuffer<'a, ExtendedUdpMetadata>;

/// Error returned by [`Socket::bind`]
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
@@ -107,7 +149,7 @@ impl std::error::Error for RecvError {}
/// packet buffers.
#[derive(Debug)]
pub struct Socket<'a> {
endpoint: IpListenEndpoint,
bound_endpoint: IpListenEndpoint,
rx_buffer: PacketBuffer<'a>,
tx_buffer: PacketBuffer<'a>,
/// The time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets.
@@ -122,7 +164,7 @@ impl<'a> Socket<'a> {
/// Create an UDP socket with the given buffers.
pub fn new(rx_buffer: PacketBuffer<'a>, tx_buffer: PacketBuffer<'a>) -> Socket<'a> {
Socket {
endpoint: IpListenEndpoint::default(),
bound_endpoint: IpListenEndpoint::default(),
rx_buffer,
tx_buffer,
hop_limit: None,
@@ -170,8 +212,8 @@ impl<'a> Socket<'a> {

/// Return the bound endpoint.
#[inline]
pub fn endpoint(&self) -> IpListenEndpoint {
self.endpoint
pub fn bound_endpoint(&self) -> IpListenEndpoint {
self.bound_endpoint
}

/// Return the time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets.
@@ -203,20 +245,20 @@ impl<'a> Socket<'a> {

/// Bind the socket to the given endpoint.
///
/// This function returns `Err(Error::Illegal)` if the socket was open
/// (see [is_open](#method.is_open)), and `Err(Error::Unaddressable)`
/// This function returns `Err(Error::Illegal)` if the socket was bound
/// (see [is_bound](#method.is_bound)), and `Err(Error::Unaddressable)`
/// if the port in the given endpoint is zero.
pub fn bind<T: Into<IpListenEndpoint>>(&mut self, endpoint: T) -> Result<(), BindError> {
let endpoint = endpoint.into();
if endpoint.port == 0 {
return Err(BindError::Unaddressable);
}

if self.is_open() {
if self.is_bound() {
return Err(BindError::InvalidState);
}

self.endpoint = endpoint;
self.bound_endpoint = endpoint;

#[cfg(feature = "async")]
{
@@ -230,7 +272,7 @@ impl<'a> Socket<'a> {
/// Close the socket.
pub fn close(&mut self) {
// Clear the bound endpoint of the socket.
self.endpoint = IpListenEndpoint::default();
self.bound_endpoint = IpListenEndpoint::default();

// Reset the RX and TX buffers of the socket.
self.tx_buffer.reset();
@@ -245,8 +287,8 @@ impl<'a> Socket<'a> {

/// Check whether the socket is open.
#[inline]
pub fn is_open(&self) -> bool {
self.endpoint.port != 0
pub fn is_bound(&self) -> bool {
self.bound_endpoint.port != 0
}

/// Check whether the transmit buffer is full.
@@ -292,19 +334,19 @@ impl<'a> Socket<'a> {
/// `Err(Error::Unaddressable)` if local or remote port, or remote address are unspecified,
/// and `Err(Error::Truncated)` if there is not enough transmit buffer capacity
/// to ever send this packet.
pub fn send(
pub fn send_from(
&mut self,
size: usize,
meta: impl Into<UdpMetadata>,
meta: impl Into<ExtendedUdpMetadata>,
) -> Result<&mut [u8], SendError> {
let meta = meta.into();
if self.endpoint.port == 0 {
if meta.local_endpoint.port == 0 {
return Err(SendError::Unaddressable);
}
if meta.endpoint.addr.is_unspecified() {
if meta.remote_endpoint.addr.is_unspecified() {
return Err(SendError::Unaddressable);
}
if meta.endpoint.port == 0 {
if meta.remote_endpoint.port == 0 {
return Err(SendError::Unaddressable);
}

@@ -315,35 +357,53 @@ impl<'a> Socket<'a> {

net_trace!(
"udp:{}:{}: buffer to send {} octets",
self.endpoint,
meta.endpoint,
meta.local_endpoint,
meta.remote_endpoint,
size
);
Ok(payload_buf)
}

/// Enqueue a packet to be sent to a given remote endpoint, and return a pointer
/// to its payload.
///
/// This function returns `Err(Error::Exhausted)` if the transmit buffer is full,
/// `Err(Error::Unaddressable)` if local or remote port, or remote address are unspecified,
/// and `Err(Error::Truncated)` if there is not enough transmit buffer capacity
/// to ever send this packet.
pub fn send(
&mut self,
size: usize,
meta: impl Into<UdpMetadata>,
) -> Result<&mut [u8], SendError> {
self.send_from(
size,
ExtendedUdpMetadata::new(self.bound_endpoint(), meta.into()),
)
}

/// Enqueue a packet to be send to a given remote endpoint and pass the buffer
/// to the provided closure. The closure then returns the size of the data written
/// into the buffer.
///
/// Also see [send](#method.send).
pub fn send_with<F>(
pub fn send_from_with<F>(
&mut self,
max_size: usize,
meta: impl Into<UdpMetadata>,
meta: impl Into<ExtendedUdpMetadata>,
f: F,
) -> Result<usize, SendError>
where
F: FnOnce(&mut [u8]) -> usize,
{
let meta = meta.into();
if self.endpoint.port == 0 {
if meta.local_endpoint.port == 0 {
return Err(SendError::Unaddressable);
}
if meta.endpoint.addr.is_unspecified() {
if meta.remote_endpoint.addr.is_unspecified() {
return Err(SendError::Unaddressable);
}
if meta.endpoint.port == 0 {
if meta.remote_endpoint.port == 0 {
return Err(SendError::Unaddressable);
}

@@ -354,13 +414,46 @@ impl<'a> Socket<'a> {

net_trace!(
"udp:{}:{}: buffer to send {} octets",
self.endpoint,
meta.endpoint,
meta.local_endpoint,
meta.remote_endpoint,
size
);
Ok(size)
}

/// Enqueue a packet to be send to a given remote endpoint and pass the buffer
/// to the provided closure. The closure then returns the size of the data written
/// into the buffer.
///
/// Also see [send](#method.send).
pub fn send_with<F>(
&mut self,
max_size: usize,
meta: impl Into<UdpMetadata>,
f: F,
) -> Result<usize, SendError>
where
F: FnOnce(&mut [u8]) -> usize,
{
self.send_from_with(
max_size,
ExtendedUdpMetadata::new(self.bound_endpoint(), meta.into()),
f,
)
}

/// Enqueue a packet to be sent to a given remote endpoint, and fill it from a slice.
///
/// See also [send](#method.send).
pub fn send_slice_from(
&mut self,
data: &[u8],
meta: impl Into<ExtendedUdpMetadata>,
) -> Result<(), SendError> {
self.send_from(data.len(), meta)?.copy_from_slice(data);
Ok(())
}

/// Enqueue a packet to be sent to a given remote endpoint, and fill it from a slice.
///
/// See also [send](#method.send).
@@ -377,48 +470,76 @@ impl<'a> Socket<'a> {
/// as a pointer to the payload.
///
/// This function returns `Err(Error::Exhausted)` if the receive buffer is empty.
pub fn recv(&mut self) -> Result<(&[u8], UdpMetadata), RecvError> {
let (remote_endpoint, payload_buf) =
self.rx_buffer.dequeue().map_err(|_| RecvError::Exhausted)?;
pub fn recv_to(&mut self) -> Result<(&[u8], ExtendedUdpMetadata), RecvError> {
let (meta, payload_buf) = self.rx_buffer.dequeue().map_err(|_| RecvError::Exhausted)?;

net_trace!(
"udp:{}:{}: receive {} buffered octets",
self.endpoint,
remote_endpoint.endpoint,
meta.local_endpoint,
meta.remote_endpoint,
payload_buf.len()
);
Ok((payload_buf, remote_endpoint))
Ok((payload_buf, meta))
}

/// Dequeue a packet received from a remote endpoint, and return the endpoint as well
/// as a pointer to the payload.
///
/// This function returns `Err(Error::Exhausted)` if the receive buffer is empty.
pub fn recv(&mut self) -> Result<(&[u8], UdpMetadata), RecvError> {
self.recv_to().map(|(buf, meta)| (buf, meta.into()))
}

/// Dequeue a packet received from a remote endpoint, copy the payload into the given slice,
/// and return the amount of octets copied as well as the endpoint.
///
/// See also [recv](#method.recv).
pub fn recv_slice(&mut self, data: &mut [u8]) -> Result<(usize, UdpMetadata), RecvError> {
let (buffer, endpoint) = self.recv().map_err(|_| RecvError::Exhausted)?;
pub fn recv_slice_to(
&mut self,
data: &mut [u8],
) -> Result<(usize, ExtendedUdpMetadata), RecvError> {
let (buffer, endpoint) = self.recv_to().map_err(|_| RecvError::Exhausted)?;
let length = min(data.len(), buffer.len());
data[..length].copy_from_slice(&buffer[..length]);
Ok((length, endpoint))
}

/// Dequeue a packet received from a remote endpoint, copy the payload into the given slice,
/// and return the amount of octets copied as well as the endpoint.
///
/// See also [recv](#method.recv).
pub fn recv_slice(&mut self, data: &mut [u8]) -> Result<(usize, UdpMetadata), RecvError> {
self.recv_slice_to(data)
.map(|(length, meta)| (length, meta.into()))
}

/// Peek at a packet received from a remote endpoint, and return the endpoint as well
/// as a pointer to the payload without removing the packet from the receive buffer.
/// This function otherwise behaves identically to [recv](#method.recv).
///
/// It returns `Err(Error::Exhausted)` if the receive buffer is empty.
pub fn peek(&mut self) -> Result<(&[u8], &UdpMetadata), RecvError> {
let endpoint = self.endpoint;
self.rx_buffer.peek().map_err(|_| RecvError::Exhausted).map(
|(remote_endpoint, payload_buf)| {
pub fn peek_to(&mut self) -> Result<(&[u8], ExtendedUdpMetadata), RecvError> {
self.rx_buffer
.peek()
.map_err(|_| RecvError::Exhausted)
.map(|(meta, payload_buf)| {
net_trace!(
"udp:{}:{}: peek {} buffered octets",
endpoint,
remote_endpoint.endpoint,
meta.local_endpoint,
meta.remote_endpoint,
payload_buf.len()
);
(payload_buf, remote_endpoint)
},
)
(payload_buf, *meta)
})
}

/// Peek at a packet received from a remote endpoint, and return the endpoint as well
/// as a pointer to the payload without removing the packet from the receive buffer.
/// This function otherwise behaves identically to [recv](#method.recv).
///
/// It returns `Err(Error::Exhausted)` if the receive buffer is empty.
pub fn peek(&mut self) -> Result<(&[u8], UdpMetadata), RecvError> {
self.peek_to().map(|(buf, meta)| (buf, meta.into()))
}

/// Peek at a packet received from a remote endpoint, copy the payload into the given slice,
@@ -427,19 +548,36 @@ impl<'a> Socket<'a> {
/// This function otherwise behaves identically to [recv_slice](#method.recv_slice).
///
/// See also [peek](#method.peek).
pub fn peek_slice(&mut self, data: &mut [u8]) -> Result<(usize, &UdpMetadata), RecvError> {
let (buffer, endpoint) = self.peek()?;
pub fn peek_slice_to(
&mut self,
data: &mut [u8],
) -> Result<(usize, ExtendedUdpMetadata), RecvError> {
let (buffer, endpoint) = self.peek_to()?;
let length = min(data.len(), buffer.len());
data[..length].copy_from_slice(&buffer[..length]);
Ok((length, endpoint))
}

/// Peek at a packet received from a remote endpoint, copy the payload into the given slice,
/// and return the amount of octets copied as well as the endpoint without removing the
/// packet from the receive buffer.
/// This function otherwise behaves identically to [recv_slice](#method.recv_slice).
///
/// See also [peek](#method.peek).
pub fn peek_slice(&mut self, data: &mut [u8]) -> Result<(usize, UdpMetadata), RecvError> {
self.peek_slice_to(data)
.map(|(length, meta)| (length, meta.into()))
}

pub(crate) fn accepts(&self, cx: &mut Context, ip_repr: &IpRepr, repr: &UdpRepr) -> bool {
if self.endpoint.port != repr.dst_port {
if !self.is_bound() {
return true;
}
if self.bound_endpoint.port != repr.dst_port {
return false;
}
if self.endpoint.addr.is_some()
&& self.endpoint.addr != Some(ip_repr.dst_addr())
if self.bound_endpoint.addr.is_some()
&& self.bound_endpoint.addr != Some(ip_repr.dst_addr())
&& !cx.is_broadcast(&ip_repr.dst_addr())
&& !ip_repr.dst_addr().is_multicast()
{
@@ -461,28 +599,33 @@ impl<'a> Socket<'a> {

let size = payload.len();

let local_endpoint = IpEndpoint {
addr: ip_repr.dst_addr(),
port: repr.dst_port,
};
let remote_endpoint = IpEndpoint {
addr: ip_repr.src_addr(),
port: repr.src_port,
};

net_trace!(
"udp:{}:{}: receiving {} octets",
self.endpoint,
local_endpoint,
remote_endpoint,
size
);

let metadata = UdpMetadata {
endpoint: remote_endpoint,
let metadata = ExtendedUdpMetadata {
local_endpoint: local_endpoint.into(),
remote_endpoint,
meta,
};

match self.rx_buffer.enqueue(size, metadata) {
Ok(buf) => buf.copy_from_slice(payload),
Err(_) => net_trace!(
"udp:{}:{}: buffer full, dropped incoming packet",
self.endpoint,
local_endpoint,
remote_endpoint
),
}
@@ -495,19 +638,18 @@ impl<'a> Socket<'a> {
where
F: FnOnce(&mut Context, PacketMeta, (IpRepr, UdpRepr, &[u8])) -> Result<(), E>,
{
let endpoint = self.endpoint;
let hop_limit = self.hop_limit.unwrap_or(64);

let res = self.tx_buffer.dequeue_with(|packet_meta, payload_buf| {
let src_addr = match endpoint.addr {
let src_addr = match packet_meta.local_endpoint.addr {
Some(addr) => addr,
None => match cx.get_source_address(packet_meta.endpoint.addr) {
None => match cx.get_source_address(packet_meta.remote_endpoint.addr) {
Some(addr) => addr,
None => {
net_trace!(
"udp:{}:{}: cannot find suitable source address, dropping.",
endpoint,
packet_meta.endpoint
packet_meta.local_endpoint,
packet_meta.remote_endpoint
);
return Ok(());
}
@@ -516,18 +658,18 @@ impl<'a> Socket<'a> {

net_trace!(
"udp:{}:{}: sending {} octets",
endpoint,
packet_meta.endpoint,
packet_meta.local_endpoint,
packet_meta.remote_endpoint,
payload_buf.len()
);

let repr = UdpRepr {
src_port: endpoint.port,
dst_port: packet_meta.endpoint.port,
src_port: packet_meta.local_endpoint.port,
dst_port: packet_meta.remote_endpoint.port,
};
let ip_repr = IpRepr::new(
src_addr,
packet_meta.endpoint.addr,
packet_meta.remote_endpoint.addr,
IpProtocol::Udp,
repr.header_len() + payload_buf.len(),
hop_limit,
@@ -794,7 +936,7 @@ mod test {
&REMOTE_UDP_REPR,
PAYLOAD,
);
assert_eq!(socket.peek(), Ok((&b"abcdef"[..], &REMOTE_END.into(),)));
assert_eq!(socket.peek(), Ok((&b"abcdef"[..], REMOTE_END.into(),)));
assert_eq!(socket.recv(), Ok((&b"abcdef"[..], REMOTE_END.into(),)));
assert_eq!(socket.peek(), Err(RecvError::Exhausted));
}
@@ -841,7 +983,7 @@ mod test {
let mut slice = [0; 4];
assert_eq!(
socket.peek_slice(&mut slice[..]),
Ok((4, &REMOTE_END.into()))
Ok((4, REMOTE_END.into()))
);
assert_eq!(&slice, b"abcd");
assert_eq!(
@@ -943,8 +1085,8 @@ mod test {
let mut socket = socket(recv_buffer, buffer(0));
assert_eq!(socket.bind(LOCAL_PORT), Ok(()));

assert!(socket.is_open());
assert!(socket.is_bound());
socket.close();
assert!(!socket.is_open());
assert!(!socket.is_bound());
}
}