diff --git a/monarch_hyperactor/src/mailbox.rs b/monarch_hyperactor/src/mailbox.rs index 6e7654039..219eb4513 100644 --- a/monarch_hyperactor/src/mailbox.rs +++ b/monarch_hyperactor/src/mailbox.rs @@ -9,6 +9,7 @@ use std::hash::DefaultHasher; use std::hash::Hash; use std::hash::Hasher; +use std::ops::Deref; use std::sync::Arc; use hyperactor::Mailbox; @@ -35,11 +36,13 @@ use hyperactor::message::Bindings; use hyperactor::message::Unbind; use hyperactor_mesh::comm::multicast::set_cast_info_on_headers; use monarch_types::PickledPyObject; +use pyo3::IntoPyObjectExt; use pyo3::exceptions::PyEOFError; use pyo3::exceptions::PyRuntimeError; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; use pyo3::types::PyTuple; +use pyo3::types::PyType; use serde::Deserialize; use serde::Serialize; @@ -326,6 +329,19 @@ pub struct PythonPortRef { #[pymethods] impl PythonPortRef { + #[new] + fn new(port: PyPortId) -> Self { + Self { + inner: PortRef::attest(port.into()), + } + } + fn __reduce__<'py>( + slf: Bound<'py, PythonPortRef>, + ) -> PyResult<(Bound<'py, PyType>, (PyPortId,))> { + let id: PyPortId = (*slf.borrow()).inner.port_id().clone().into(); + Ok((slf.get_type(), (id,))) + } + fn send(&self, mailbox: &PyMailbox, message: PythonMessage) -> PyResult<()> { self.inner .send(&mailbox.inner, message) @@ -472,6 +488,22 @@ pub struct PythonOncePortRef { #[pymethods] impl PythonOncePortRef { + #[new] + fn new(port: Option) -> Self { + Self { + inner: port.map(|port| PortRef::attest(port.inner).into_once()), + } + } + fn __reduce__<'py>( + slf: Bound<'py, PythonOncePortRef>, + ) -> PyResult<(Bound<'py, PyType>, (Option,))> { + let id: Option = (*slf.borrow()) + .inner + .as_ref() + .map(|x| x.port_id().clone().into()); + Ok((slf.get_type(), (id,))) + } + fn send(&mut self, mailbox: &PyMailbox, message: PythonMessage) -> PyResult<()> { let Some(port_ref) = self.inner.take() else { return Err(PyErr::new::("OncePortRef is already used")); diff --git a/python/tests/test_python_actors.py b/python/tests/test_python_actors.py index 9ea581161..2608b7f13 100644 --- a/python/tests/test_python_actors.py +++ b/python/tests/test_python_actors.py @@ -20,6 +20,8 @@ import torch +from monarch._src.actor.actor_mesh import Port, PortTuple + from monarch.actor import ( Accumulator, Actor, @@ -706,6 +708,24 @@ async def test_actor_log_streaming() -> None: pass +class SendAlot(Actor): + @endpoint + async def send(self, port: Port[int]): + for i in range(100): + port.send(i) + + +def test_port_as_argument(): + proc_mesh = local_proc_mesh(gpus=1).get() + s = proc_mesh.spawn("send_alot", SendAlot).get() + send, recv = PortTuple.create(proc_mesh._mailbox, None) + + s.send.broadcast(send) + + for i in range(100): + assert i == recv.recv().get() + + @pytest.mark.timeout(15) async def test_same_actor_twice() -> None: pm = await proc_mesh(gpus=1)