diff --git a/src/tx/retransmission_queue.rs b/src/tx/retransmission_queue.rs index 2716c1c..cc6ae1b 100644 --- a/src/tx/retransmission_queue.rs +++ b/src/tx/retransmission_queue.rs @@ -195,11 +195,7 @@ impl RetransmissionQueue { } fn update_receiver_window(&mut self, a_rwnd: usize) { - self.rwnd = if self.outstanding_data.unacked_bytes() >= a_rwnd { - 0 - } else { - a_rwnd - self.outstanding_data.unacked_bytes() - } + self.rwnd = a_rwnd.saturating_sub(self.outstanding_data.unacked_bytes()); } fn phase(&self) -> CongestionAlgorithmPhase { @@ -334,9 +330,10 @@ impl RetransmissionQueue { let old_unacked_bytes = self.outstanding_data.unacked_bytes(); let old_rwnd = self.rwnd; - let rtt = match sack.gap_ack_blocks.is_empty() { - true => self.outstanding_data.measure_rtt(now, sack.cumulative_tsn_ack), - false => None, + let rtt = if sack.gap_ack_blocks.is_empty() { + self.outstanding_data.measure_rtt(now, sack.cumulative_tsn_ack) + } else { + None }; // Exit fast recovery before continuing processing, in case it needs to go into fast @@ -492,9 +489,10 @@ impl RetransmissionQueue { self.t3_rtx.start(now); } - let bytes_retransmitted = to_be_sent.iter().fold(0, |acc, (_, data)| { - acc + round_up_to_4!(self.data_chunk_header_size + data.payload.len()) - }); + let bytes_retransmitted: usize = to_be_sent + .iter() + .map(|(_, data)| round_up_to_4!(self.data_chunk_header_size + data.payload.len())) + .sum(); self.rtx_packets_count += 1; self.rtx_bytes_count += bytes_retransmitted as u64; @@ -502,8 +500,7 @@ impl RetransmissionQueue { log::debug!( "Fast-retransmitting TSN {} - {} bytes. unacked_bytes={} ({})", to_be_sent.iter().map(|(tsn, _)| tsn.to_string()).collect::>().join(","), - to_be_sent.iter().fold(0, |acc, (_, data)| acc - + round_up_to_4!(self.data_chunk_header_size + data.payload.len())), + bytes_retransmitted, self.unacked_bytes(), old_unacked_bytes ); @@ -541,9 +538,10 @@ impl RetransmissionQueue { let mut max_bytes = round_down_to_4!(min(self.max_bytes_to_send(), bytes_remaining_in_packet)); let mut to_be_sent = self.outstanding_data.get_chunks_to_be_retransmitted(max_bytes); - let bytes_retransmitted = to_be_sent.iter().fold(0, |acc, (_, data)| { - acc + round_up_to_4!(self.data_chunk_header_size + data.payload.len()) - }); + let bytes_retransmitted: usize = to_be_sent + .iter() + .map(|(_, data)| round_up_to_4!(self.data_chunk_header_size + data.payload.len())) + .sum(); max_bytes -= bytes_retransmitted; if !to_be_sent.is_empty() { @@ -583,11 +581,14 @@ impl RetransmissionQueue { if !self.t3_rtx.is_running() { self.t3_rtx.start(now); } + let sent_bytes: usize = to_be_sent + .iter() + .map(|(_, data)| round_up_to_4!(self.data_chunk_header_size + data.payload.len())) + .sum(); log::debug!( "Sending TSN {} - {} bytes. unacked_bytes={} ({}), cwnd={}, rwnd={} ({})", to_be_sent.iter().map(|(tsn, _)| tsn.to_string()).collect::>().join(","), - to_be_sent.iter().fold(0, |acc, (_, data)| acc - + round_up_to_4!(self.data_chunk_header_size + data.payload.len())), + sent_bytes, self.unacked_bytes(), old_unacked_bytes, self.cwnd, @@ -666,8 +667,7 @@ impl RetransmissionQueue { /// Returns the number of bytes that may be sent in a single packet according to the congestion /// control algorithm. fn max_bytes_to_send(&self) -> usize { - let left = - if self.unacked_bytes() >= self.cwnd { 0 } else { self.cwnd - self.unacked_bytes() }; + let left = self.cwnd.saturating_sub(self.unacked_bytes()); if self.unacked_bytes() == 0 { // TODO: Make the implementation compliant with RFC 9260. // @@ -766,11 +766,7 @@ mod tests { } fn get_tsns(chunks: &[(Tsn, Data)]) -> Vec { - let mut tsns: Vec = Vec::new(); - for elem in chunks { - tsns.push(elem.0); - } - tsns + chunks.iter().map(|(tsn, _)| *tsn).collect() } fn get_sid_tsns(chunks: &[(Tsn, Data)]) -> Vec<(StreamId, Tsn)> { @@ -1468,7 +1464,7 @@ mod tests { panic!(); }; assert_eq!(fwd.new_cumulative_tsn, Tsn(13)); - assert_eq!(fwd.skipped_streams, vec!(SkippedStream::ForwardTsn(StreamId(1), Ssn(0)))); + assert_eq!(fwd.skipped_streams, vec![SkippedStream::ForwardTsn(StreamId(1), Ssn(0))]); } #[test] @@ -1527,7 +1523,7 @@ mod tests { panic!(); }; assert_eq!(fwd.new_cumulative_tsn, Tsn(12)); - assert_eq!(fwd.skipped_streams, vec!(SkippedStream::ForwardTsn(StreamId(1), Ssn(0)))); + assert_eq!(fwd.skipped_streams, vec![SkippedStream::ForwardTsn(StreamId(1), Ssn(0))]); } #[test] @@ -1626,11 +1622,11 @@ mod tests { assert_eq!(fwd.new_cumulative_tsn, Tsn(12)); assert_eq!( fwd.skipped_streams, - vec!( + vec![ SkippedStream::IForwardTsn(StreamKey::Ordered(StreamId(1)), Mid(0)), SkippedStream::IForwardTsn(StreamKey::Ordered(StreamId(2)), Mid(0)), SkippedStream::IForwardTsn(StreamKey::Ordered(StreamId(3)), Mid(0)) - ) + ] ); // TODO: Continue migrating this test case.