Skip to content

Commit 6919663

Browse files
thomasywangfacebook-github-bot
authored andcommitted
Signal handler for RemoteProcessAlloc (#540)
Summary: Pull Request resolved: #540 What's going on here: 1. `RemoteProcessAlloc` is instantiated in the client code (ex. https://fburl.com/code/p4t5aewo) 2. `RemoteProcessAlloc::new()` now spawns a signal handler and holds onto the `JoinHandle`. A tx-rx pair is created so that the signal handler task is aware of the addresses of hosts in `RemoteProcessAlloc::host_states` as they are added and removed 3. `RemoteProcessAlloc::host_states` is now wrapped in a struct `HostStates` which contains the tx side and aims to have the same interface as a `HashMap` but sends updates to the map over the tx. When a `RemoteProcessAllocHostState` is inserted, the address and `HostId` is sent over the tx. When a `RemoteProcessAllocHostState` is removed, the `HostId` is sent over the tx (address is None). 4. When the handler receives a `HostId` and `Some(ChannelAddr)` it will dial this address, and insert the `ChannelTx` into it's own `HashMap` with the `HostId` as the key 5. When the handler receives a `HostId` and `None`, it will remove the corresponding entry from it's `HashMap` 6. When the handler receives a signal, it will iterate over all `ChannelTx`s in the `HashMap` and send `RemoteProcessAllocatorMessage::Signal(signal)` over each `ChannelTx` to the `RemoteProcessAllocator` running on a remote machine 7. The`RemoteProcessAllocator` receives the message. If the signal == SIGINT, it calls `ensure_previous_alloc_stopped` to stop gracefully, then reraises the signal Reviewed By: moonli Differential Revision: D78097380
1 parent 2ff2097 commit 6919663

File tree

5 files changed

+389
-4
lines changed

5 files changed

+389
-4
lines changed

hyperactor_mesh/Cargo.toml

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# @generated by autocargo from //monarch/hyperactor_mesh:[hyperactor_mesh,hyperactor_mesh_test_bootstrap,process_allocator_cleanup,process_allocator_test_bin,process_allocator_test_bootstrap]
1+
# @generated by autocargo from //monarch/hyperactor_mesh:[hyperactor_mesh,hyperactor_mesh_test_bootstrap,hyperactor_mesh_test_remote_process_alloc,hyperactor_mesh_test_remote_process_allocator,process_allocator_cleanup,process_allocator_test_bin,process_allocator_test_bootstrap]
22

33
[package]
44
name = "hyperactor_mesh"
@@ -11,6 +11,14 @@ license = "BSD-3-Clause"
1111
name = "hyperactor_mesh_test_bootstrap"
1212
path = "test/bootstrap.rs"
1313

14+
[[bin]]
15+
name = "hyperactor_mesh_test_remote_process_alloc"
16+
path = "test/remote_process_alloc.rs"
17+
18+
[[bin]]
19+
name = "hyperactor_mesh_test_remote_process_allocator"
20+
path = "test/remote_process_allocator.rs"
21+
1422
[[bin]]
1523
name = "process_allocator_test_bin"
1624
path = "test/process_allocator_cleanup/process_allocator_test_bin.rs"
@@ -29,6 +37,7 @@ async-trait = "0.1.86"
2937
bincode = "1.3.3"
3038
bitmaps = "3.2.1"
3139
buck-resources = "1"
40+
clap = { version = "4.5.38", features = ["derive", "env", "string", "unicode", "wrap_help"] }
3241
dashmap = { version = "5.5.3", features = ["rayon", "serde"] }
3342
enum-as-inner = "0.6.0"
3443
erased-serde = "0.3.27"
@@ -59,6 +68,7 @@ tracing-subscriber = { version = "0.3.19", features = ["chrono", "env-filter", "
5968
[dev-dependencies]
6069
maplit = "1.0"
6170
timed_test = { version = "0.0.0", path = "../timed_test" }
71+
tracing-test = { version = "0.2.3", features = ["no-env-filter"] }
6272

6373
[lints]
6474
rust = { unexpected_cfgs = { check-cfg = ["cfg(fbcode_build)"], level = "warn" } }

hyperactor_mesh/src/alloc/process.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,8 @@ impl Alloc for ProcessAlloc {
466466
let (_stderr, _) = stderr.join().await;
467467
}
468468

469+
tracing::info!("child stopped with ProcStopReason::{:?}", reason);
470+
469471
break Some(ProcState::Stopped {
470472
proc_id: ProcId(WorldId(self.name.to_string()), index),
471473
reason

hyperactor_mesh/src/alloc/remoteprocess.rs

Lines changed: 199 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ use anyhow::Context;
1616
use async_trait::async_trait;
1717
use futures::FutureExt;
1818
use futures::future::select_all;
19+
use futures::future::try_join_all;
1920
use hyperactor::Named;
2021
use hyperactor::ProcId;
2122
use hyperactor::WorldId;
@@ -37,6 +38,7 @@ use hyperactor::reference::Reference;
3738
use hyperactor::serde_json;
3839
use mockall::automock;
3940
use ndslice::Shape;
41+
use nix::sys::signal;
4042
use serde::Deserialize;
4143
use serde::Serialize;
4244
use tokio::io::AsyncWriteExt;
@@ -78,6 +80,8 @@ pub enum RemoteProcessAllocatorMessage {
7880
/// Heartbeat message to check if remote process allocator and its
7981
/// host are alive.
8082
HeartBeat,
83+
/// Stop allocation and terminate
84+
Terminate,
8185
}
8286

8387
/// Control message sent from local allocator to remote allocator
@@ -221,6 +225,9 @@ impl RemoteProcessAllocator {
221225
}
222226
}
223227
}
228+
Ok(RemoteProcessAllocatorMessage::Terminate) => {
229+
self.terminate();
230+
}
224231
Ok(RemoteProcessAllocatorMessage::Stop) => {
225232
tracing::info!("received stop request");
226233

@@ -475,6 +482,58 @@ pub trait RemoteProcessAllocInitializer {
475482
async fn initialize_alloc(&mut self) -> Result<Vec<RemoteProcessAllocHost>, anyhow::Error>;
476483
}
477484

485+
/// Wrapper struct around `HashMap<HostId, RemoteProcessAllocHostState>`
486+
/// to ensure that host addresses are synced with the signal handler
487+
struct HostStates {
488+
inner: HashMap<HostId, RemoteProcessAllocHostState>,
489+
host_address_tx: UnboundedSender<(HostId, Option<ChannelAddr>)>,
490+
}
491+
492+
impl HostStates {
493+
fn new(host_address_tx: UnboundedSender<(HostId, Option<ChannelAddr>)>) -> HostStates {
494+
Self {
495+
inner: HashMap::new(),
496+
host_address_tx,
497+
}
498+
}
499+
500+
fn insert(
501+
&mut self,
502+
host_id: HostId,
503+
state: RemoteProcessAllocHostState,
504+
address: ChannelAddr,
505+
) {
506+
let _ = self.host_address_tx.send((host_id.clone(), Some(address)));
507+
self.inner.insert(host_id, state);
508+
}
509+
510+
fn get(&self, host_id: &HostId) -> Option<&RemoteProcessAllocHostState> {
511+
self.inner.get(host_id)
512+
}
513+
514+
fn get_mut(&mut self, host_id: &HostId) -> Option<&mut RemoteProcessAllocHostState> {
515+
self.inner.get_mut(host_id)
516+
}
517+
518+
fn remove(&mut self, host_id: &HostId) -> Option<RemoteProcessAllocHostState> {
519+
let _ = self.host_address_tx.send((host_id.clone(), None));
520+
self.inner.remove(host_id)
521+
}
522+
523+
fn iter(&self) -> impl Iterator<Item = (&HostId, &RemoteProcessAllocHostState)> {
524+
self.inner.iter()
525+
}
526+
527+
fn iter_mut(&mut self) -> impl Iterator<Item = (&HostId, &mut RemoteProcessAllocHostState)> {
528+
self.inner.iter_mut()
529+
}
530+
531+
fn is_empty(&self) -> bool {
532+
self.inner.is_empty()
533+
}
534+
// Any missing HashMap methods should be added here as needed
535+
}
536+
478537
/// A generalized implementation of an Alloc using one or more hosts running
479538
/// RemoteProcessAlloc for process allocation.
480539
pub struct RemoteProcessAlloc {
@@ -494,14 +553,15 @@ pub struct RemoteProcessAlloc {
494553
// Inidicates that the allocation process has permanently failed.
495554
failed: bool,
496555
hosts_by_offset: HashMap<usize, HostId>,
497-
host_states: HashMap<HostId, RemoteProcessAllocHostState>,
556+
host_states: HostStates,
498557
world_shapes: HashMap<WorldId, Shape>,
499558
event_queue: VecDeque<ProcState>,
500559
comm_watcher_tx: UnboundedSender<HostId>,
501560
comm_watcher_rx: UnboundedReceiver<HostId>,
502561

503562
bootstrap_addr: ChannelAddr,
504563
rx: ChannelRx<RemoteProcessProcStateMessage>,
564+
signal_listener_handler: JoinHandle<()>,
505565
}
506566

507567
impl RemoteProcessAlloc {
@@ -529,6 +589,63 @@ impl RemoteProcessAlloc {
529589

530590
let (comm_watcher_tx, comm_watcher_rx) = unbounded_channel();
531591

592+
let (host_address_tx, mut host_address_rx) =
593+
unbounded_channel::<(HostId, Option<ChannelAddr>)>();
594+
let signal_listener_handler = tokio::spawn({
595+
async move {
596+
let mut signals = signal_hook_tokio::Signals::new([signal::SIGINT as i32]).unwrap();
597+
let mut host_txs = HashMap::new();
598+
loop {
599+
tokio::select! {
600+
Some((host_id, remote_addr)) = host_address_rx.recv() => {
601+
match remote_addr {
602+
Some(addr) => {
603+
let Ok(tx) = channel::dial(addr.clone()) else {
604+
tracing::error!(
605+
"failed to dial remote {} for host {}",
606+
addr, host_id
607+
);
608+
return;
609+
};
610+
host_txs.insert(host_id, tx);
611+
}
612+
None => {
613+
host_txs.remove(&host_id);
614+
}
615+
}
616+
}
617+
signal = signals.next() => {
618+
if let Some(signal) = signal {
619+
if let Err(e) = try_join_all(
620+
// send instead of post to ensure message has been delivered to the remote end of the
621+
// channel before reraising signal and terminating this process
622+
host_txs.values().map(|tx| tx.send(RemoteProcessAllocatorMessage::Terminate))
623+
)
624+
.await {
625+
tracing::error!("error sending RemoteProcessAllocatorMessage: {}", e);
626+
}
627+
628+
match signal::Signal::try_from(signal) {
629+
Ok(sig @ signal::SIGINT) => {
630+
// SAFETY: We're setting the handle to SigDfl (default system behaviour)
631+
if let Err(err) = unsafe {
632+
signal::signal(sig, signal::SigHandler::SigDfl)
633+
} {
634+
tracing::error!("failed to signal {}: {}", sig, err);
635+
}
636+
if let Err(err) = signal::raise(sig) {
637+
tracing::error!("failed to raise {}: {}", sig, err);
638+
}
639+
}
640+
_ => {}
641+
}
642+
}
643+
}
644+
}
645+
}
646+
}
647+
});
648+
532649
Ok(Self {
533650
spec,
534651
world_id,
@@ -539,7 +656,7 @@ impl RemoteProcessAlloc {
539656
world_shapes: HashMap::new(),
540657
ordered_hosts: Vec::new(),
541658
hosts_by_offset: HashMap::new(),
542-
host_states: HashMap::new(),
659+
host_states: HostStates::new(host_address_tx),
543660
bootstrap_addr,
544661
event_queue: VecDeque::new(),
545662
comm_watcher_tx,
@@ -548,6 +665,7 @@ impl RemoteProcessAlloc {
548665
started: false,
549666
running: true,
550667
failed: false,
668+
signal_listener_handler,
551669
})
552670
}
553671

@@ -661,7 +779,8 @@ impl RemoteProcessAlloc {
661779
};
662780

663781
tracing::debug!("dialing remote: {} for host {}", remote_addr, host.id);
664-
let tx = channel::dial(remote_addr.parse()?)
782+
let remote_addr = remote_addr.parse::<ChannelAddr>()?;
783+
let tx = channel::dial(remote_addr.clone())
665784
.map_err(anyhow::Error::from)
666785
.context(format!(
667786
"failed to dial remote {} for host {}",
@@ -690,6 +809,7 @@ impl RemoteProcessAlloc {
690809
failed: false,
691810
allocated: false,
692811
},
812+
remote_addr,
693813
);
694814
}
695815

@@ -817,6 +937,12 @@ impl RemoteProcessAlloc {
817937
}
818938
}
819939

940+
impl Drop for RemoteProcessAlloc {
941+
fn drop(&mut self) {
942+
self.signal_listener_handler.abort();
943+
}
944+
}
945+
820946
#[async_trait]
821947
impl Alloc for RemoteProcessAlloc {
822948
async fn next(&mut self) -> Option<ProcState> {
@@ -1636,8 +1762,11 @@ mod test {
16361762

16371763
#[cfg(test)]
16381764
mod test_alloc {
1765+
use std::os::unix::process::ExitStatusExt;
1766+
16391767
use hyperactor::clock::ClockKind;
16401768
use ndslice::shape;
1769+
use nix::unistd::Pid;
16411770
use timed_test::async_timed_test;
16421771

16431772
use super::*;
@@ -2014,4 +2143,71 @@ mod test_alloc {
20142143
task2_allocator.terminate();
20152144
task2_allocator_handle.await.unwrap();
20162145
}
2146+
2147+
#[tracing_test::traced_test]
2148+
#[async_timed_test(timeout_secs = 60)]
2149+
async fn test_remote_process_alloc_signal_handler() {
2150+
let num_proc_meshes = 5;
2151+
let hosts_per_proc_mesh = 5;
2152+
2153+
let addresses = (0..(num_proc_meshes * hosts_per_proc_mesh))
2154+
.map(|_| ChannelAddr::any(ChannelTransport::Unix).to_string())
2155+
.collect::<Vec<_>>();
2156+
2157+
let remote_process_allocators = addresses
2158+
.iter()
2159+
.map(|addr| {
2160+
Command::new(
2161+
buck_resources::get("monarch/hyperactor_mesh/remote_process_allocator")
2162+
.unwrap(),
2163+
)
2164+
.env("RUST_LOG", "info")
2165+
.arg(format!("--addr={addr}"))
2166+
.stdout(std::process::Stdio::piped())
2167+
.spawn()
2168+
.unwrap()
2169+
})
2170+
.collect::<Vec<_>>();
2171+
2172+
let done_allocating_addr = ChannelAddr::any(ChannelTransport::Unix);
2173+
let (done_allocating_addr, mut done_allocating_rx) =
2174+
channel::serve::<()>(done_allocating_addr).await.unwrap();
2175+
let mut remote_process_alloc = Command::new(
2176+
buck_resources::get("monarch/hyperactor_mesh/remote_process_alloc").unwrap(),
2177+
)
2178+
.arg(format!("--done-allocating-addr={}", done_allocating_addr))
2179+
.arg(format!("--addresses={}", addresses.join(",")))
2180+
.arg(format!("--num-proc-meshes={}", num_proc_meshes))
2181+
.arg(format!("--hosts-per-proc-mesh={}", hosts_per_proc_mesh))
2182+
.spawn()
2183+
.unwrap();
2184+
2185+
done_allocating_rx.recv().await.unwrap();
2186+
2187+
signal::kill(
2188+
Pid::from_raw(remote_process_alloc.id().unwrap() as i32),
2189+
signal::SIGINT,
2190+
)
2191+
.unwrap();
2192+
2193+
assert_eq!(
2194+
remote_process_alloc.wait().await.unwrap().signal(),
2195+
Some(signal::SIGINT as i32)
2196+
);
2197+
2198+
RealClock.sleep(tokio::time::Duration::from_secs(5)).await;
2199+
2200+
for remote_process_allocator in remote_process_allocators {
2201+
let output = remote_process_allocator.wait_with_output().await.unwrap();
2202+
assert!(output.status.success());
2203+
assert!(
2204+
String::from_utf8_lossy(&output.stdout)
2205+
.contains("child stopped with ProcStopReason::Stopped")
2206+
);
2207+
assert!(
2208+
!String::from_utf8_lossy(&output.stdout)
2209+
.contains("child stopped with ProcStopReason::Watchdog")
2210+
);
2211+
}
2212+
}
20172213
}

0 commit comments

Comments
 (0)