diff --git a/monarch_hyperactor/src/actor_mesh.rs b/monarch_hyperactor/src/actor_mesh.rs index 6dcb76b68..1f55da3e3 100644 --- a/monarch_hyperactor/src/actor_mesh.rs +++ b/monarch_hyperactor/src/actor_mesh.rs @@ -46,6 +46,7 @@ use crate::proc_mesh::Keepalive; use crate::selection::PySelection; use crate::shape::PyShape; use crate::supervision::SupervisionError; +use crate::supervision::Unhealthy; #[pyclass( name = "PythonActorMesh", @@ -55,7 +56,7 @@ pub struct PythonActorMesh { inner: SharedCell>, client: PyMailbox, _keepalive: Keepalive, - unhealthy_event: Arc>>>, + unhealthy_event: Arc>>, user_monitor_sender: tokio::sync::broadcast::Sender>, monitor: tokio::task::JoinHandle<()>, } @@ -71,11 +72,11 @@ impl PythonActorMesh { ) -> Self { let (user_monitor_sender, _) = tokio::sync::broadcast::channel::>(1); - let unhealthy_event = Arc::new(std::sync::Mutex::new(None)); + let unhealthy_event = Arc::new(std::sync::Mutex::new(Unhealthy::SoFarSoGood)); let monitor = tokio::spawn(Self::actor_mesh_monitor( events, user_monitor_sender.clone(), - unhealthy_event.clone(), + Arc::clone(&unhealthy_event), )); Self { inner, @@ -92,15 +93,19 @@ impl PythonActorMesh { async fn actor_mesh_monitor( mut events: ActorSupervisionEvents, user_sender: tokio::sync::broadcast::Sender>, - unhealthy_event: Arc>>>, + unhealthy_event: Arc>>, ) { loop { let event = events.next().await; let mut inner_unhealthy_event = unhealthy_event.lock().unwrap(); - *inner_unhealthy_event = Some(event.clone()); + match &event { + None => *inner_unhealthy_event = Unhealthy::StreamClosed, + Some(event) => *inner_unhealthy_event = Unhealthy::Crashed(event.clone()), + } - // Ignore the sender error when there is no receiver, which happens when there - // is no active requests to this mesh. + // Ignore the sender error when there is no receiver, + // which happens when there is no active requests to this + // mesh. let _ = user_sender.send(event.clone()); if event.is_none() { @@ -132,11 +137,20 @@ impl PythonActorMesh { .unhealthy_event .lock() .expect("failed to acquire unhealthy_event lock"); - if let Some(ref event) = *unhealthy_event { - return Err(PyRuntimeError::new_err(format!( - "actor mesh is unhealthy with reason: {:?}", - event - ))); + + match &*unhealthy_event { + Unhealthy::SoFarSoGood => (), + Unhealthy::Crashed(event) => { + return Err(PyRuntimeError::new_err(format!( + "actor mesh is unhealthy with reason: {:?}", + event + ))); + } + Unhealthy::StreamClosed => { + return Err(PyRuntimeError::new_err( + "actor mesh is stopped due to proc mesh shutdown".to_string(), + )); + } } self.try_inner()? @@ -156,15 +170,16 @@ impl PythonActorMesh { .lock() .expect("failed to acquire unhealthy_event lock"); - Ok(unhealthy_event.as_ref().map(|event| match event { - None => PyActorSupervisionEvent { + match &*unhealthy_event { + Unhealthy::SoFarSoGood => Ok(None), + Unhealthy::StreamClosed => Ok(Some(PyActorSupervisionEvent { // Dummy actor as place holder to indicate the whole mesh is stopped // TODO(albertli): remove this when pushing all supervision logic to rust. actor_id: id!(default[0].actor[0]).into(), actor_status: "actor mesh is stopped due to proc mesh shutdown".to_string(), - }, - Some(event) => PyActorSupervisionEvent::from(event.clone()), - })) + })), + Unhealthy::Crashed(event) => Ok(Some(PyActorSupervisionEvent::from(event.clone()))), + } } // Consider defining a "PythonActorRef", which carries specifically diff --git a/monarch_hyperactor/src/proc_mesh.rs b/monarch_hyperactor/src/proc_mesh.rs index e96eb4a0c..cc0dac0d9 100644 --- a/monarch_hyperactor/src/proc_mesh.rs +++ b/monarch_hyperactor/src/proc_mesh.rs @@ -46,6 +46,7 @@ use crate::mailbox::PyMailbox; use crate::runtime::signal_safe_block_on; use crate::shape::PyShape; use crate::supervision::SupervisionError; +use crate::supervision::Unhealthy; // A wrapper around `ProcMesh` which keeps track of all `RootActorMesh`s that it spawns. pub struct TrackedProcMesh { @@ -119,7 +120,7 @@ pub struct PyProcMesh { proc_events: SharedCell>, user_monitor_receiver: SharedCell>>, user_monitor_registered: Arc, - unhealthy_event: Arc>>>, + unhealthy_event: Arc>>, } fn allocate_proc_mesh(alloc: &PyAlloc) -> PyResult { @@ -146,15 +147,15 @@ impl PyProcMesh { let proc_events = SharedCell::from(Mutex::new(proc_mesh.events().unwrap())); let (user_sender, user_receiver) = mpsc::unbounded_channel::(); let user_monitor_registered = Arc::new(AtomicBool::new(false)); - let unhealthy_event = Arc::new(Mutex::new(None)); + let unhealthy_event = Arc::new(Mutex::new(Unhealthy::SoFarSoGood)); let monitor = tokio::spawn(Self::default_proc_mesh_monitor( proc_events .borrow() .expect("borrowing immediately after creation"), world_id, user_sender, - user_monitor_registered.clone(), - unhealthy_event.clone(), + Arc::clone(&user_monitor_registered), + Arc::clone(&unhealthy_event), )); Self { inner: SharedCell::from(TrackedProcMesh::from(proc_mesh)), @@ -172,7 +173,7 @@ impl PyProcMesh { world_id: WorldId, user_sender: mpsc::UnboundedSender, user_monitor_registered: Arc, - unhealthy_event: Arc>>>, + unhealthy_event: Arc>>, ) { loop { let mut proc_events = events.lock().await; @@ -181,7 +182,7 @@ impl PyProcMesh { let mut inner_unhealthy_event = unhealthy_event.lock().await; match event { None => { - *inner_unhealthy_event = Some(None); + *inner_unhealthy_event = Unhealthy::StreamClosed; tracing::info!("ProcMesh {}: alloc has stopped", world_id); break; } @@ -189,7 +190,7 @@ impl PyProcMesh { // Graceful stops can be ignored. ProcEvent::Stopped(_, ProcStopReason::Stopped) => continue, event => { - *inner_unhealthy_event = Some(Some(event.clone())); + *inner_unhealthy_event = Unhealthy::Crashed(event.clone()); tracing::info!("ProcMesh {}: {}", world_id, event); if user_monitor_registered.load(std::sync::atomic::Ordering::SeqCst) { if user_sender.send(event).is_err() { @@ -202,7 +203,7 @@ impl PyProcMesh { } _ = events.preempted() => { let mut inner_unhealthy_event = unhealthy_event.lock().await; - *inner_unhealthy_event = Some(None); + *inner_unhealthy_event = Unhealthy::StreamClosed; tracing::info!("ProcMesh {}: is stopped", world_id); break; } @@ -243,19 +244,18 @@ impl PyProcMesh { } } -// Return with error if the mesh is unhealthy. -async fn ensure_mesh_healthy( - unhealthy_event: &Mutex>>, -) -> Result<(), PyErr> { +async fn ensure_mesh_healthy(unhealthy_event: &Mutex>) -> Result<(), PyErr> { let locked = unhealthy_event.lock().await; - if let Some(event) = &*locked { - let msg = match event { - Some(e) => format!("proc mesh is stopped with reason: {:?}", e), - None => "proc mesh is stopped with reason: alloc is stopped".to_string(), - }; - return Err(SupervisionError::new_err(msg)); + match &*locked { + Unhealthy::SoFarSoGood => Ok(()), + Unhealthy::StreamClosed => Err(SupervisionError::new_err( + "proc mesh is stopped with reason: alloc is stopped".to_string(), + )), + Unhealthy::Crashed(event) => Err(SupervisionError::new_err(format!( + "proc mesh is stopped with reason: {:?}", + event + ))), } - Ok(()) } #[pymethods] diff --git a/monarch_hyperactor/src/supervision.rs b/monarch_hyperactor/src/supervision.rs index f3323759e..9ffe92873 100644 --- a/monarch_hyperactor/src/supervision.rs +++ b/monarch_hyperactor/src/supervision.rs @@ -23,3 +23,23 @@ pub fn register_python_bindings(module: &Bound<'_, PyModule>) -> PyResult<()> { module.add("SupervisionError", py.get_type::())?; Ok(()) } + +// Shared between mesh types. +#[derive(Debug, Clone)] +pub(crate) enum Unhealthy { + SoFarSoGood, // Still healthy + StreamClosed, // Event stream closed + Crashed(Event), // Bad health event received +} + +impl Unhealthy { + #[allow(dead_code)] // No uses yet. + pub(crate) fn is_healthy(&self) -> bool { + matches!(self, Unhealthy::SoFarSoGood) + } + + #[allow(dead_code)] // No uses yet. + pub(crate) fn is_crashed(&self) -> bool { + matches!(self, Unhealthy::Crashed(_)) + } +}