Skip to content

Fix python port serialization #560

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions monarch_hyperactor/src/mailbox.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
use std::hash::DefaultHasher;
use std::hash::Hash;
use std::hash::Hasher;
use std::ops::Deref;
use std::sync::Arc;

use hyperactor::Mailbox;
Expand All @@ -35,11 +36,13 @@ use hyperactor::message::Bindings;
use hyperactor::message::Unbind;
use hyperactor_mesh::comm::multicast::set_cast_info_on_headers;
use monarch_types::PickledPyObject;
use pyo3::IntoPyObjectExt;
use pyo3::exceptions::PyEOFError;
use pyo3::exceptions::PyRuntimeError;
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::types::PyTuple;
use pyo3::types::PyType;
use serde::Deserialize;
use serde::Serialize;

Expand Down Expand Up @@ -326,6 +329,19 @@ pub struct PythonPortRef {

#[pymethods]
impl PythonPortRef {
#[new]
fn new(port: PyPortId) -> Self {
Self {
inner: PortRef::attest(port.into()),
}
}
fn __reduce__<'py>(
slf: Bound<'py, PythonPortRef>,
) -> PyResult<(Bound<'py, PyType>, (PyPortId,))> {
let id: PyPortId = (*slf.borrow()).inner.port_id().clone().into();
Ok((slf.get_type(), (id,)))
}

fn send(&self, mailbox: &PyMailbox, message: PythonMessage) -> PyResult<()> {
self.inner
.send(&mailbox.inner, message)
Expand Down Expand Up @@ -472,6 +488,22 @@ pub struct PythonOncePortRef {

#[pymethods]
impl PythonOncePortRef {
#[new]
fn new(port: Option<PyPortId>) -> Self {
Self {
inner: port.map(|port| PortRef::attest(port.inner).into_once()),
}
}
fn __reduce__<'py>(
slf: Bound<'py, PythonOncePortRef>,
) -> PyResult<(Bound<'py, PyType>, (Option<PyPortId>,))> {
let id: Option<PyPortId> = (*slf.borrow())
.inner
.as_ref()
.map(|x| x.port_id().clone().into());
Ok((slf.get_type(), (id,)))
}

fn send(&mut self, mailbox: &PyMailbox, message: PythonMessage) -> PyResult<()> {
let Some(port_ref) = self.inner.take() else {
return Err(PyErr::new::<PyValueError, _>("OncePortRef is already used"));
Expand Down
20 changes: 20 additions & 0 deletions python/tests/test_python_actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

import torch

from monarch._src.actor.actor_mesh import Port, PortTuple

from monarch.actor import (
Accumulator,
Actor,
Expand Down Expand Up @@ -706,6 +708,24 @@ async def test_actor_log_streaming() -> None:
pass


class SendAlot(Actor):
@endpoint
async def send(self, port: Port[int]):
for i in range(100):
port.send(i)


def test_port_as_argument():
proc_mesh = local_proc_mesh(gpus=1).get()
s = proc_mesh.spawn("send_alot", SendAlot).get()
send, recv = PortTuple.create(proc_mesh._mailbox, None)

s.send.broadcast(send)

for i in range(100):
assert i == recv.recv().get()


@pytest.mark.timeout(15)
async def test_same_actor_twice() -> None:
pm = await proc_mesh(gpus=1)
Expand Down
Loading