diff --git a/monarch_hyperactor/src/actor_mesh.rs b/monarch_hyperactor/src/actor_mesh.rs index 2851cc07e..6d8f30c73 100644 --- a/monarch_hyperactor/src/actor_mesh.rs +++ b/monarch_hyperactor/src/actor_mesh.rs @@ -25,6 +25,7 @@ use pyo3::exceptions::PyEOFError; use pyo3::exceptions::PyException; use pyo3::exceptions::PyNotImplementedError; use pyo3::exceptions::PyRuntimeError; +use pyo3::exceptions::PyTypeError; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; use pyo3::types::PyBytes; @@ -192,14 +193,28 @@ impl PythonActorMesh { .map(PyActorId::from)) } - // Start monitoring the actor mesh by subscribing to its supervision events. For each supervision - // event, it is consumed by PythonActorMesh first, then gets sent to the monitor for user to consume. - fn monitor<'py>(&self, py: Python<'py>) -> PyResult { - let receiver = self.user_monitor_sender.subscribe(); - let monitor_instance = PyActorMeshMonitor { - receiver: SharedCell::from(Mutex::new(receiver)), - }; - monitor_instance.into_py_any(py) + fn supervise(&self, py: Python<'_>, receiver: Bound<'_, PyAny>) -> PyResult { + if let Ok(r) = receiver.extract::>() { + let rx = SupervisedPythonPortReceiver { + inner: r.inner(), + monitor: ActorMeshMonitor { + receiver: SharedCell::from(Mutex::new(self.user_monitor_sender.subscribe())), + }, + }; + rx.into_py_any(py) + } else if let Ok(r) = receiver.extract::>() { + let rx = SupervisedPythonOncePortReceiver { + inner: r.inner(), + monitor: ActorMeshMonitor { + receiver: SharedCell::from(Mutex::new(self.user_monitor_sender.subscribe())), + }, + }; + rx.into_py_any(py) + } else { + Err(PyTypeError::new_err( + "Expected a PortReceiver or OncePortReceiver", + )) + } } #[pyo3(signature = (**kwargs))] @@ -374,84 +389,46 @@ impl Drop for PythonActorMesh { } } -#[pyclass( - name = "ActorMeshMonitor", - module = "monarch._rust_bindings.monarch_hyperactor.actor_mesh" -)] -pub struct PyActorMeshMonitor { +#[derive(Debug, Clone)] +struct ActorMeshMonitor { receiver: SharedCell>>>, } -#[pymethods] -impl PyActorMeshMonitor { - fn __aiter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { - slf - } - - pub fn __anext__(&self, py: Python<'_>) -> PyResult { +impl ActorMeshMonitor { + pub async fn next(&self) -> PyActorSupervisionEvent { let receiver = self.receiver.clone(); - Ok(pyo3_async_runtimes::tokio::future_into_py(py, get_next(receiver))?.into()) - } -} - -impl PyActorMeshMonitor { - pub async fn next(&self) -> PyResult { - get_next(self.receiver.clone()).await - } -} - -impl Clone for PyActorMeshMonitor { - fn clone(&self) -> Self { - Self { - receiver: self.receiver.clone(), + let receiver = receiver + .borrow() + .expect("`Actor mesh receiver` is shutdown"); + let mut receiver = receiver.lock().await; + let event = receiver.recv().await.unwrap(); + match event { + None => PyActorSupervisionEvent { + // Dummy actor as place holder to indicate the whole mesh is stopped + // TODO(albertli): remove this when pushing all supervision logic to rust. + actor_id: id!(default[0].actor[0]).into(), + actor_status: "actor mesh is stopped due to proc mesh shutdown".to_string(), + }, + Some(event) => PyActorSupervisionEvent::from(event.clone()), } } } -async fn get_next( - receiver: SharedCell>>>, -) -> PyResult { - let receiver = receiver.clone(); - - let receiver = receiver - .borrow() - .expect("`Actor mesh receiver` is shutdown"); - let mut receiver = receiver.lock().await; - let event = receiver.recv().await.unwrap(); - - let supervision_event = match event { - None => PyActorSupervisionEvent { - // Dummy actor as place holder to indicate the whole mesh is stopped - // TODO(albertli): remove this when pushing all supervision logic to rust. - actor_id: id!(default[0].actor[0]).into(), - actor_status: "actor mesh is stopped due to proc mesh shutdown".to_string(), - }, - Some(event) => PyActorSupervisionEvent::from(event.clone()), - }; - tracing::info!("recv supervision event: {supervision_event:?}"); - - Python::with_gil(|py| supervision_event.into_py_any(py)) -} - -// TODO(albertli): this is temporary remove this when pushing all supervision logic to rust. +// Values of this type can only be created by calling +// `PythonActorMesh::supervise()`. #[pyclass( - name = "MonitoredPortReceiver", + name = "SupervisedPortReceiver", module = "monarch._rust_bindings.monarch_hyperactor.actor_mesh" )] -pub(super) struct MonitoredPythonPortReceiver { +struct SupervisedPythonPortReceiver { inner: Arc>>, - monitor: PyActorMeshMonitor, + monitor: ActorMeshMonitor, } #[pymethods] -impl MonitoredPythonPortReceiver { - #[new] - fn new(receiver: &PythonPortReceiver, monitor: &PyActorMeshMonitor) -> Self { - let inner = receiver.inner(); - MonitoredPythonPortReceiver { - inner, - monitor: monitor.clone(), - } +impl SupervisedPythonPortReceiver { + fn __repr__(&self) -> &'static str { + "" } fn recv_task(&mut self) -> PyPythonTask { @@ -464,10 +441,8 @@ impl MonitoredPythonPortReceiver { result.map_err(|err| PyErr::new::(format!("port closed: {}", err))) } event = monitor.next() => { - let event = event.expect("supervision event should not be None"); Python::with_gil(|py| { - let e = event.downcast_bound::(py)?; - Err(PyErr::new::(format!("supervision error: {:?}", e))) + Err(PyErr::new::(format!("supervision error: {:?}", event))) }) } }; @@ -476,24 +451,21 @@ impl MonitoredPythonPortReceiver { } } +// Values of this type can only be created by calling +// `PythonActorMesh::supervise()`. #[pyclass( - name = "MonitoredOncePortReceiver", + name = "SupervisedOncePortReceiver", module = "monarch._rust_bindings.monarch_hyperactor.actor_mesh" )] -pub(super) struct MonitoredPythonOncePortReceiver { +struct SupervisedPythonOncePortReceiver { inner: Arc>>>, - monitor: PyActorMeshMonitor, + monitor: ActorMeshMonitor, } #[pymethods] -impl MonitoredPythonOncePortReceiver { - #[new] - fn new(receiver: &PythonOncePortReceiver, monitor: &PyActorMeshMonitor) -> Self { - let inner = receiver.inner(); - MonitoredPythonOncePortReceiver { - inner, - monitor: monitor.clone(), - } +impl SupervisedPythonOncePortReceiver { + fn __repr__(&self) -> &'static str { + "" } fn recv_task(&mut self) -> PyResult { @@ -507,10 +479,8 @@ impl MonitoredPythonOncePortReceiver { result.map_err(|err| PyErr::new::(format!("port closed: {}", err))) } event = monitor.next() => { - let event = event.expect("supervision event should not be None"); Python::with_gil(|py| { - let e = event.downcast_bound::(py)?; - Err(PyErr::new::(format!("supervision error: {:?}", e))) + Err(PyErr::new::(format!("supervision error: {:?}", event))) }) } }; @@ -556,9 +526,8 @@ impl From for PyActorSupervisionEvent { pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResult<()> { hyperactor_mod.add_class::()?; hyperactor_mod.add_class::()?; - hyperactor_mod.add_class::()?; - hyperactor_mod.add_class::()?; - hyperactor_mod.add_class::()?; + hyperactor_mod.add_class::()?; + hyperactor_mod.add_class::()?; hyperactor_mod.add_class::()?; Ok(()) } diff --git a/python/monarch/_rust_bindings/monarch_hyperactor/actor_mesh.pyi b/python/monarch/_rust_bindings/monarch_hyperactor/actor_mesh.pyi index 2813b7e15..75e6bd495 100644 --- a/python/monarch/_rust_bindings/monarch_hyperactor/actor_mesh.pyi +++ b/python/monarch/_rust_bindings/monarch_hyperactor/actor_mesh.pyi @@ -89,10 +89,17 @@ class PythonActorMesh: """ ... - # TODO(albertli): remove this when pushing all supervision logic to Rust - def monitor(self) -> ActorMeshMonitor: - """ - Returns a supervision monitor for this mesh. + def supervise( + self, r: PortReceiver | OncePortReceiver + ) -> SupervisedPortReceiver | SupervisedOncePortReceiver: + """Return a monitored port receiver. + + A monitored port receiver behaves like a regular port receiver + but also observes the health of the actor mesh associated with + the sender. If the actor mesh becomes unhealthy, the receiver + will yield a supervision error instead of waiting indefinitely + for a message. + """ ... @@ -129,42 +136,16 @@ class PythonActorMesh: ... @final -class ActorMeshMonitor: - def __aiter__(self) -> AsyncIterator["ActorSupervisionEvent"]: - """ - Returns an async iterator for this monitor. - """ - ... - - async def __anext__(self) -> "ActorSupervisionEvent": - """ - Returns the next proc event in the proc mesh. - """ - ... - -@final -class MonitoredPortReceiver(PortReceiverBase): - """ - A monitored receiver to which PythonMessages are sent. +class SupervisedPortReceiver(PortReceiverBase): + """A monitored receiver to which PythonMessages are sent. Values + of this type cannot be constructed directly in Python. """ - def __init__(self, receiver: PortReceiver, monitor: ActorMeshMonitor) -> None: - """ - Create a new monitored receiver from a PortReceiver. - """ - ... - @final -class MonitoredOncePortReceiver(PortReceiverBase): +class SupervisedOncePortReceiver(PortReceiverBase): + """A monitored once receiver to which PythonMessages are sent. + Values of this type cannot be constructed directly in Python. """ - A variant of monitored PortReceiver that can only receive a single message. - """ - - def __init__(self, receiver: OncePortReceiver, monitor: ActorMeshMonitor) -> None: - """ - Create a new monitored receiver from a PortReceiver. - """ - ... @final class ActorSupervisionEvent: diff --git a/python/monarch/_src/actor/actor_mesh.py b/python/monarch/_src/actor/actor_mesh.py index b85930ca2..48b28655b 100644 --- a/python/monarch/_src/actor/actor_mesh.py +++ b/python/monarch/_src/actor/actor_mesh.py @@ -51,16 +51,12 @@ PythonMessage, PythonMessageKind, ) -from monarch._rust_bindings.monarch_hyperactor.actor_mesh import ( - ActorMeshMonitor, - MonitoredOncePortReceiver, - MonitoredPortReceiver, - PythonActorMesh, -) +from monarch._rust_bindings.monarch_hyperactor.actor_mesh import PythonActorMesh from monarch._rust_bindings.monarch_hyperactor.mailbox import ( Mailbox, OncePortReceiver, OncePortRef, + PortReceiver as HyPortReceiver, PortRef, ) @@ -319,6 +315,9 @@ def _send( def _port(self, once: bool = False) -> "PortTuple[R]": pass + def _supervise(self, r: HyPortReceiver | OncePortReceiver) -> Any: + return r + # the following are all 'adverbs' or different ways to handle the # return values of this endpoint. Adverbs should only ever take *args, **kwargs # of the original call. If we want to add syntax sugar for something that needs additional @@ -401,6 +400,10 @@ def __init__( self._signature: inspect.Signature = inspect.signature(impl) self._mailbox = mailbox + def _supervise(self, r: HyPortReceiver | OncePortReceiver) -> Any: + mesh = self._actor_mesh._actor_mesh + return r if mesh is None else mesh.supervise(r) + def _send( self, args: Tuple[Any, ...], @@ -432,12 +435,12 @@ def _send( return Extent(shape.labels, shape.ndslice.sizes) def _port(self, once: bool = False) -> "PortTuple[R]": - monitor = ( - None - if self._actor_mesh._actor_mesh is None - else self._actor_mesh._actor_mesh.monitor() - ) - return PortTuple.create(self._mailbox, monitor, once) + p, r = PortTuple.create(self._mailbox, once) + if TYPE_CHECKING: + assert isinstance( + r._receiver, (HyPortReceiver | OncePortReceiver) + ), "unexpected receiver type" + return PortTuple(p, PortReceiver(self._mailbox, self._supervise(r._receiver))) class Accumulator(Generic[P, R, A]): @@ -591,18 +594,9 @@ class PortTuple(NamedTuple, Generic[R]): receiver: "PortReceiver[R]" @staticmethod - def create( - mailbox: Mailbox, monitor: Optional[ActorMeshMonitor], once: bool = False - ) -> "PortTuple[Any]": + def create(mailbox: Mailbox, once: bool = False) -> "PortTuple[Any]": handle, receiver = mailbox.open_once_port() if once else mailbox.open_port() port_ref = handle.bind() - if monitor is not None: - receiver = ( - MonitoredOncePortReceiver(receiver, monitor) - if isinstance(receiver, OncePortReceiver) - else MonitoredPortReceiver(receiver, monitor) - ) - return PortTuple( Port(port_ref, mailbox, rank=None), PortReceiver(mailbox, receiver), @@ -614,18 +608,9 @@ class PortTuple(NamedTuple): receiver: "PortReceiver[Any]" @staticmethod - def create( - mailbox: Mailbox, monitor: Optional[ActorMeshMonitor], once: bool = False - ) -> "PortTuple[Any]": + def create(mailbox: Mailbox, once: bool = False) -> "PortTuple[Any]": handle, receiver = mailbox.open_once_port() if once else mailbox.open_port() port_ref = handle.bind() - if monitor is not None: - receiver = ( - MonitoredOncePortReceiver(receiver, monitor) - if isinstance(receiver, OncePortReceiver) - else MonitoredPortReceiver(receiver, monitor) - ) - return PortTuple( Port(port_ref, mailbox, rank=None), PortReceiver(mailbox, receiver), diff --git a/python/monarch/common/remote.py b/python/monarch/common/remote.py index be5274cfd..1c19a360b 100644 --- a/python/monarch/common/remote.py +++ b/python/monarch/common/remote.py @@ -144,7 +144,7 @@ def _port(self, once: bool = False) -> "PortTuple[R]": "Cannot create raw port objects with an old-style tensor engine controller." ) mailbox: Mailbox = mesh_controller._mailbox - return PortTuple.create(mailbox, None, once) + return PortTuple.create(mailbox, once) @property def _resolvable(self): diff --git a/python/monarch/mesh_controller.py b/python/monarch/mesh_controller.py index 6545a9268..61f3101db 100644 --- a/python/monarch/mesh_controller.py +++ b/python/monarch/mesh_controller.py @@ -149,9 +149,7 @@ def fetch( defs: Tuple["Tensor", ...], uses: Tuple["Tensor", ...], ) -> "OldFuture": # the OldFuture is a lie - sender, receiver = PortTuple.create( - self._mesh_controller._mailbox, None, once=True - ) + sender, receiver = PortTuple.create(self._mesh_controller._mailbox, once=True) ident = self.new_node(defs, uses, cast("OldFuture", sender)) process = mesh._process(shard) @@ -187,9 +185,7 @@ def shutdown( atexit.unregister(self._atexit) self._shutdown = True - sender, receiver = PortTuple.create( - self._mesh_controller._mailbox, None, once=True - ) + sender, receiver = PortTuple.create(self._mesh_controller._mailbox, once=True) assert sender._port_ref is not None self._mesh_controller.sync_at_exit(sender._port_ref.port_id) receiver.recv().get(timeout=60) diff --git a/python/tests/test_actor_error.py b/python/tests/test_actor_error.py index 22dd463cd..c69c2aeae 100644 --- a/python/tests/test_actor_error.py +++ b/python/tests/test_actor_error.py @@ -611,7 +611,7 @@ async def test_supervision_with_sending_error(): with pytest.raises( SupervisionError, match="supervision error:.*message not delivered:" ): - await actor_mesh.check_with_payload.call(payload="a" * 2000000000) + await actor_mesh.check_with_payload.call(payload="a" * 5000000000) # new call should fail with check of health state of actor mesh with pytest.raises(SupervisionError, match="actor mesh is not in a healthy state"): diff --git a/python/tests/test_python_actors.py b/python/tests/test_python_actors.py index a6fe23d5d..5e69caaca 100644 --- a/python/tests/test_python_actors.py +++ b/python/tests/test_python_actors.py @@ -546,7 +546,7 @@ async def send(self, port: Port[int]): def test_port_as_argument(): proc_mesh = local_proc_mesh(gpus=1).get() s = proc_mesh.spawn("send_alot", SendAlot).get() - send, recv = PortTuple.create(proc_mesh._mailbox, None) + send, recv = PortTuple.create(proc_mesh._mailbox) s.send.broadcast(send)