Skip to content

Commit 075ec07

Browse files
thomasywangfacebook-github-bot
authored andcommitted
Signal handler for RemoteProcessAlloc (#540)
Summary: What's going on here: 1. `RemoteProcessAlloc` is instantiated in the client code (ex. https://fburl.com/code/p4t5aewo) 2. `RemoteProcessAlloc::new()` now spawns a signal handler and holds onto the `JoinHandle`. A tx-rx pair is created so that the signal handler task is aware of the addresses of hosts in `RemoteProcessAlloc::host_states` as they are added and removed 3. `RemoteProcessAlloc::host_states` is now wrapped in a struct `HostStates` which contains the tx side and aims to have the same interface as a `HashMap` but sends updates to the map over the tx. When a `RemoteProcessAllocHostState` is inserted, the address and `HostId` is sent over the tx. When a `RemoteProcessAllocHostState` is removed, the `HostId` is sent over the tx (address is None). 4. When the handler receives a `HostId` and `Some(ChannelAddr)` it will dial this address, and insert the `ChannelTx` into it's own `HashMap` with the `HostId` as the key 5. When the handler receives a `HostId` and `None`, it will remove the corresponding entry from it's `HashMap` 6. When the handler receives a signal, it will iterate over all `ChannelTx`s in the `HashMap` and send `RemoteProcessAllocatorMessage::Signal(signal)` over each `ChannelTx` to the `RemoteProcessAllocator` running on a remote machine 7. The`RemoteProcessAllocator` receives the message. If the signal == SIGINT, it calls `ensure_previous_alloc_stopped` to stop gracefully, then reraises the signal Reviewed By: moonli Differential Revision: D78097380
1 parent 83ea22e commit 075ec07

File tree

8 files changed

+527
-4
lines changed

8 files changed

+527
-4
lines changed

hyperactor/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ serde_bytes = "0.11"
4646
serde_json = { version = "1.0.140", features = ["alloc", "float_roundtrip", "unbounded_depth"] }
4747
serde_with = { version = "3", features = ["hex", "json"] }
4848
serde_yaml = "0.9.25"
49+
signal-hook-tokio = { version = "0.3", features = ["futures-v0_3"] }
4950
thiserror = "2.0.12"
5051
tokio = { version = "1.45.0", features = ["full", "test-util", "tracing"] }
5152
tokio-rustls = { version = "0.24.1", features = ["dangerous_configuration"] }

hyperactor/src/lib.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ pub mod panic_handler;
8484
mod parse;
8585
pub mod proc;
8686
pub mod reference;
87+
mod signal_handler;
8788
pub mod simnet;
8889
pub mod supervision;
8990
pub mod sync;
@@ -167,6 +168,14 @@ pub use reference::WorldId;
167168
// Re-exported to support tracing in hyperactor_macros codegen.
168169
#[doc(hidden)]
169170
pub use serde_json;
171+
#[doc(inline)]
172+
pub use signal_handler::SignalCleanupGuard;
173+
#[doc(inline)]
174+
pub use signal_handler::register_signal_cleanup;
175+
#[doc(inline)]
176+
pub use signal_handler::register_signal_cleanup_scoped;
177+
#[doc(inline)]
178+
pub use signal_handler::unregister_signal_cleanup;
170179
// Re-exported to support tracing in hyperactor_macros codegen.
171180
#[doc(hidden)]
172181
pub use tracing;

hyperactor/src/signal_handler.rs

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
use std::collections::HashMap;
10+
use std::future::Future;
11+
use std::pin::Pin;
12+
use std::sync::Arc;
13+
use std::sync::Mutex;
14+
use std::sync::OnceLock;
15+
16+
use nix::sys::signal;
17+
use tokio_stream::StreamExt;
18+
19+
type AsyncCleanupCallback = Pin<Box<dyn Future<Output = ()> + Send>>;
20+
21+
/// Global signal manager that coordinates cleanup across all signal handlers
22+
pub(crate) struct GlobalSignalManager {
23+
cleanup_callbacks: Arc<Mutex<HashMap<u64, AsyncCleanupCallback>>>,
24+
next_id: Arc<Mutex<u64>>,
25+
_listener: tokio::task::JoinHandle<()>,
26+
}
27+
28+
impl GlobalSignalManager {
29+
fn new() -> Self {
30+
let listener = tokio::spawn(async move {
31+
if let Ok(mut signals) =
32+
signal_hook_tokio::Signals::new([signal::SIGINT as i32, signal::SIGTERM as i32])
33+
{
34+
if let Some(signal) = signals.next().await {
35+
tracing::info!("received signal: {}", signal);
36+
37+
get_signal_manager().execute_all_cleanups().await;
38+
39+
match signal::Signal::try_from(signal) {
40+
Ok(sig) => {
41+
if let Err(err) =
42+
// SAFETY: We're setting the handle to SigDfl (default system behaviour)
43+
unsafe { signal::signal(sig, signal::SigHandler::SigDfl) }
44+
{
45+
tracing::error!(
46+
"failed to restore default signal handler for {}: {}",
47+
sig,
48+
err
49+
);
50+
}
51+
52+
// Re-raise the signal to trigger default behavior (process termination)
53+
if let Err(err) = signal::raise(sig) {
54+
tracing::error!("failed to re-raise signal {}: {}", sig, err);
55+
}
56+
}
57+
Err(err) => {
58+
tracing::error!("failed to convert signal {}: {}", signal, err);
59+
}
60+
}
61+
}
62+
}
63+
});
64+
Self {
65+
cleanup_callbacks: Arc::new(Mutex::new(HashMap::new())),
66+
next_id: Arc::new(Mutex::new(0)),
67+
_listener: listener,
68+
}
69+
}
70+
71+
/// Register a cleanup callback and return a unique ID for later unregistration
72+
fn register_cleanup(&self, callback: AsyncCleanupCallback) -> u64 {
73+
let mut next_id = self.next_id.lock().unwrap_or_else(|e| e.into_inner());
74+
let id = *next_id;
75+
*next_id += 1;
76+
drop(next_id);
77+
78+
let mut callbacks = self
79+
.cleanup_callbacks
80+
.lock()
81+
.unwrap_or_else(|e| e.into_inner());
82+
callbacks.insert(id, callback);
83+
tracing::info!("registered signal cleanup callback with ID: {}", id);
84+
id
85+
}
86+
87+
/// Unregister a cleanup callback by ID
88+
fn unregister_cleanup(&self, id: u64) {
89+
let mut callbacks = self
90+
.cleanup_callbacks
91+
.lock()
92+
.unwrap_or_else(|e| e.into_inner());
93+
if callbacks.remove(&id).is_some() {
94+
tracing::info!("unregistered signal cleanup callback with ID: {}", id);
95+
} else {
96+
tracing::warn!(
97+
"attempted to unregister non-existent cleanup callback with ID: {}",
98+
id
99+
);
100+
}
101+
}
102+
103+
/// Execute all registered cleanup callbacks asynchronously
104+
async fn execute_all_cleanups(&self) {
105+
let callbacks = {
106+
let mut callbacks = self
107+
.cleanup_callbacks
108+
.lock()
109+
.unwrap_or_else(|e| e.into_inner());
110+
std::mem::take(&mut *callbacks)
111+
};
112+
113+
let futures = callbacks.into_iter().map(|(id, future)| async move {
114+
tracing::debug!("executing cleanup callback with ID: {}", id);
115+
future.await;
116+
});
117+
118+
futures::future::join_all(futures).await;
119+
}
120+
}
121+
122+
/// Global instance of the signal manager
123+
static SIGNAL_MANAGER: OnceLock<GlobalSignalManager> = OnceLock::new();
124+
125+
/// Get the global signal manager instance
126+
pub(crate) fn get_signal_manager() -> &'static GlobalSignalManager {
127+
SIGNAL_MANAGER.get_or_init(GlobalSignalManager::new)
128+
}
129+
130+
/// RAII guard that automatically unregisters a signal cleanup callback when dropped
131+
pub struct SignalCleanupGuard {
132+
id: u64,
133+
}
134+
135+
impl SignalCleanupGuard {
136+
fn new(id: u64) -> Self {
137+
Self { id }
138+
}
139+
140+
/// Get the ID of the registered cleanup callback
141+
pub fn id(&self) -> u64 {
142+
self.id
143+
}
144+
}
145+
146+
impl Drop for SignalCleanupGuard {
147+
fn drop(&mut self) {
148+
get_signal_manager().unregister_cleanup(self.id);
149+
}
150+
}
151+
152+
/// Register a cleanup callback to be executed on SIGINT/SIGTERM
153+
/// Returns a unique ID that can be used to unregister the callback
154+
pub fn register_signal_cleanup(callback: AsyncCleanupCallback) -> u64 {
155+
get_signal_manager().register_cleanup(callback)
156+
}
157+
158+
/// Register a scoped cleanup callback to be executed on SIGINT/SIGTERM
159+
/// Returns a guard that automatically unregisters the callback when dropped
160+
pub fn register_signal_cleanup_scoped(callback: AsyncCleanupCallback) -> SignalCleanupGuard {
161+
let id = get_signal_manager().register_cleanup(callback);
162+
SignalCleanupGuard::new(id)
163+
}
164+
165+
/// Unregister a previously registered cleanup callback
166+
pub fn unregister_signal_cleanup(id: u64) {
167+
get_signal_manager().unregister_cleanup(id);
168+
}

hyperactor_mesh/Cargo.toml

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# @generated by autocargo from //monarch/hyperactor_mesh:[hyperactor_mesh,hyperactor_mesh_test_bootstrap,process_allocator_cleanup,process_allocator_test_bin,process_allocator_test_bootstrap]
1+
# @generated by autocargo from //monarch/hyperactor_mesh:[hyperactor_mesh,hyperactor_mesh_test_bootstrap,hyperactor_mesh_test_remote_process_alloc,hyperactor_mesh_test_remote_process_allocator,process_allocator_cleanup,process_allocator_test_bin,process_allocator_test_bootstrap]
22

33
[package]
44
name = "hyperactor_mesh"
@@ -11,6 +11,14 @@ license = "BSD-3-Clause"
1111
name = "hyperactor_mesh_test_bootstrap"
1212
path = "test/bootstrap.rs"
1313

14+
[[bin]]
15+
name = "hyperactor_mesh_test_remote_process_alloc"
16+
path = "test/remote_process_alloc.rs"
17+
18+
[[bin]]
19+
name = "hyperactor_mesh_test_remote_process_allocator"
20+
path = "test/remote_process_allocator.rs"
21+
1422
[[bin]]
1523
name = "process_allocator_test_bin"
1624
path = "test/process_allocator_cleanup/process_allocator_test_bin.rs"
@@ -29,6 +37,7 @@ async-trait = "0.1.86"
2937
bincode = "1.3.3"
3038
bitmaps = "3.2.1"
3139
buck-resources = "1"
40+
clap = { version = "4.5.38", features = ["derive", "env", "string", "unicode", "wrap_help"] }
3241
dashmap = { version = "5.5.3", features = ["rayon", "serde"] }
3342
enum-as-inner = "0.6.0"
3443
erased-serde = "0.3.27"

hyperactor_mesh/src/alloc/process.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,8 @@ impl Alloc for ProcessAlloc {
466466
let (_stderr, _) = stderr.join().await;
467467
}
468468

469+
tracing::info!("child stopped with ProcStopReason::{:?}", reason);
470+
469471
break Some(ProcState::Stopped {
470472
proc_id: ProcId(WorldId(self.name.to_string()), index),
471473
reason

0 commit comments

Comments
 (0)