Skip to content

[monarch][net] Replace tuples with named structures #631

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 38 additions & 25 deletions hyperactor/src/channel/net.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,13 +209,21 @@ impl<M: RemoteMessage> NetTx<M> {
// If we can't deliver a message within this limit consider
// `link` broken and return.

#[derive(Debug)]
struct QueuedMessage<M: RemoteMessage> {
seq: u64,
data: Bytes,
received_at: Instant,
return_channel: oneshot::Sender<M>,
}

#[derive(Debug)]
struct Outbox<'a, M: RemoteMessage> {
// The seq number of the next new message put into outbox. Requeued
// unacked messages should still use their already assigned seq
// numbers.
next_seq: u64,
deque: VecDeque<(u64, Bytes, Instant, oneshot::Sender<M>)>,
deque: VecDeque<QueuedMessage<M>>,
log_id: &'a str,
}

Expand All @@ -231,8 +239,9 @@ impl<M: RemoteMessage> NetTx<M> {
fn is_expired(&self) -> bool {
match self.deque.front() {
None => false,
Some((_, _, since, _)) => {
since.elapsed() > config::global::get(config::MESSAGE_DELIVERY_TIMEOUT)
Some(msg) => {
msg.received_at.elapsed()
> config::global::get(config::MESSAGE_DELIVERY_TIMEOUT)
}
}
}
Expand All @@ -256,17 +265,17 @@ impl<M: RemoteMessage> NetTx<M> {
self.log_id,
)
})?
.1
.data
.clone();
sink.send(data).await.map_err(|e| e.to_string())?;
Ok(())
}

fn front_size(&self) -> Option<usize> {
self.deque.front().map(|(_, bytes, _, _)| bytes.len())
self.deque.front().map(|msg| msg.data.len())
}

fn pop_front(&mut self) -> Option<(u64, Bytes, Instant, oneshot::Sender<M>)> {
fn pop_front(&mut self) -> Option<QueuedMessage<M>> {
self.deque.pop_front()
}

Expand All @@ -275,7 +284,7 @@ impl<M: RemoteMessage> NetTx<M> {
(message, return_channel, received_at): (M, oneshot::Sender<M>, Instant),
) -> Result<(), String> {
assert!(
self.deque.back().is_none_or(|msg| msg.0 < self.next_seq),
self.deque.back().is_none_or(|msg| msg.seq < self.next_seq),
"{}: unexpected: seq should be in ascending order, but got {:?} vs {}",
self.log_id,
self.deque.back(),
Expand All @@ -287,8 +296,12 @@ impl<M: RemoteMessage> NetTx<M> {
.map_err(|e| format!("serialization error: {e}"))?
.into();
REMOTE_MESSAGE_SEND_SIZE.record(data.len() as f64, &[]);
self.deque
.push_back((self.next_seq, data, received_at, return_channel));
self.deque.push_back(QueuedMessage {
seq: self.next_seq,
data,
received_at,
return_channel,
});
self.next_seq += 1;
Ok(())
}
Expand All @@ -297,11 +310,11 @@ impl<M: RemoteMessage> NetTx<M> {
match (unacked.deque.back(), self.deque.front()) {
(Some(last), Some(first)) => {
assert!(
last.0 < first.0,
last.seq < first.seq,
"{}: seq should be in ascending order, but got {} vs {:?}",
self.log_id,
last.0,
first.0,
last.seq,
first.seq,
);
}
_ => (),
Expand All @@ -315,7 +328,7 @@ impl<M: RemoteMessage> NetTx<M> {

#[derive(Debug)]
struct Unacked<'a, M: RemoteMessage> {
deque: VecDeque<(u64, Bytes, Instant, oneshot::Sender<M>)>,
deque: VecDeque<QueuedMessage<M>>,
largest_acked: Option<u64>,
log_id: &'a str,
}
Expand All @@ -329,13 +342,13 @@ impl<M: RemoteMessage> NetTx<M> {
}
}

fn push_back(&mut self, message: (u64, Bytes, Instant, oneshot::Sender<M>)) {
fn push_back(&mut self, message: QueuedMessage<M>) {
assert!(
self.deque.back().is_none_or(|msg| msg.0 < message.0),
self.deque.back().is_none_or(|msg| msg.seq < message.seq),
"{}: seq should be in ascending order, but got {:?} vs {}",
self.log_id,
self.deque.back(),
message.0
message.seq
);

if let Some(largest) = self.largest_acked {
Expand Down Expand Up @@ -370,7 +383,7 @@ impl<M: RemoteMessage> NetTx<M> {
// Tx resends. As a result, this message's ack would be
// recorded already by `largest_acked` before it is put into
// unacked queue.
if message.0 <= largest {
if message.seq <= largest {
// since the message is already delivered and acked, it
// does need to be put in the queue again.
return;
Expand All @@ -392,8 +405,8 @@ impl<M: RemoteMessage> NetTx<M> {

self.largest_acked = Some(acked);
let deque = &mut self.deque;
while let Some((seq, _, _, _)) = deque.front() {
if *seq <= acked {
while let Some(msg) = deque.front() {
if msg.seq <= acked {
deque.pop_front();
} else {
// Messages in the deque are orderd by seq in ascending
Expand All @@ -407,7 +420,7 @@ impl<M: RemoteMessage> NetTx<M> {
fn is_expired(&self) -> bool {
matches!(
self.deque.front(),
Some((_, _, received_at, _)) if received_at.elapsed() > config::global::get(config::MESSAGE_DELIVERY_TIMEOUT)
Some(msg) if msg.received_at.elapsed() > config::global::get(config::MESSAGE_DELIVERY_TIMEOUT)
)
}

Expand All @@ -416,10 +429,10 @@ impl<M: RemoteMessage> NetTx<M> {
/// branches.
async fn wait_for_timeout(&self) {
match self.deque.front() {
Some((_, _, received_at, _)) => {
Some(msg) => {
RealClock
.sleep_until(
received_at.clone()
msg.received_at.clone()
+ config::global::get(config::MESSAGE_DELIVERY_TIMEOUT),
)
.await
Expand Down Expand Up @@ -728,11 +741,11 @@ impl<M: RemoteMessage> NetTx<M> {
.deque
.drain(..)
.chain(outbox.deque.drain(..))
.filter_map(|(_, bytes, _, return_channel)| {
bincode::deserialize(&bytes)
.filter_map(|queued_msg| {
bincode::deserialize(&queued_msg.data)
.ok()
.and_then(|frame| match frame {
Frame::Message(_, msg) => Some((return_channel, msg)),
Frame::Message(_, msg) => Some((queued_msg.return_channel, msg)),
_ => None,
})
})
Expand Down