Skip to content

Commit cffa2c2

Browse files
actor: move more port receiver supervision to rust (#578)
Summary: Pull Request resolved: #578 this diff moves supervision logic from python into rust, aligning with the goal of eliminating complex supervision wiring in python. the essential change is that: ``` class ActorEndpoint(...): def _port(self, once: bool = False) -> "PortTuple[R]": monitor = ( None if self._actor_mesh._actor_mesh is None else self._actor_mesh._actor_mesh.monitor() ) return PortTuple.create(self._mailbox, monitor, once) ``` becomes: ``` class ActorEndpoint(...): def _port(self, once: bool = False) -> PortTuple[R]: p, r = PortTuple.create(self._mailbox, once) return PortTuple(p, PortReceiver(self._mailbox, self._supervise(r._receiver))) ``` `_supervise(...)` dispatches to new Rust helpers: ```python mesh.supervise_port(...) and mesh.supervise_once_port(...) ``` which wrap the receivers with supervision logic (including selection between message arrival and supervision events), completely eliminating the need for python-side constructs like `ActorMeshMonitor`. most of the python complexity introduced in D77434080 is removed. the only meaningful addition is `_supervise(...)`, a small overrideable hook that defaults to a no-op and cleanly delegates to rust when supervision is desired. - the creation and wiring of the monitor stream is now fully in rust. - the responsibility of wrapping receivers with supervision is now fully in rust. - python no longer constructs or passes supervision monitors; rust now owns the full wiring, and python receives already-wrapped receivers with supervision behavior embedded this is a strict improvement: lower complexity, cleaner override points and supervision is entirely managed in rust. Differential Revision: D78528860
1 parent 1b0d02f commit cffa2c2

File tree

6 files changed

+99
-154
lines changed

6 files changed

+99
-154
lines changed

monarch_hyperactor/src/actor_mesh.rs

Lines changed: 57 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -177,14 +177,32 @@ impl PythonActorMesh {
177177
.map(PyActorId::from))
178178
}
179179

180-
// Start monitoring the actor mesh by subscribing to its supervision events. For each supervision
181-
// event, it is consumed by PythonActorMesh first, then gets sent to the monitor for user to consume.
182-
fn monitor<'py>(&self, py: Python<'py>) -> PyResult<PyObject> {
183-
let receiver = self.user_monitor_sender.subscribe();
184-
let monitor_instance = PyActorMeshMonitor {
185-
receiver: SharedCell::from(Mutex::new(receiver)),
180+
fn supervise_port<'py>(
181+
&self,
182+
py: Python<'py>,
183+
receiver: &PythonPortReceiver,
184+
) -> PyResult<PyObject> {
185+
let rx = MonitoredPythonPortReceiver {
186+
inner: receiver.inner(),
187+
monitor: ActorMeshMonitor {
188+
receiver: SharedCell::from(Mutex::new(self.user_monitor_sender.subscribe())),
189+
},
186190
};
187-
Ok(monitor_instance.into_py(py))
191+
rx.into_py_any(py)
192+
}
193+
194+
fn supervise_once_port<'py>(
195+
&self,
196+
py: Python<'py>,
197+
receiver: &PythonOncePortReceiver,
198+
) -> PyResult<PyObject> {
199+
let rx = MonitoredPythonOncePortReceiver {
200+
inner: receiver.inner(),
201+
monitor: ActorMeshMonitor {
202+
receiver: SharedCell::from(Mutex::new(self.user_monitor_sender.subscribe())),
203+
},
204+
};
205+
rx.into_py_any(py)
188206
}
189207

190208
#[pyo3(signature = (**kwargs))]
@@ -336,83 +354,46 @@ impl Drop for PythonActorMesh {
336354
}
337355
}
338356

339-
#[pyclass(
340-
name = "ActorMeshMonitor",
341-
module = "monarch._rust_bindings.monarch_hyperactor.actor_mesh"
342-
)]
343-
pub struct PyActorMeshMonitor {
357+
#[derive(Debug, Clone)]
358+
struct ActorMeshMonitor {
344359
receiver: SharedCell<Mutex<tokio::sync::broadcast::Receiver<Option<ActorSupervisionEvent>>>>,
345360
}
346361

347-
#[pymethods]
348-
impl PyActorMeshMonitor {
349-
fn __aiter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
350-
slf
351-
}
352-
353-
pub fn __anext__(&self, py: Python<'_>) -> PyResult<PyObject> {
362+
impl ActorMeshMonitor {
363+
pub async fn next(&self) -> PyActorSupervisionEvent {
354364
let receiver = self.receiver.clone();
355-
Ok(pyo3_async_runtimes::tokio::future_into_py(py, get_next(receiver))?.into())
356-
}
357-
}
358-
359-
impl PyActorMeshMonitor {
360-
pub async fn next(&self) -> PyResult<PyObject> {
361-
get_next(self.receiver.clone()).await
362-
}
363-
}
364-
365-
impl Clone for PyActorMeshMonitor {
366-
fn clone(&self) -> Self {
367-
Self {
368-
receiver: self.receiver.clone(),
365+
let receiver = receiver
366+
.borrow()
367+
.expect("`Actor mesh receiver` is shutdown");
368+
let mut receiver = receiver.lock().await;
369+
let event = receiver.recv().await.unwrap();
370+
match event {
371+
None => PyActorSupervisionEvent {
372+
// Dummy actor as place holder to indicate the whole mesh is stopped
373+
// TODO(albertli): remove this when pushing all supervision logic to rust.
374+
actor_id: id!(default[0].actor[0]).into(),
375+
actor_status: "actor mesh is stopped due to proc mesh shutdown".to_string(),
376+
},
377+
Some(event) => PyActorSupervisionEvent::from(event.clone()),
369378
}
370379
}
371380
}
372381

373-
async fn get_next(
374-
receiver: SharedCell<Mutex<tokio::sync::broadcast::Receiver<Option<ActorSupervisionEvent>>>>,
375-
) -> PyResult<PyObject> {
376-
let receiver = receiver.clone();
377-
378-
let receiver = receiver
379-
.borrow()
380-
.expect("`Actor mesh receiver` is shutdown");
381-
let mut receiver = receiver.lock().await;
382-
let event = receiver.recv().await.unwrap();
383-
384-
let supervision_event = match event {
385-
None => PyActorSupervisionEvent {
386-
// Dummy actor as place holder to indicate the whole mesh is stopped
387-
// TODO(albertli): remove this when pushing all supervision logic to rust.
388-
actor_id: id!(default[0].actor[0]).into(),
389-
actor_status: "actor mesh is stopped due to proc mesh shutdown".to_string(),
390-
},
391-
Some(event) => PyActorSupervisionEvent::from(event.clone()),
392-
};
393-
394-
Python::with_gil(|py| supervision_event.into_py_any(py))
395-
}
396-
397-
// TODO(albertli): this is temporary remove this when pushing all supervision logic to rust.
382+
// Values of this (private) type can only be created by calling
383+
// `PythonActorMesh::supervise_port()`.
398384
#[pyclass(
399385
name = "MonitoredPortReceiver",
400386
module = "monarch._rust_bindings.monarch_hyperactor.actor_mesh"
401387
)]
402-
pub(super) struct MonitoredPythonPortReceiver {
388+
struct MonitoredPythonPortReceiver {
403389
inner: Arc<tokio::sync::Mutex<PortReceiver<PythonMessage>>>,
404-
monitor: PyActorMeshMonitor,
390+
monitor: ActorMeshMonitor,
405391
}
406392

407393
#[pymethods]
408394
impl MonitoredPythonPortReceiver {
409-
#[new]
410-
fn new(receiver: &PythonPortReceiver, monitor: &PyActorMeshMonitor) -> Self {
411-
let inner = receiver.inner();
412-
MonitoredPythonPortReceiver {
413-
inner,
414-
monitor: monitor.clone(),
415-
}
395+
fn __repr__(&self) -> &'static str {
396+
"<MonitoredPortReceiver>"
416397
}
417398

418399
fn recv_task<'py>(&mut self) -> PyPythonTask {
@@ -425,32 +406,29 @@ impl MonitoredPythonPortReceiver {
425406
result.map_err(|err| PyErr::new::<PyEOFError, _>(format!("port closed: {}", err)))
426407
}
427408
event = monitor.next() => {
428-
Err(PyErr::new::<SupervisionError, _>(format!("supervision error: {:?}", event.unwrap())))
409+
Err(PyErr::new::<SupervisionError, _>(format!("supervision error: {:?}", event)))
429410
}
430411
};
431412
result.and_then(|message: PythonMessage| Python::with_gil(|py| message.into_py_any(py)))
432413
}).into()
433414
}
434415
}
435416

417+
// Values of this (private) type can only be created by calling
418+
// `PythonActorMesh::supervise_once_port()`.
436419
#[pyclass(
437420
name = "MonitoredOncePortReceiver",
438421
module = "monarch._rust_bindings.monarch_hyperactor.actor_mesh"
439422
)]
440-
pub(super) struct MonitoredPythonOncePortReceiver {
423+
struct MonitoredPythonOncePortReceiver {
441424
inner: Arc<std::sync::Mutex<Option<OncePortReceiver<PythonMessage>>>>,
442-
monitor: PyActorMeshMonitor,
425+
monitor: ActorMeshMonitor,
443426
}
444427

445428
#[pymethods]
446429
impl MonitoredPythonOncePortReceiver {
447-
#[new]
448-
fn new(receiver: &PythonOncePortReceiver, monitor: &PyActorMeshMonitor) -> Self {
449-
let inner = receiver.inner();
450-
MonitoredPythonOncePortReceiver {
451-
inner,
452-
monitor: monitor.clone(),
453-
}
430+
fn __repr__(&self) -> &'static str {
431+
"<MonitoredOncePortReceiver>"
454432
}
455433

456434
fn recv_task<'py>(&mut self) -> PyResult<PyPythonTask> {
@@ -464,7 +442,7 @@ impl MonitoredPythonOncePortReceiver {
464442
result.map_err(|err| PyErr::new::<PyEOFError, _>(format!("port closed: {}", err)))
465443
}
466444
event = monitor.next() => {
467-
Err(PyErr::new::<SupervisionError, _>(format!("supervision error: {:?}", event.unwrap())))
445+
Err(PyErr::new::<SupervisionError, _>(format!("supervision error: {:?}", event)))
468446
}
469447
};
470448
result.and_then(|message: PythonMessage| Python::with_gil(|py| message.into_py_any(py)))
@@ -476,6 +454,7 @@ impl MonitoredPythonOncePortReceiver {
476454
name = "ActorSupervisionEvent",
477455
module = "monarch._rust_bindings.monarch_hyperactor.actor_mesh"
478456
)]
457+
#[derive(Debug)]
479458
pub struct PyActorSupervisionEvent {
480459
/// Actor ID of the actor where supervision event originates from.
481460
#[pyo3(get)]
@@ -508,7 +487,6 @@ impl From<ActorSupervisionEvent> for PyActorSupervisionEvent {
508487
pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResult<()> {
509488
hyperactor_mod.add_class::<PythonActorMesh>()?;
510489
hyperactor_mod.add_class::<PythonActorMeshRef>()?;
511-
hyperactor_mod.add_class::<PyActorMeshMonitor>()?;
512490
hyperactor_mod.add_class::<MonitoredPythonPortReceiver>()?;
513491
hyperactor_mod.add_class::<MonitoredPythonOncePortReceiver>()?;
514492
hyperactor_mod.add_class::<PyActorSupervisionEvent>()?;

python/monarch/_rust_bindings/monarch_hyperactor/actor_mesh.pyi

Lines changed: 12 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,15 @@ class PythonActorMesh:
8989
"""
9090
...
9191

92-
# TODO(albertli): remove this when pushing all supervision logic to Rust
93-
def monitor(self) -> ActorMeshMonitor:
92+
def supervise_port(self, r: PortReceiver) -> MonitoredPortReceiver:
9493
"""
95-
Returns a supervision monitor for this mesh.
94+
Return a monitored port receiver.
95+
"""
96+
...
97+
98+
def supervise_once_port(self, r: OncePortReceiver) -> MonitoredOncePortReceiver:
99+
"""
100+
Return a monitored once port receiver.
96101
"""
97102
...
98103

@@ -114,43 +119,17 @@ class PythonActorMesh:
114119
"""
115120
...
116121

117-
@final
118-
class ActorMeshMonitor:
119-
def __aiter__(self) -> AsyncIterator["ActorSupervisionEvent"]:
120-
"""
121-
Returns an async iterator for this monitor.
122-
"""
123-
...
124-
125-
async def __anext__(self) -> "ActorSupervisionEvent":
126-
"""
127-
Returns the next proc event in the proc mesh.
128-
"""
129-
...
130-
131122
@final
132123
class MonitoredPortReceiver(PortReceiverBase):
124+
"""A monitored receiver to which PythonMessages are sent. Values
125+
of this type cannot be constructed directly in Python.
133126
"""
134-
A monitored receiver to which PythonMessages are sent.
135-
"""
136-
137-
def __init__(self, receiver: PortReceiver, monitor: ActorMeshMonitor) -> None:
138-
"""
139-
Create a new monitored receiver from a PortReceiver.
140-
"""
141-
...
142127

143128
@final
144129
class MonitoredOncePortReceiver(PortReceiverBase):
130+
"""A monitored once receiver to which PythonMessages are sent.
131+
Values of this type cannot be constructed directly in Python.
145132
"""
146-
A variant of monitored PortReceiver that can only receive a single message.
147-
"""
148-
149-
def __init__(self, receiver: OncePortReceiver, monitor: ActorMeshMonitor) -> None:
150-
"""
151-
Create a new monitored receiver from a PortReceiver.
152-
"""
153-
...
154133

155134
@final
156135
class ActorSupervisionEvent:

python/monarch/_src/actor/actor_mesh.py

Lines changed: 26 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@
4040
Optional,
4141
overload,
4242
ParamSpec,
43+
Protocol,
44+
runtime_checkable,
4345
Sequence,
4446
Tuple,
4547
Type,
@@ -52,16 +54,12 @@
5254
PythonMessage,
5355
PythonMessageKind,
5456
)
55-
from monarch._rust_bindings.monarch_hyperactor.actor_mesh import (
56-
ActorMeshMonitor,
57-
MonitoredOncePortReceiver,
58-
MonitoredPortReceiver,
59-
PythonActorMesh,
60-
)
57+
from monarch._rust_bindings.monarch_hyperactor.actor_mesh import PythonActorMesh
6158
from monarch._rust_bindings.monarch_hyperactor.mailbox import (
6259
Mailbox,
6360
OncePortReceiver,
6461
OncePortRef,
62+
PortReceiver as HyPortReceiver,
6563
PortRef,
6664
)
6765

@@ -311,6 +309,9 @@ def _send(
311309
def _port(self, once: bool = False) -> "PortTuple[R]":
312310
pass
313311

312+
def _supervise(self, r: HyPortReceiver | OncePortReceiver) -> Any:
313+
return r
314+
314315
# the following are all 'adverbs' or different ways to handle the
315316
# return values of this endpoint. Adverbs should only ever take *args, **kwargs
316317
# of the original call. If we want to add syntax sugar for something that needs additional
@@ -393,6 +394,16 @@ def __init__(
393394
self._signature: inspect.Signature = inspect.signature(impl)
394395
self._mailbox = mailbox
395396

397+
def _supervise(self, r: HyPortReceiver | OncePortReceiver) -> Any:
398+
mesh = self._actor_mesh._actor_mesh
399+
if mesh is None:
400+
return r
401+
return (
402+
mesh.supervise_once_port(r)
403+
if isinstance(r, OncePortReceiver)
404+
else mesh.supervise_port(r)
405+
)
406+
396407
def _send(
397408
self,
398409
args: Tuple[Any, ...],
@@ -424,12 +435,12 @@ def _send(
424435
return Extent(shape.labels, shape.ndslice.sizes)
425436

426437
def _port(self, once: bool = False) -> "PortTuple[R]":
427-
monitor = (
428-
None
429-
if self._actor_mesh._actor_mesh is None
430-
else self._actor_mesh._actor_mesh.monitor()
431-
)
432-
return PortTuple.create(self._mailbox, monitor, once)
438+
p, r = PortTuple.create(self._mailbox, once)
439+
if TYPE_CHECKING:
440+
assert isinstance(
441+
r._receiver, (HyPortReceiver | OncePortReceiver)
442+
), "unexpected receiver type"
443+
return PortTuple(p, PortReceiver(self._mailbox, self._supervise(r._receiver)))
433444

434445

435446
class Accumulator(Generic[P, R, A]):
@@ -583,21 +594,11 @@ class PortTuple(NamedTuple, Generic[R]):
583594
receiver: "PortReceiver[R]"
584595

585596
@staticmethod
586-
def create(
587-
mailbox: Mailbox, monitor: Optional[ActorMeshMonitor], once: bool = False
588-
) -> "PortTuple[Any]":
597+
def create(mailbox: Mailbox, once: bool = False) -> "PortTuple[Any]":
589598
handle, receiver = mailbox.open_once_port() if once else mailbox.open_port()
590599
port_ref = handle.bind()
591-
if monitor is not None:
592-
receiver = (
593-
MonitoredOncePortReceiver(receiver, monitor)
594-
if isinstance(receiver, OncePortReceiver)
595-
else MonitoredPortReceiver(receiver, monitor)
596-
)
597-
598600
return PortTuple(
599-
Port(port_ref, mailbox, rank=None),
600-
PortReceiver(mailbox, receiver),
601+
Port(port_ref, mailbox, rank=None), PortReceiver(mailbox, receiver)
601602
)
602603
else:
603604

@@ -606,18 +607,9 @@ class PortTuple(NamedTuple):
606607
receiver: "PortReceiver[Any]"
607608

608609
@staticmethod
609-
def create(
610-
mailbox: Mailbox, monitor: Optional[ActorMeshMonitor], once: bool = False
611-
) -> "PortTuple[Any]":
610+
def create(mailbox: Mailbox, once: bool = False) -> "PortTuple[Any]":
612611
handle, receiver = mailbox.open_once_port() if once else mailbox.open_port()
613612
port_ref = handle.bind()
614-
if monitor is not None:
615-
receiver = (
616-
MonitoredOncePortReceiver(receiver, monitor)
617-
if isinstance(receiver, OncePortReceiver)
618-
else MonitoredPortReceiver(receiver, monitor)
619-
)
620-
621613
return PortTuple(
622614
Port(port_ref, mailbox, rank=None),
623615
PortReceiver(mailbox, receiver),

0 commit comments

Comments
 (0)