Skip to content

Commit 440df2e

Browse files
thomasywangfacebook-github-bot
authored andcommitted
Stop actor meshes (#537)
Summary: Pull Request resolved: #537 Reviewed By: moonli Differential Revision: D78283408
1 parent 2d91d14 commit 440df2e

File tree

10 files changed

+251
-19
lines changed

10 files changed

+251
-19
lines changed

hyperactor/src/config.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ declare_attrs! {
4343

4444
/// Maximum buffer size for split port messages
4545
pub attr SPLIT_MAX_BUFFER_SIZE: usize = 5;
46+
47+
/// Timeout used by proc mesh for stopping an actor.
48+
pub attr STOP_ACTOR_TIMEOUT: Duration = Duration::from_secs(1);
4649
}
4750

4851
/// Load configuration from environment variables

hyperactor/src/proc.rs

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,7 @@ impl Proc {
490490
/// Call `abort` on the `JoinHandle` associated with the given
491491
/// root actor. If successful return `Some(root.clone())` else
492492
/// `None`.
493-
pub fn abort_root_actor(&mut self, root: &ActorId) -> Option<ActorId> {
493+
pub fn abort_root_actor(&self, root: &ActorId) -> Option<ActorId> {
494494
self.state()
495495
.ledger
496496
.roots
@@ -511,17 +511,12 @@ impl Proc {
511511
.next()
512512
}
513513

514-
// Iterating over a proc's root actors signaling each to stop.
515-
// Return the root actor IDs and status observers.
516-
async fn destroy(
517-
&mut self,
518-
) -> Result<HashMap<ActorId, watch::Receiver<ActorStatus>>, anyhow::Error> {
519-
tracing::debug!("{}: proc stopping", self.proc_id());
520-
521-
let mut statuses = HashMap::new();
522-
for entry in self.state().ledger.roots.iter() {
514+
/// Signals to a root actor to stop,
515+
/// returning a status observer if successful.
516+
pub fn stop_actor(&self, actor_id: &ActorId) -> Option<watch::Receiver<ActorStatus>> {
517+
if let Some(entry) = self.state().ledger.roots.get(actor_id) {
523518
match entry.value().upgrade() {
524-
None => (), // the root's cell has been dropped
519+
None => None, // the root's cell has been dropped
525520
Some(cell) => {
526521
tracing::info!("sending stop signal to {}", cell.actor_id());
527522
if let Err(err) = cell.signal(Signal::DrainAndStop) {
@@ -531,15 +526,16 @@ impl Proc {
531526
cell.pid(),
532527
err
533528
);
534-
continue;
529+
None
530+
} else {
531+
Some(cell.status().clone())
535532
}
536-
statuses.insert(cell.actor_id().clone(), cell.status().clone());
537533
}
538534
}
535+
} else {
536+
tracing::error!("no actor {} found in {} roots", actor_id, self.proc_id());
537+
None
539538
}
540-
541-
tracing::debug!("{}: proc stopped", self.proc_id());
542-
Ok(statuses)
543539
}
544540

545541
/// Stop the proc. Returns a pair of:
@@ -553,7 +549,23 @@ impl Proc {
553549
timeout: Duration,
554550
skip_waiting: Option<&ActorId>,
555551
) -> Result<(Vec<ActorId>, Vec<ActorId>), anyhow::Error> {
556-
let mut statuses = self.destroy().await?;
552+
tracing::debug!("{}: proc stopping", self.proc_id());
553+
554+
let mut statuses = HashMap::new();
555+
for actor_id in self
556+
.state()
557+
.ledger
558+
.roots
559+
.iter()
560+
.map(|entry| entry.key().clone())
561+
.collect::<Vec<_>>()
562+
{
563+
if let Some(status) = self.stop_actor(&actor_id) {
564+
statuses.insert(actor_id, status);
565+
}
566+
}
567+
tracing::debug!("{}: proc stopped", self.proc_id());
568+
557569
let waits: Vec<_> = statuses
558570
.iter_mut()
559571
.filter(|(actor_id, _)| Some(*actor_id) != skip_waiting)

hyperactor_mesh/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ tracing-subscriber = { version = "0.3.19", features = ["chrono", "env-filter", "
5959
[dev-dependencies]
6060
maplit = "1.0"
6161
timed_test = { version = "0.0.0", path = "../timed_test" }
62+
tracing-test = { version = "0.2.3", features = ["no-env-filter"] }
6263

6364
[lints]
6465
rust = { unexpected_cfgs = { check-cfg = ["cfg(fbcode_build)"], level = "warn" } }

hyperactor_mesh/src/actor_mesh.rs

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,13 @@ pub trait ActorMesh: Mesh<Id = ActorMeshId> {
178178
self.shape().slice().iter().map(move |rank| gang.rank(rank))
179179
}
180180

181+
fn stop(&self) -> impl std::future::Future<Output = Result<(), anyhow::Error>> + Send
182+
where
183+
Self: Sync,
184+
{
185+
async { self.proc_mesh().stop_actor_by_name(self.name()).await }
186+
}
187+
181188
/// Get a serializeable reference to this mesh similar to ActorHandle::bind
182189
fn bind(&self) -> ActorMeshRef<Self::Actor> {
183190
ActorMeshRef::attest(
@@ -1023,6 +1030,69 @@ mod tests {
10231030
);
10241031
assert!(events.next().await.is_none());
10251032
}
1033+
1034+
#[tracing_test::traced_test]
1035+
#[tokio::test]
1036+
async fn test_stop_actor_mesh() {
1037+
use hyperactor::test_utils::pingpong::PingPongActor;
1038+
use hyperactor::test_utils::pingpong::PingPongActorParams;
1039+
use hyperactor::test_utils::pingpong::PingPongMessage;
1040+
1041+
let config = hyperactor::config::global::lock();
1042+
let _guard = config.override_key(
1043+
hyperactor::config::MESSAGE_DELIVERY_TIMEOUT,
1044+
tokio::time::Duration::from_secs(1),
1045+
);
1046+
1047+
let alloc = LocalAllocator
1048+
.allocate(AllocSpec {
1049+
shape: shape! { replica = 2 },
1050+
constraints: Default::default(),
1051+
})
1052+
.await
1053+
.unwrap();
1054+
let mesh = ProcMesh::allocate(alloc).await.unwrap();
1055+
1056+
let ping_pong_actor_params = PingPongActorParams::new(
1057+
PortRef::attest_message_port(mesh.client().actor_id()),
1058+
None,
1059+
);
1060+
let mesh_one: RootActorMesh<PingPongActor> = mesh
1061+
.spawn::<PingPongActor>("mesh_one", &ping_pong_actor_params)
1062+
.await
1063+
.unwrap();
1064+
1065+
let mesh_two: RootActorMesh<PingPongActor> = mesh
1066+
.spawn::<PingPongActor>("mesh_two", &ping_pong_actor_params)
1067+
.await
1068+
.unwrap();
1069+
1070+
mesh_two.stop().await.unwrap();
1071+
1072+
let ping_two: ActorRef<PingPongActor> = mesh_two.get(0).unwrap();
1073+
let pong_two: ActorRef<PingPongActor> = mesh_two.get(1).unwrap();
1074+
1075+
assert!(logs_contain(&format!(
1076+
"stopped actor {}",
1077+
ping_two.actor_id()
1078+
)));
1079+
assert!(logs_contain(&format!(
1080+
"stopped actor {}",
1081+
pong_two.actor_id()
1082+
)));
1083+
1084+
// Other actor meshes on this proc mesh should still be up and running
1085+
let ping_one: ActorRef<PingPongActor> = mesh_one.get(0).unwrap();
1086+
let pong_one: ActorRef<PingPongActor> = mesh_one.get(1).unwrap();
1087+
let (done_tx, done_rx) = mesh.client().open_once_port();
1088+
pong_one
1089+
.send(
1090+
mesh.client(),
1091+
PingPongMessage(1, ping_one.clone(), done_tx.bind()),
1092+
)
1093+
.unwrap();
1094+
assert!(done_rx.recv().await.is_ok());
1095+
}
10261096
} // mod local
10271097

10281098
mod process {

hyperactor_mesh/src/proc_mesh.rs

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@ use std::sync::Arc;
1414

1515
use async_trait::async_trait;
1616
use dashmap::DashMap;
17+
use futures::future::join_all;
1718
use hyperactor::Actor;
19+
use hyperactor::ActorId;
1820
use hyperactor::ActorRef;
1921
use hyperactor::Mailbox;
2022
use hyperactor::Named;
@@ -56,6 +58,7 @@ use crate::comm::CommActorMode;
5658
use crate::proc_mesh::mesh_agent::GspawnResult;
5759
use crate::proc_mesh::mesh_agent::MeshAgent;
5860
use crate::proc_mesh::mesh_agent::MeshAgentMessageClient;
61+
use crate::proc_mesh::mesh_agent::StopActorResult;
5962
use crate::reference::ProcMeshId;
6063

6164
pub mod mesh_agent;
@@ -449,6 +452,40 @@ impl ProcMesh {
449452
pub fn shape(&self) -> &Shape {
450453
&self.shape
451454
}
455+
456+
/// Send stop actors message to all mesh agents for a specific mesh name
457+
pub async fn stop_actor_by_name(&self, mesh_name: &str) -> Result<(), anyhow::Error> {
458+
let timeout = hyperactor::config::global::get(hyperactor::config::STOP_ACTOR_TIMEOUT);
459+
let results = join_all(self.agents().map(|agent| async move {
460+
let actor_id = ActorId(agent.actor_id().proc_id().clone(), mesh_name.to_string(), 0);
461+
(
462+
actor_id.clone(),
463+
agent
464+
.clone()
465+
.stop_actor(&self.client, actor_id, timeout.as_millis() as u64)
466+
.await,
467+
)
468+
}))
469+
.await;
470+
471+
for (actor_id, result) in results {
472+
match result {
473+
Ok(StopActorResult::Timeout) => {
474+
tracing::error!("timed out while stopping actor {}", actor_id);
475+
}
476+
Ok(StopActorResult::NotFound) => {
477+
tracing::error!("no actor {} on proc {}", actor_id, actor_id.proc_id());
478+
}
479+
Ok(StopActorResult::Success) => {
480+
tracing::info!("stopped actor {}", actor_id);
481+
}
482+
Err(e) => {
483+
tracing::error!("error stopping actor {}: {}", actor_id, e);
484+
}
485+
}
486+
}
487+
Ok(())
488+
}
452489
}
453490

454491
/// Proc lifecycle events.

hyperactor_mesh/src/proc_mesh/mesh_agent.rs

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,17 @@ use hyperactor::HandleClient;
2525
use hyperactor::Handler;
2626
use hyperactor::Instance;
2727
use hyperactor::Named;
28+
use hyperactor::OncePortRef;
2829
use hyperactor::PortHandle;
2930
use hyperactor::PortRef;
3031
use hyperactor::ProcId;
3132
use hyperactor::RefClient;
33+
use hyperactor::actor::ActorStatus;
3234
use hyperactor::actor::remote::Remote;
3335
use hyperactor::channel;
3436
use hyperactor::channel::ChannelAddr;
37+
use hyperactor::clock::Clock;
38+
use hyperactor::clock::RealClock;
3539
use hyperactor::mailbox::BoxedMailboxSender;
3640
use hyperactor::mailbox::DeliveryError;
3741
use hyperactor::mailbox::DialMailboxRouter;
@@ -52,11 +56,17 @@ pub enum GspawnResult {
5256
Error(String),
5357
}
5458

59+
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Named)]
60+
pub enum StopActorResult {
61+
Success,
62+
Timeout,
63+
NotFound,
64+
}
65+
5566
#[derive(
5667
Debug,
5768
Clone,
5869
PartialEq,
59-
Eq,
6070
Serialize,
6171
Deserialize,
6272
Handler,
@@ -91,6 +101,17 @@ pub(crate) enum MeshAgentMessage {
91101
/// reply port; the proc should send its rank to indicated a spawned actor
92102
status_port: PortRef<GspawnResult>,
93103
},
104+
105+
/// Stop actors of a specific mesh name
106+
StopActor {
107+
/// The actor to stop
108+
actor_id: ActorId,
109+
/// The timeout for waiting for the actor to stop
110+
timeout_ms: u64,
111+
/// The result when trying to stop the actor
112+
#[reply]
113+
stopped: OncePortRef<StopActorResult>,
114+
},
94115
}
95116

96117
/// A mesh agent is responsible for managing procs in a [`ProcMesh`].
@@ -224,6 +245,30 @@ impl MeshAgentMessageHandler for MeshAgent {
224245
status_port.send(cx, GspawnResult::Success { rank, actor_id })?;
225246
Ok(())
226247
}
248+
249+
async fn stop_actor(
250+
&mut self,
251+
_cx: &Context<Self>,
252+
actor_id: ActorId,
253+
timeout_ms: u64,
254+
) -> Result<StopActorResult, anyhow::Error> {
255+
tracing::info!("Stopping actor: {}", actor_id);
256+
257+
if let Some(mut status) = self.proc.stop_actor(&actor_id) {
258+
match RealClock
259+
.timeout(
260+
tokio::time::Duration::from_millis(timeout_ms),
261+
status.wait_for(|state: &ActorStatus| matches!(*state, ActorStatus::Stopped)),
262+
)
263+
.await
264+
{
265+
Ok(_) => Ok(StopActorResult::Success),
266+
Err(_) => Ok(StopActorResult::Timeout),
267+
}
268+
} else {
269+
Ok(StopActorResult::NotFound)
270+
}
271+
}
227272
}
228273

229274
#[async_trait]

monarch_hyperactor/src/actor_mesh.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,25 @@ impl PythonActorMesh {
225225
fn __reduce_ex__(&self, _proto: u8) -> PyResult<()> {
226226
Err(self.pickling_err())
227227
}
228+
229+
fn stop<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
230+
let actor_mesh = self.inner.clone();
231+
pyo3_async_runtimes::tokio::future_into_py(py, async move {
232+
let actor_mesh = actor_mesh
233+
.take()
234+
.await
235+
.map_err(|_| PyRuntimeError::new_err("`ActorMesh` has already been stopped"))?;
236+
actor_mesh.stop().await.map_err(|err| {
237+
PyException::new_err(format!("Failed to stop actor mesh: {}", err))
238+
})?;
239+
Ok(())
240+
})
241+
}
242+
243+
#[getter]
244+
fn stopped(&self) -> PyResult<bool> {
245+
Ok(self.inner.borrow().is_err())
246+
}
228247
}
229248

230249
#[pyclass(

python/monarch/_rust_bindings/monarch_hyperactor/actor_mesh.pyi

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,20 @@ class PythonActorMesh:
114114
"""
115115
...
116116

117+
async def stop(self) -> None:
118+
"""
119+
Stop all actors that are part of this mesh.
120+
Using this mesh after stop() is called will raise an Exception.
121+
"""
122+
...
123+
124+
@property
125+
def stopped(self) -> bool:
126+
"""
127+
If the mesh has been stopped.
128+
"""
129+
...
130+
117131
@final
118132
class ActorMeshMonitor:
119133
def __aiter__(self) -> AsyncIterator["ActorSupervisionEvent"]:

0 commit comments

Comments
 (0)