Skip to content

Commit 6df1873

Browse files
committed
Use async functions in sync context to avoid code duplication.
Pull Request resolved: #565 We currently have a lot of code duplication for blocking/nonblocking variants. This PR lets the blocking variants be defined in terms of the async ones. It first directly exposes the tokio Future object to python, letting it choose to synchronously block on it, or turn it into a asyncio.Future. Then when we are running async code to do pipe receives in a sync context, we can skip running a real event loop and just step the coroutine manually. See [avoiding async code duplication] in the diff for a description of the approach. Currently there is jank around the fact that we are running synchronous actor code on a thread that already has a running event loop. To indicate to the synchronous actor that it is ok to `get()` despite this setup, we temporarily blank out the running asyncio loop while a asynchronous actor is running a message. We need to eventually figure out a way to get the synchronous actors running on a thread that truly is not the asyncio event loop. ghstack-source-id: 297098923 Differential Revision: [D78466520](https://our.internmc.facebook.com/intern/diff/D78466520/)
1 parent cef93c7 commit 6df1873

File tree

9 files changed

+224
-129
lines changed

9 files changed

+224
-129
lines changed

monarch_hyperactor/src/actor.rs

Lines changed: 76 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,11 @@ use monarch_types::SerializablePyErr;
3333
use pyo3::conversion::IntoPyObjectExt;
3434
use pyo3::exceptions::PyBaseException;
3535
use pyo3::exceptions::PyRuntimeError;
36+
use pyo3::exceptions::PyStopIteration;
3637
use pyo3::prelude::*;
3738
use pyo3::types::PyBytes;
3839
use pyo3::types::PyDict;
40+
use pyo3::types::PyIterator;
3941
use pyo3::types::PyList;
4042
use pyo3::types::PyType;
4143
use serde::Deserialize;
@@ -616,18 +618,19 @@ impl Handler<PythonMessage> for PythonActor {
616618
/// Helper struct to make a Python future passable in an actor message.
617619
///
618620
/// Also so that we don't have to write this massive type signature everywhere
619-
struct PythonTask {
621+
622+
pub(crate) struct PythonTask {
620623
future: Mutex<Pin<Box<dyn Future<Output = PyResult<PyObject>> + Send + 'static>>>,
621624
}
622625

623626
impl PythonTask {
624-
fn new(fut: impl Future<Output = PyResult<PyObject>> + Send + 'static) -> Self {
627+
pub(crate) fn new(fut: impl Future<Output = PyResult<PyObject>> + Send + 'static) -> Self {
625628
Self {
626629
future: Mutex::new(Box::pin(fut)),
627630
}
628631
}
629632

630-
async fn take(self) -> Pin<Box<dyn Future<Output = PyResult<PyObject>> + Send + 'static>> {
633+
fn take(self) -> Pin<Box<dyn Future<Output = PyResult<PyObject>> + Send + 'static>> {
631634
self.future.into_inner()
632635
}
633636
}
@@ -640,6 +643,74 @@ impl fmt::Debug for PythonTask {
640643
}
641644
}
642645

646+
#[pyclass(
647+
name = "PythonTask",
648+
module = "monarch._rust_bindings.monarch_hyperactor.actor"
649+
)]
650+
pub(crate) struct PyPythonTask {
651+
inner: Option<PythonTask>,
652+
}
653+
654+
impl From<PythonTask> for PyPythonTask {
655+
fn from(task: PythonTask) -> Self {
656+
Self { inner: Some(task) }
657+
}
658+
}
659+
660+
#[pyclass(
661+
name = "JustStopWithValueIterator",
662+
module = "monarch._rust_bindings.monarch_hyperactor.actor"
663+
)]
664+
struct JustStopWithValueIterator {
665+
value: Option<PyObject>,
666+
}
667+
668+
#[pymethods]
669+
impl JustStopWithValueIterator {
670+
fn __next__<'py>(&mut self) -> PyResult<PyObject> {
671+
Err(PyStopIteration::new_err(self.value.take().unwrap()))
672+
}
673+
}
674+
675+
#[pymethods]
676+
impl PyPythonTask {
677+
fn into_future(&mut self, py: Python<'_>) -> PyResult<PyObject> {
678+
let task = self
679+
.inner
680+
.take()
681+
.map(|task| task.take())
682+
.expect("PythonTask already consumed");
683+
Ok(pyo3_async_runtimes::tokio::future_into_py(py, task)?.unbind())
684+
}
685+
fn block_on(&mut self, py: Python<'_>) -> PyResult<PyObject> {
686+
let task = self
687+
.inner
688+
.take()
689+
.map(|task| task.take())
690+
.expect("PythonTask already consumed");
691+
signal_safe_block_on(py, task)?
692+
}
693+
694+
/// In an async context this turns the tokio::Future into
695+
/// an asyncio Future and awaits it.
696+
/// In a synchronous context, this just blocks on the future and
697+
/// immediately returns the value without pausing caller coroutine.
698+
/// See [avoiding async code duplication] for justitifcation.
699+
fn __await__(&mut self, py: Python<'_>) -> PyResult<PyObject> {
700+
let lp = py
701+
.import("asyncio.events")
702+
.unwrap()
703+
.call_method0("_get_running_loop")
704+
.unwrap();
705+
if lp.is_none() {
706+
let value = self.block_on(py)?;
707+
Ok(JustStopWithValueIterator { value: Some(value) }.into_py_any(py)?)
708+
} else {
709+
self.into_future(py)?.call_method0(py, "__await__")
710+
}
711+
}
712+
}
713+
643714
async fn handle_async_endpoint_panic(
644715
panic_sender: UnboundedSender<anyhow::Result<(), SerializablePyErr>>,
645716
task: PythonTask,
@@ -664,7 +735,7 @@ async fn handle_async_endpoint_panic(
664735
Err(_) => pending().await,
665736
}
666737
};
667-
let future = task.take().await;
738+
let future = task.take();
668739
let result: anyhow::Result<(), SerializablePyErr> = tokio::select! {
669740
result = future => {
670741
match result {
@@ -689,6 +760,7 @@ pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResul
689760
hyperactor_mod.add_class::<PythonMessageKind>()?;
690761
hyperactor_mod.add_class::<UnflattenArg>()?;
691762
hyperactor_mod.add_class::<PanicFlag>()?;
763+
hyperactor_mod.add_class::<PyPythonTask>()?;
692764
Ok(())
693765
}
694766

monarch_hyperactor/src/actor_mesh.rs

Lines changed: 15 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ use hyperactor_mesh::actor_mesh::ActorSupervisionEvents;
2020
use hyperactor_mesh::reference::ActorMeshRef;
2121
use hyperactor_mesh::shared_cell::SharedCell;
2222
use hyperactor_mesh::shared_cell::SharedCellRef;
23+
use pyo3::IntoPyObjectExt;
2324
use pyo3::exceptions::PyEOFError;
2425
use pyo3::exceptions::PyException;
2526
use pyo3::exceptions::PyNotImplementedError;
@@ -33,8 +34,10 @@ use serde::Deserialize;
3334
use serde::Serialize;
3435
use tokio::sync::Mutex;
3536

37+
use crate::actor::PyPythonTask;
3638
use crate::actor::PythonActor;
3739
use crate::actor::PythonMessage;
40+
use crate::actor::PythonTask;
3841
use crate::mailbox::PyMailbox;
3942
use crate::mailbox::PythonOncePortReceiver;
4043
use crate::mailbox::PythonPortReceiver;
@@ -413,36 +416,21 @@ impl MonitoredPythonPortReceiver {
413416
}
414417
}
415418

416-
fn recv<'py>(&mut self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
419+
fn recv_task<'py>(&mut self) -> PyPythonTask {
417420
let receiver = self.inner.clone();
418421
let monitor = self.monitor.clone();
419-
pyo3_async_runtimes::tokio::future_into_py(py, async move {
422+
PythonTask::new(async move {
420423
let mut receiver = receiver.lock().await;
421-
tokio::select! {
424+
let result = tokio::select! {
422425
result = receiver.recv() => {
423426
result.map_err(|err| PyErr::new::<PyEOFError, _>(format!("port closed: {}", err)))
424427
}
425428
event = monitor.next() => {
426429
Err(PyErr::new::<SupervisionError, _>(format!("supervision error: {:?}", event.unwrap())))
427430
}
428-
}
429-
})
430-
}
431-
432-
fn blocking_recv<'py>(&mut self, py: Python<'py>) -> PyResult<PythonMessage> {
433-
let receiver = self.inner.clone();
434-
let monitor = self.monitor.clone();
435-
signal_safe_block_on(py, async move {
436-
let mut receiver = receiver.lock().await;
437-
tokio::select! {
438-
result = receiver.recv() => {
439-
result.map_err(|err| PyErr::new::<PyEOFError, _>(format!("port closed: {}", err)))
440-
}
441-
event = monitor.next() => {
442-
Err(PyErr::new::<SupervisionError, _>(format!("supervision error: {:?}", event.unwrap())))
443-
}
444-
}
445-
})?
431+
};
432+
result.and_then(|message: PythonMessage| Python::with_gil(|py| message.into_py_any(py)))
433+
}).into()
446434
}
447435
}
448436

@@ -466,38 +454,22 @@ impl MonitoredPythonOncePortReceiver {
466454
}
467455
}
468456

469-
fn recv<'py>(&mut self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
457+
fn recv_task<'py>(&mut self) -> PyResult<PyPythonTask> {
470458
let Some(receiver) = self.inner.lock().unwrap().take() else {
471459
return Err(PyErr::new::<PyValueError, _>("OncePort is already used"));
472460
};
473461
let monitor = self.monitor.clone();
474-
pyo3_async_runtimes::tokio::future_into_py(py, async move {
475-
tokio::select! {
462+
Ok(PythonTask::new(async move {
463+
let result = tokio::select! {
476464
result = receiver.recv() => {
477465
result.map_err(|err| PyErr::new::<PyEOFError, _>(format!("port closed: {}", err)))
478466
}
479467
event = monitor.next() => {
480468
Err(PyErr::new::<SupervisionError, _>(format!("supervision error: {:?}", event.unwrap())))
481469
}
482-
}
483-
})
484-
}
485-
486-
fn blocking_recv<'py>(&mut self, py: Python<'py>) -> PyResult<PythonMessage> {
487-
let Some(receiver) = self.inner.lock().unwrap().take() else {
488-
return Err(PyErr::new::<PyValueError, _>("OncePort is already used"));
489-
};
490-
let monitor = self.monitor.clone();
491-
signal_safe_block_on(py, async move {
492-
tokio::select! {
493-
result = receiver.recv() => {
494-
result.map_err(|err| PyErr::new::<PyEOFError, _>(format!("port closed: {}", err)))
495-
}
496-
event = monitor.next() => {
497-
Err(PyErr::new::<SupervisionError, _>(format!("supervision error: {:?}", event.unwrap())))
498-
}
499-
}
500-
})?
470+
};
471+
result.and_then(|message: PythonMessage| Python::with_gil(|py| message.into_py_any(py)))
472+
}).into())
501473
}
502474
}
503475

monarch_hyperactor/src/mailbox.rs

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9+
use std::future::Future;
910
use std::hash::DefaultHasher;
1011
use std::hash::Hash;
1112
use std::hash::Hasher;
@@ -45,9 +46,12 @@ use pyo3::types::PyTuple;
4546
use pyo3::types::PyType;
4647
use serde::Deserialize;
4748
use serde::Serialize;
49+
use tokio::sync::mpsc::UnboundedReceiver;
4850

51+
use crate::actor::PyPythonTask;
4952
use crate::actor::PythonMessage;
5053
use crate::actor::PythonMessageKind;
54+
use crate::actor::PythonTask;
5155
use crate::proc::PyActorId;
5256
use crate::runtime::signal_safe_block_on;
5357
use crate::shape::PyShape;
@@ -380,23 +384,23 @@ pub(super) struct PythonPortReceiver {
380384
inner: Arc<tokio::sync::Mutex<PortReceiver<PythonMessage>>>,
381385
}
382386

387+
async fn recv_async(
388+
receiver: Arc<tokio::sync::Mutex<PortReceiver<PythonMessage>>>,
389+
) -> PyResult<PyObject> {
390+
receiver
391+
.lock()
392+
.await
393+
.recv()
394+
.await
395+
.map_err(|err| PyErr::new::<PyEOFError, _>(format!("Port closed: {}", err)))
396+
.and_then(|message| Python::with_gil(|py| message.into_py_any(py)))
397+
}
398+
383399
#[pymethods]
384400
impl PythonPortReceiver {
385-
fn recv<'py>(&mut self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
401+
fn recv_task<'py>(&mut self) -> PyPythonTask {
386402
let receiver = self.inner.clone();
387-
pyo3_async_runtimes::tokio::future_into_py(py, async move {
388-
receiver
389-
.lock()
390-
.await
391-
.recv()
392-
.await
393-
.map_err(|err| PyErr::new::<PyEOFError, _>(format!("Port closed: {}", err)))
394-
})
395-
}
396-
fn blocking_recv<'py>(&mut self, py: Python<'py>) -> PyResult<PythonMessage> {
397-
let receiver = self.inner.clone();
398-
signal_safe_block_on(py, async move { receiver.lock().await.recv().await })?
399-
.map_err(|err| PyErr::new::<PyEOFError, _>(format!("Port closed: {}", err)))
403+
PythonTask::new(recv_async(receiver)).into()
400404
}
401405
}
402406

@@ -551,24 +555,18 @@ pub(super) struct PythonOncePortReceiver {
551555

552556
#[pymethods]
553557
impl PythonOncePortReceiver {
554-
fn recv<'py>(&mut self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
558+
fn recv_task<'py>(&mut self) -> PyResult<PyPythonTask> {
555559
let Some(receiver) = self.inner.lock().unwrap().take() else {
556560
return Err(PyErr::new::<PyValueError, _>("OncePort is already used"));
557561
};
558-
559-
pyo3_async_runtimes::tokio::future_into_py(py, async move {
562+
let fut = async move {
560563
receiver
561564
.recv()
562565
.await
563566
.map_err(|err| PyErr::new::<PyEOFError, _>(format!("Port closed: {}", err)))
564-
})
565-
}
566-
fn blocking_recv<'py>(&mut self, py: Python<'py>) -> PyResult<PythonMessage> {
567-
let Some(receiver) = self.inner.lock().unwrap().take() else {
568-
return Err(PyErr::new::<PyValueError, _>("OncePort is already used"));
567+
.and_then(|message| Python::with_gil(|py| message.into_py_any(py)))
569568
};
570-
signal_safe_block_on(py, async move { receiver.recv().await })?
571-
.map_err(|err| PyErr::new::<PyEOFError, _>(format!("Port closed: {}", err)))
569+
Ok(PythonTask::new(fut).into())
572570
}
573571
}
574572

@@ -710,6 +708,5 @@ pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResul
710708
hyperactor_mod.add_class::<PythonOncePortHandle>()?;
711709
hyperactor_mod.add_class::<PythonOncePortRef>()?;
712710
hyperactor_mod.add_class::<PythonOncePortReceiver>()?;
713-
714711
Ok(())
715712
}

python/monarch/_rust_bindings/monarch_hyperactor/actor_mesh.pyi

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ from monarch._rust_bindings.monarch_hyperactor.mailbox import (
1313
Mailbox,
1414
OncePortReceiver,
1515
PortReceiver,
16+
PortReceiverBase,
1617
)
1718
from monarch._rust_bindings.monarch_hyperactor.proc import ActorId
1819
from monarch._rust_bindings.monarch_hyperactor.selection import Selection
@@ -128,7 +129,7 @@ class ActorMeshMonitor:
128129
...
129130

130131
@final
131-
class MonitoredPortReceiver:
132+
class MonitoredPortReceiver(PortReceiverBase):
132133
"""
133134
A monitored receiver to which PythonMessages are sent.
134135
"""
@@ -139,15 +140,8 @@ class MonitoredPortReceiver:
139140
"""
140141
...
141142

142-
async def recv(self) -> PythonMessage:
143-
"""Receive a PythonMessage from the port's sender."""
144-
...
145-
def blocking_recv(self) -> PythonMessage:
146-
"""Receive a single PythonMessage from the port's sender."""
147-
...
148-
149143
@final
150-
class MonitoredOncePortReceiver:
144+
class MonitoredOncePortReceiver(PortReceiverBase):
151145
"""
152146
A variant of monitored PortReceiver that can only receive a single message.
153147
"""
@@ -158,13 +152,6 @@ class MonitoredOncePortReceiver:
158152
"""
159153
...
160154

161-
async def recv(self) -> PythonMessage:
162-
"""Receive a single PythonMessage from the port's sender."""
163-
...
164-
def blocking_recv(self) -> PythonMessage:
165-
"""Receive a single PythonMessage from the port's sender."""
166-
...
167-
168155
@final
169156
class ActorSupervisionEvent:
170157
@property

0 commit comments

Comments
 (0)