Skip to content

Use async functions in sync context to avoid code duplication. #565

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: gh/zdevito/44/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 89 additions & 4 deletions monarch_hyperactor/src/actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ use monarch_types::SerializablePyErr;
use pyo3::conversion::IntoPyObjectExt;
use pyo3::exceptions::PyBaseException;
use pyo3::exceptions::PyRuntimeError;
use pyo3::exceptions::PyStopIteration;
use pyo3::prelude::*;
use pyo3::types::PyBytes;
use pyo3::types::PyDict;
Expand Down Expand Up @@ -616,18 +617,19 @@ impl Handler<PythonMessage> for PythonActor {
/// Helper struct to make a Python future passable in an actor message.
///
/// Also so that we don't have to write this massive type signature everywhere
struct PythonTask {

pub(crate) struct PythonTask {
future: Mutex<Pin<Box<dyn Future<Output = PyResult<PyObject>> + Send + 'static>>>,
}

impl PythonTask {
fn new(fut: impl Future<Output = PyResult<PyObject>> + Send + 'static) -> Self {
pub(crate) fn new(fut: impl Future<Output = PyResult<PyObject>> + Send + 'static) -> Self {
Self {
future: Mutex::new(Box::pin(fut)),
}
}

async fn take(self) -> Pin<Box<dyn Future<Output = PyResult<PyObject>> + Send + 'static>> {
fn take(self) -> Pin<Box<dyn Future<Output = PyResult<PyObject>> + Send + 'static>> {
self.future.into_inner()
}
}
Expand All @@ -640,6 +642,88 @@ impl fmt::Debug for PythonTask {
}
}

#[pyclass(
name = "PythonTask",
module = "monarch._rust_bindings.monarch_hyperactor.actor"
)]
pub struct PyPythonTask {
inner: Option<PythonTask>,
}

impl From<PythonTask> for PyPythonTask {
fn from(task: PythonTask) -> Self {
Self { inner: Some(task) }
}
}

#[pyclass(
name = "JustStopWithValueIterator",
module = "monarch._rust_bindings.monarch_hyperactor.actor"
)]
struct JustStopWithValueIterator {
value: Option<PyObject>,
}

#[pymethods]
impl JustStopWithValueIterator {
fn __next__(&mut self) -> PyResult<PyObject> {
Err(PyStopIteration::new_err(self.value.take().unwrap()))
}
}

impl PyPythonTask {
pub fn new<F, T>(fut: F) -> PyResult<Self>
where
F: Future<Output = PyResult<T>> + Send + 'static,
T: for<'py> IntoPyObject<'py>,
{
Ok(PythonTask::new(async {
fut.await
.and_then(|t| Python::with_gil(|py| t.into_py_any(py)))
})
.into())
}
}

#[pymethods]
impl PyPythonTask {
fn into_future(&mut self, py: Python<'_>) -> PyResult<PyObject> {
let task = self
.inner
.take()
.map(|task| task.take())
.expect("PythonTask already consumed");
Ok(pyo3_async_runtimes::tokio::future_into_py(py, task)?.unbind())
}
fn block_on(&mut self, py: Python<'_>) -> PyResult<PyObject> {
let task = self
.inner
.take()
.map(|task| task.take())
.expect("PythonTask already consumed");
signal_safe_block_on(py, task)?
}

/// In an async context this turns the tokio::Future into
/// an asyncio Future and awaits it.
/// In a synchronous context, this just blocks on the future and
/// immediately returns the value without pausing caller coroutine.
/// See [avoiding async code duplication] for justitifcation.
fn __await__(&mut self, py: Python<'_>) -> PyResult<PyObject> {
let lp = py
.import("asyncio.events")
.unwrap()
.call_method0("_get_running_loop")
.unwrap();
if lp.is_none() {
let value = self.block_on(py)?;
Ok(JustStopWithValueIterator { value: Some(value) }.into_py_any(py)?)
} else {
self.into_future(py)?.call_method0(py, "__await__")
}
}
}

async fn handle_async_endpoint_panic(
panic_sender: UnboundedSender<anyhow::Result<(), SerializablePyErr>>,
task: PythonTask,
Expand All @@ -664,7 +748,7 @@ async fn handle_async_endpoint_panic(
Err(_) => pending().await,
}
};
let future = task.take().await;
let future = task.take();
let result: anyhow::Result<(), SerializablePyErr> = tokio::select! {
result = future => {
match result {
Expand All @@ -689,6 +773,7 @@ pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResul
hyperactor_mod.add_class::<PythonMessageKind>()?;
hyperactor_mod.add_class::<UnflattenArg>()?;
hyperactor_mod.add_class::<PanicFlag>()?;
hyperactor_mod.add_class::<PyPythonTask>()?;
Ok(())
}

Expand Down
58 changes: 14 additions & 44 deletions monarch_hyperactor/src/actor_mesh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,15 @@ use serde::Deserialize;
use serde::Serialize;
use tokio::sync::Mutex;

use crate::actor::PyPythonTask;
use crate::actor::PythonActor;
use crate::actor::PythonMessage;
use crate::actor::PythonTask;
use crate::mailbox::PyMailbox;
use crate::mailbox::PythonOncePortReceiver;
use crate::mailbox::PythonPortReceiver;
use crate::proc::PyActorId;
use crate::proc_mesh::Keepalive;
use crate::runtime::signal_safe_block_on;
use crate::selection::PySelection;
use crate::shape::PyShape;
use crate::supervision::SupervisionError;
Expand Down Expand Up @@ -413,36 +414,21 @@ impl MonitoredPythonPortReceiver {
}
}

fn recv<'py>(&mut self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
fn recv_task<'py>(&mut self) -> PyPythonTask {
let receiver = self.inner.clone();
let monitor = self.monitor.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
PythonTask::new(async move {
let mut receiver = receiver.lock().await;
tokio::select! {
let result = tokio::select! {
result = receiver.recv() => {
result.map_err(|err| PyErr::new::<PyEOFError, _>(format!("port closed: {}", err)))
}
event = monitor.next() => {
Err(PyErr::new::<SupervisionError, _>(format!("supervision error: {:?}", event.unwrap())))
}
}
})
}

fn blocking_recv<'py>(&mut self, py: Python<'py>) -> PyResult<PythonMessage> {
let receiver = self.inner.clone();
let monitor = self.monitor.clone();
signal_safe_block_on(py, async move {
let mut receiver = receiver.lock().await;
tokio::select! {
result = receiver.recv() => {
result.map_err(|err| PyErr::new::<PyEOFError, _>(format!("port closed: {}", err)))
}
event = monitor.next() => {
Err(PyErr::new::<SupervisionError, _>(format!("supervision error: {:?}", event.unwrap())))
}
}
})?
};
result.and_then(|message: PythonMessage| Python::with_gil(|py| message.into_py_any(py)))
}).into()
}
}

Expand All @@ -466,38 +452,22 @@ impl MonitoredPythonOncePortReceiver {
}
}

fn recv<'py>(&mut self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
fn recv_task<'py>(&mut self) -> PyResult<PyPythonTask> {
let Some(receiver) = self.inner.lock().unwrap().take() else {
return Err(PyErr::new::<PyValueError, _>("OncePort is already used"));
};
let monitor = self.monitor.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
tokio::select! {
Ok(PythonTask::new(async move {
let result = tokio::select! {
result = receiver.recv() => {
result.map_err(|err| PyErr::new::<PyEOFError, _>(format!("port closed: {}", err)))
}
event = monitor.next() => {
Err(PyErr::new::<SupervisionError, _>(format!("supervision error: {:?}", event.unwrap())))
}
}
})
}

fn blocking_recv<'py>(&mut self, py: Python<'py>) -> PyResult<PythonMessage> {
let Some(receiver) = self.inner.lock().unwrap().take() else {
return Err(PyErr::new::<PyValueError, _>("OncePort is already used"));
};
let monitor = self.monitor.clone();
signal_safe_block_on(py, async move {
tokio::select! {
result = receiver.recv() => {
result.map_err(|err| PyErr::new::<PyEOFError, _>(format!("port closed: {}", err)))
}
event = monitor.next() => {
Err(PyErr::new::<SupervisionError, _>(format!("supervision error: {:?}", event.unwrap())))
}
}
})?
};
result.and_then(|message: PythonMessage| Python::with_gil(|py| message.into_py_any(py)))
}).into())
}
}

Expand Down
45 changes: 20 additions & 25 deletions monarch_hyperactor/src/mailbox.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,10 @@ use pyo3::types::PyType;
use serde::Deserialize;
use serde::Serialize;

use crate::actor::PyPythonTask;
use crate::actor::PythonMessage;
use crate::actor::PythonMessageKind;
use crate::actor::PythonTask;
use crate::proc::PyActorId;
use crate::runtime::signal_safe_block_on;
use crate::shape::PyShape;
Expand Down Expand Up @@ -380,23 +382,23 @@ pub(super) struct PythonPortReceiver {
inner: Arc<tokio::sync::Mutex<PortReceiver<PythonMessage>>>,
}

async fn recv_async(
receiver: Arc<tokio::sync::Mutex<PortReceiver<PythonMessage>>>,
) -> PyResult<PyObject> {
receiver
.lock()
.await
.recv()
.await
.map_err(|err| PyErr::new::<PyEOFError, _>(format!("Port closed: {}", err)))
.and_then(|message| Python::with_gil(|py| message.into_py_any(py)))
}

#[pymethods]
impl PythonPortReceiver {
fn recv<'py>(&mut self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
fn recv_task<'py>(&mut self) -> PyPythonTask {
let receiver = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
receiver
.lock()
.await
.recv()
.await
.map_err(|err| PyErr::new::<PyEOFError, _>(format!("Port closed: {}", err)))
})
}
fn blocking_recv<'py>(&mut self, py: Python<'py>) -> PyResult<PythonMessage> {
let receiver = self.inner.clone();
signal_safe_block_on(py, async move { receiver.lock().await.recv().await })?
.map_err(|err| PyErr::new::<PyEOFError, _>(format!("Port closed: {}", err)))
PythonTask::new(recv_async(receiver)).into()
}
}

Expand Down Expand Up @@ -551,24 +553,18 @@ pub(super) struct PythonOncePortReceiver {

#[pymethods]
impl PythonOncePortReceiver {
fn recv<'py>(&mut self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
fn recv_task<'py>(&mut self) -> PyResult<PyPythonTask> {
let Some(receiver) = self.inner.lock().unwrap().take() else {
return Err(PyErr::new::<PyValueError, _>("OncePort is already used"));
};

pyo3_async_runtimes::tokio::future_into_py(py, async move {
let fut = async move {
receiver
.recv()
.await
.map_err(|err| PyErr::new::<PyEOFError, _>(format!("Port closed: {}", err)))
})
}
fn blocking_recv<'py>(&mut self, py: Python<'py>) -> PyResult<PythonMessage> {
let Some(receiver) = self.inner.lock().unwrap().take() else {
return Err(PyErr::new::<PyValueError, _>("OncePort is already used"));
.and_then(|message| Python::with_gil(|py| message.into_py_any(py)))
};
signal_safe_block_on(py, async move { receiver.recv().await })?
.map_err(|err| PyErr::new::<PyEOFError, _>(format!("Port closed: {}", err)))
Ok(PythonTask::new(fut).into())
}
}

Expand Down Expand Up @@ -710,6 +706,5 @@ pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResul
hyperactor_mod.add_class::<PythonOncePortHandle>()?;
hyperactor_mod.add_class::<PythonOncePortRef>()?;
hyperactor_mod.add_class::<PythonOncePortReceiver>()?;

Ok(())
}
19 changes: 3 additions & 16 deletions python/monarch/_rust_bindings/monarch_hyperactor/actor_mesh.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ from monarch._rust_bindings.monarch_hyperactor.mailbox import (
Mailbox,
OncePortReceiver,
PortReceiver,
PortReceiverBase,
)
from monarch._rust_bindings.monarch_hyperactor.proc import ActorId
from monarch._rust_bindings.monarch_hyperactor.selection import Selection
Expand Down Expand Up @@ -128,7 +129,7 @@ class ActorMeshMonitor:
...

@final
class MonitoredPortReceiver:
class MonitoredPortReceiver(PortReceiverBase):
"""
A monitored receiver to which PythonMessages are sent.
"""
Expand All @@ -139,15 +140,8 @@ class MonitoredPortReceiver:
"""
...

async def recv(self) -> PythonMessage:
"""Receive a PythonMessage from the port's sender."""
...
def blocking_recv(self) -> PythonMessage:
"""Receive a single PythonMessage from the port's sender."""
...

@final
class MonitoredOncePortReceiver:
class MonitoredOncePortReceiver(PortReceiverBase):
"""
A variant of monitored PortReceiver that can only receive a single message.
"""
Expand All @@ -158,13 +152,6 @@ class MonitoredOncePortReceiver:
"""
...

async def recv(self) -> PythonMessage:
"""Receive a single PythonMessage from the port's sender."""
...
def blocking_recv(self) -> PythonMessage:
"""Receive a single PythonMessage from the port's sender."""
...

@final
class ActorSupervisionEvent:
@property
Expand Down
Loading
Loading