Skip to content

Commit 168b40a

Browse files
committed
Use async functions in sync context to avoid code duplication.
We currently have a lot of code duplication for blocking/nonblocking variants. This PR lets the blocking variants be defined in terms of the async ones. It first directly exposes the tokio Future object to python, letting it choose to synchronously block on it, or turn it into a asyncio.Future. Then when we are running async code to do pipe receives in a sync context, we can skip running a real event loop and just step the coroutine manually. TODO: expand writeup. Currently there is jank around the fact that we are running synchronous actor code on a thread that already has a running event loop. I put in a workaround with a comment, but we should resolve this by not running the sync code on the asyncio loop. Differential Revision: [D78466520](https://our.internmc.facebook.com/intern/diff/D78466520/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D78466520/)! ghstack-source-id: 296756465 Pull Request resolved: #565
1 parent b926ee4 commit 168b40a

File tree

7 files changed

+169
-125
lines changed

7 files changed

+169
-125
lines changed

monarch_hyperactor/src/actor.rs

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -480,18 +480,19 @@ impl Handler<PythonMessage> for PythonActor {
480480
/// Helper struct to make a Python future passable in an actor message.
481481
///
482482
/// Also so that we don't have to write this massive type signature everywhere
483-
struct PythonTask {
483+
484+
pub(crate) struct PythonTask {
484485
future: Mutex<Pin<Box<dyn Future<Output = PyResult<PyObject>> + Send + 'static>>>,
485486
}
486487

487488
impl PythonTask {
488-
fn new(fut: impl Future<Output = PyResult<PyObject>> + Send + 'static) -> Self {
489+
pub(crate) fn new(fut: impl Future<Output = PyResult<PyObject>> + Send + 'static) -> Self {
489490
Self {
490491
future: Mutex::new(Box::pin(fut)),
491492
}
492493
}
493494

494-
async fn take(self) -> Pin<Box<dyn Future<Output = PyResult<PyObject>> + Send + 'static>> {
495+
fn take(self) -> Pin<Box<dyn Future<Output = PyResult<PyObject>> + Send + 'static>> {
495496
self.future.into_inner()
496497
}
497498
}
@@ -504,6 +505,40 @@ impl fmt::Debug for PythonTask {
504505
}
505506
}
506507

508+
#[pyclass(
509+
name = "PythonTask",
510+
module = "monarch._rust_bindings.monarch_hyperactor.actor"
511+
)]
512+
pub(crate) struct PyPythonTask {
513+
inner: Option<PythonTask>,
514+
}
515+
516+
impl From<PythonTask> for PyPythonTask {
517+
fn from(task: PythonTask) -> Self {
518+
Self { inner: Some(task) }
519+
}
520+
}
521+
522+
#[pymethods]
523+
impl PyPythonTask {
524+
fn into_future(&mut self, py: Python) -> PyResult<PyObject> {
525+
let task = self
526+
.inner
527+
.take()
528+
.map(|task| task.take())
529+
.expect("PythonTask already consumed");
530+
Ok(pyo3_async_runtimes::tokio::future_into_py(py, task)?.unbind())
531+
}
532+
fn block_on(&mut self, py: Python) -> PyResult<PyObject> {
533+
let task = self
534+
.inner
535+
.take()
536+
.map(|task| task.take())
537+
.expect("PythonTask already consumed");
538+
signal_safe_block_on(py, task)?
539+
}
540+
}
541+
507542
/// An ['Actor'] used to monitor the result of an async endpoint. We use an
508543
/// actor so that:
509544
/// - Actually waiting on the async endpoint can happen concurrently with other endpoints.
@@ -557,7 +592,7 @@ impl AsyncEndpointInvocationHandler for AsyncEndpointTask {
557592
Err(_) => pending().await,
558593
}
559594
};
560-
let future = task.take().await;
595+
let future = task.take();
561596
let result: Result<(), SerializablePyErr> = tokio::select! {
562597
result = future => {
563598
match result {
@@ -595,6 +630,7 @@ pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResul
595630
hyperactor_mod.add_class::<PythonMessage>()?;
596631
hyperactor_mod.add_class::<PythonMessageKind>()?;
597632
hyperactor_mod.add_class::<PanicFlag>()?;
633+
hyperactor_mod.add_class::<PyPythonTask>()?;
598634
Ok(())
599635
}
600636

monarch_hyperactor/src/actor_mesh.rs

Lines changed: 15 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ use hyperactor_mesh::actor_mesh::ActorSupervisionEvents;
2020
use hyperactor_mesh::reference::ActorMeshRef;
2121
use hyperactor_mesh::shared_cell::SharedCell;
2222
use hyperactor_mesh::shared_cell::SharedCellRef;
23+
use pyo3::IntoPyObjectExt;
2324
use pyo3::exceptions::PyEOFError;
2425
use pyo3::exceptions::PyException;
2526
use pyo3::exceptions::PyNotImplementedError;
@@ -31,8 +32,10 @@ use serde::Deserialize;
3132
use serde::Serialize;
3233
use tokio::sync::Mutex;
3334

35+
use crate::actor::PyPythonTask;
3436
use crate::actor::PythonActor;
3537
use crate::actor::PythonMessage;
38+
use crate::actor::PythonTask;
3639
use crate::mailbox::PyMailbox;
3740
use crate::mailbox::PythonOncePortReceiver;
3841
use crate::mailbox::PythonPortReceiver;
@@ -332,36 +335,21 @@ impl MonitoredPythonPortReceiver {
332335
}
333336
}
334337

335-
fn recv<'py>(&mut self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
338+
fn recv_task<'py>(&mut self) -> PyPythonTask {
336339
let receiver = self.inner.clone();
337340
let monitor = self.monitor.clone();
338-
pyo3_async_runtimes::tokio::future_into_py(py, async move {
341+
PythonTask::new(async move {
339342
let mut receiver = receiver.lock().await;
340-
tokio::select! {
343+
let result = tokio::select! {
341344
result = receiver.recv() => {
342345
result.map_err(|err| PyErr::new::<PyEOFError, _>(format!("port closed: {}", err)))
343346
}
344347
event = monitor.next() => {
345348
Err(PyErr::new::<SupervisionError, _>(format!("supervision error: {:?}", event.unwrap())))
346349
}
347-
}
348-
})
349-
}
350-
351-
fn blocking_recv<'py>(&mut self, py: Python<'py>) -> PyResult<PythonMessage> {
352-
let receiver = self.inner.clone();
353-
let monitor = self.monitor.clone();
354-
signal_safe_block_on(py, async move {
355-
let mut receiver = receiver.lock().await;
356-
tokio::select! {
357-
result = receiver.recv() => {
358-
result.map_err(|err| PyErr::new::<PyEOFError, _>(format!("port closed: {}", err)))
359-
}
360-
event = monitor.next() => {
361-
Err(PyErr::new::<SupervisionError, _>(format!("supervision error: {:?}", event.unwrap())))
362-
}
363-
}
364-
})?
350+
};
351+
result.and_then(|message: PythonMessage| Python::with_gil(|py| message.into_py_any(py)))
352+
}).into()
365353
}
366354
}
367355

@@ -385,38 +373,22 @@ impl MonitoredPythonOncePortReceiver {
385373
}
386374
}
387375

388-
fn recv<'py>(&mut self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
376+
fn recv_task<'py>(&mut self) -> PyResult<PyPythonTask> {
389377
let Some(receiver) = self.inner.lock().unwrap().take() else {
390378
return Err(PyErr::new::<PyValueError, _>("OncePort is already used"));
391379
};
392380
let monitor = self.monitor.clone();
393-
pyo3_async_runtimes::tokio::future_into_py(py, async move {
394-
tokio::select! {
381+
Ok(PythonTask::new(async move {
382+
let result = tokio::select! {
395383
result = receiver.recv() => {
396384
result.map_err(|err| PyErr::new::<PyEOFError, _>(format!("port closed: {}", err)))
397385
}
398386
event = monitor.next() => {
399387
Err(PyErr::new::<SupervisionError, _>(format!("supervision error: {:?}", event.unwrap())))
400388
}
401-
}
402-
})
403-
}
404-
405-
fn blocking_recv<'py>(&mut self, py: Python<'py>) -> PyResult<PythonMessage> {
406-
let Some(receiver) = self.inner.lock().unwrap().take() else {
407-
return Err(PyErr::new::<PyValueError, _>("OncePort is already used"));
408-
};
409-
let monitor = self.monitor.clone();
410-
signal_safe_block_on(py, async move {
411-
tokio::select! {
412-
result = receiver.recv() => {
413-
result.map_err(|err| PyErr::new::<PyEOFError, _>(format!("port closed: {}", err)))
414-
}
415-
event = monitor.next() => {
416-
Err(PyErr::new::<SupervisionError, _>(format!("supervision error: {:?}", event.unwrap())))
417-
}
418-
}
419-
})?
389+
};
390+
result.and_then(|message: PythonMessage| Python::with_gil(|py| message.into_py_any(py)))
391+
}).into())
420392
}
421393
}
422394

monarch_hyperactor/src/mailbox.rs

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9+
use std::future::Future;
910
use std::hash::DefaultHasher;
1011
use std::hash::Hash;
1112
use std::hash::Hasher;
@@ -45,9 +46,12 @@ use pyo3::types::PyTuple;
4546
use pyo3::types::PyType;
4647
use serde::Deserialize;
4748
use serde::Serialize;
49+
use tokio::sync::mpsc::UnboundedReceiver;
4850

51+
use crate::actor::PyPythonTask;
4952
use crate::actor::PythonMessage;
5053
use crate::actor::PythonMessageKind;
54+
use crate::actor::PythonTask;
5155
use crate::proc::PyActorId;
5256
use crate::runtime::signal_safe_block_on;
5357
use crate::shape::PyShape;
@@ -374,23 +378,23 @@ pub(super) struct PythonPortReceiver {
374378
inner: Arc<tokio::sync::Mutex<PortReceiver<PythonMessage>>>,
375379
}
376380

381+
async fn recv_async(
382+
receiver: Arc<tokio::sync::Mutex<PortReceiver<PythonMessage>>>,
383+
) -> PyResult<PyObject> {
384+
receiver
385+
.lock()
386+
.await
387+
.recv()
388+
.await
389+
.map_err(|err| PyErr::new::<PyEOFError, _>(format!("Port closed: {}", err)))
390+
.and_then(|message| Python::with_gil(|py| message.into_py_any(py)))
391+
}
392+
377393
#[pymethods]
378394
impl PythonPortReceiver {
379-
fn recv<'py>(&mut self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
395+
fn recv_task<'py>(&mut self) -> PyPythonTask {
380396
let receiver = self.inner.clone();
381-
pyo3_async_runtimes::tokio::future_into_py(py, async move {
382-
receiver
383-
.lock()
384-
.await
385-
.recv()
386-
.await
387-
.map_err(|err| PyErr::new::<PyEOFError, _>(format!("Port closed: {}", err)))
388-
})
389-
}
390-
fn blocking_recv<'py>(&mut self, py: Python<'py>) -> PyResult<PythonMessage> {
391-
let receiver = self.inner.clone();
392-
signal_safe_block_on(py, async move { receiver.lock().await.recv().await })?
393-
.map_err(|err| PyErr::new::<PyEOFError, _>(format!("Port closed: {}", err)))
397+
PythonTask::new(recv_async(receiver)).into()
394398
}
395399
}
396400

@@ -545,24 +549,18 @@ pub(super) struct PythonOncePortReceiver {
545549

546550
#[pymethods]
547551
impl PythonOncePortReceiver {
548-
fn recv<'py>(&mut self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
552+
fn recv_task<'py>(&mut self) -> PyResult<PyPythonTask> {
549553
let Some(receiver) = self.inner.lock().unwrap().take() else {
550554
return Err(PyErr::new::<PyValueError, _>("OncePort is already used"));
551555
};
552-
553-
pyo3_async_runtimes::tokio::future_into_py(py, async move {
556+
let fut = async move {
554557
receiver
555558
.recv()
556559
.await
557560
.map_err(|err| PyErr::new::<PyEOFError, _>(format!("Port closed: {}", err)))
558-
})
559-
}
560-
fn blocking_recv<'py>(&mut self, py: Python<'py>) -> PyResult<PythonMessage> {
561-
let Some(receiver) = self.inner.lock().unwrap().take() else {
562-
return Err(PyErr::new::<PyValueError, _>("OncePort is already used"));
561+
.and_then(|message| Python::with_gil(|py| message.into_py_any(py)))
563562
};
564-
signal_safe_block_on(py, async move { receiver.recv().await })?
565-
.map_err(|err| PyErr::new::<PyEOFError, _>(format!("Port closed: {}", err)))
563+
Ok(PythonTask::new(fut).into())
566564
}
567565
}
568566

@@ -704,6 +702,5 @@ pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResul
704702
hyperactor_mod.add_class::<PythonOncePortHandle>()?;
705703
hyperactor_mod.add_class::<PythonOncePortRef>()?;
706704
hyperactor_mod.add_class::<PythonOncePortReceiver>()?;
707-
708705
Ok(())
709706
}

python/monarch/_rust_bindings/monarch_hyperactor/actor_mesh.pyi

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ from monarch._rust_bindings.monarch_hyperactor.mailbox import (
1414
Mailbox,
1515
OncePortReceiver,
1616
PortReceiver,
17+
PortReceiverBase,
1718
)
1819
from monarch._rust_bindings.monarch_hyperactor.proc import ActorId
1920
from monarch._rust_bindings.monarch_hyperactor.selection import Selection
@@ -106,7 +107,7 @@ class ActorMeshMonitor:
106107
...
107108

108109
@final
109-
class MonitoredPortReceiver:
110+
class MonitoredPortReceiver(PortReceiverBase):
110111
"""
111112
A monitored receiver to which PythonMessages are sent.
112113
"""
@@ -117,15 +118,8 @@ class MonitoredPortReceiver:
117118
"""
118119
...
119120

120-
async def recv(self) -> PythonMessage:
121-
"""Receive a PythonMessage from the port's sender."""
122-
...
123-
def blocking_recv(self) -> PythonMessage:
124-
"""Receive a single PythonMessage from the port's sender."""
125-
...
126-
127121
@final
128-
class MonitoredOncePortReceiver:
122+
class MonitoredOncePortReceiver(PortReceiverBase):
129123
"""
130124
A variant of monitored PortReceiver that can only receive a single message.
131125
"""
@@ -136,13 +130,6 @@ class MonitoredOncePortReceiver:
136130
"""
137131
...
138132

139-
async def recv(self) -> PythonMessage:
140-
"""Receive a single PythonMessage from the port's sender."""
141-
...
142-
def blocking_recv(self) -> PythonMessage:
143-
"""Receive a single PythonMessage from the port's sender."""
144-
...
145-
146133
@final
147134
class ActorSupervisionEvent:
148135
@property

0 commit comments

Comments
 (0)