From 66c4f09c86193f614348a086ad84782f24047b56 Mon Sep 17 00:00:00 2001 From: Peng Zhang Date: Mon, 21 Jul 2025 15:02:14 -0700 Subject: [PATCH] Migrate to PythonActorMesh and PythonActorMeshRef (#557) Summary: This diff swaps `_ActorMeshRefImpl` with `PythonActorMesh[Ref]`. The swap itself should be straightforward since `PythonActorMesh[Ref]` should be drop-in replacements for `_ActorMeshRefImpl`. Most of the complexity in this diff is from how I tried to add a toggle between them, just in case there is any bugs with `PythonActorMesh[Ref]`, so we can quickly switch back to `_ActorMeshRefImpl`. What I did is: 1. Add wrapper classes `EitherPyActorMesh[Ref]`, whose underlying type can be either `PythonActorMesh[Ref]` or `_ActorMeshRefImpl`; 2. a env var `USE_STANDIN_ACTOR_MESH` is used to which one would be used when instantiating `EitherPyActorMesh[Ref]`. The landing of this diff would mean all Python-side mesh API calls should go through Rust-side's `cast` code path, except several usages of `ActorIdRef`. Differential Revision: D78355743 --- python/monarch/_src/actor/actor_mesh.py | 378 ++++++++++++++++++++---- python/monarch/_src/actor/debugger.py | 12 +- python/monarch/_src/actor/proc_mesh.py | 18 +- python/monarch/mesh_controller.py | 4 +- python/monarch/rdma.py | 17 +- python/tests/test_actor_error.py | 9 +- python/tests/test_python_actors.py | 4 + 7 files changed, 346 insertions(+), 96 deletions(-) diff --git a/python/monarch/_src/actor/actor_mesh.py b/python/monarch/_src/actor/actor_mesh.py index 5f71e7ef0..f1ba6b36f 100644 --- a/python/monarch/_src/actor/actor_mesh.py +++ b/python/monarch/_src/actor/actor_mesh.py @@ -14,6 +14,7 @@ import inspect import itertools import logging +import os import random import traceback @@ -39,6 +40,7 @@ Optional, overload, ParamSpec, + Protocol, Sequence, Tuple, Type, @@ -56,6 +58,7 @@ MonitoredOncePortReceiver, MonitoredPortReceiver, PythonActorMesh, + PythonActorMeshRef, ) from monarch._rust_bindings.monarch_hyperactor.mailbox import ( Mailbox, @@ -68,9 +71,9 @@ from monarch._rust_bindings.monarch_hyperactor.mailbox import PortReceiverBase from monarch._rust_bindings.monarch_hyperactor.proc import ActorId +from monarch._rust_bindings.monarch_hyperactor.selection import Selection as HySelection from monarch._rust_bindings.monarch_hyperactor.shape import Point as HyPoint, Shape from monarch._rust_bindings.monarch_hyperactor.supervision import SupervisionError - from monarch._rust_bindings.monarch_hyperactor.telemetry import enter_span, exit_span from monarch._src.actor.allocator import LocalAllocator, ProcessAllocator from monarch._src.actor.future import Future @@ -80,6 +83,7 @@ from monarch._src.actor.shape import MeshTrait, NDSlice from monarch._src.actor.sync_state import fake_sync_state +from typing_extensions import Self if TYPE_CHECKING: from monarch._src.actor.proc_mesh import ProcMesh @@ -149,10 +153,219 @@ def set(debug_context: "DebugContext") -> None: Selection = Literal["all", "choose"] | int # TODO: replace with real selection objects +def to_hy_sel(selection: Selection, shape: Shape) -> HySelection: + if selection == "choose": + dim = len(shape.labels) + assert dim > 0 + query = ",".join(["?"] * dim) + return HySelection.from_string(f"{query}") + elif selection == "all": + return HySelection.from_string("*") + else: + raise ValueError(f"invalid selection: {selection}") + + +# A temporary gate used by the PythonActorMesh/PythonActorMeshRef migration. +# We can use this gate to quickly roll back to using _ActorMeshRefImpl, if we +# encounter any issues with the migration. +# +# This should be removed once we confirm PythonActorMesh/PythonActorMeshRef is +# working correctly in production. +def _use_standin_mesh() -> bool: + return bool(os.getenv("USE_STANDIN_ACTOR_MESH", default=False)) + + +class ActorMeshProtocol(Protocol): + """ + Protocol defining the common interface for actor mesh, mesh ref and _ActorMeshRefImpl. + + Note: We do not want to use ABC because _ActorMeshRefImpl already inherits + from MeshTrait and we want to avoid multiple inheritance, especially when + _ActorMeshRefImpl will be deleted soon. + """ + + @property + def shape(self) -> Shape: ... + + @property + def monitor(self) -> Optional[ActorMeshMonitor]: ... + + @property + def proc_mesh(self) -> Optional["ProcMesh"]: ... + + @property + def inner_name(Self) -> str: ... + + def cast( + self, + message: PythonMessage, + selection: Selection, + mailbox: Optional[Mailbox], + ) -> None: ... + + def slice(self, **kwargs: Any) -> Self: ... + + def bind(self) -> Self: ... + + +class _PythonActorMeshAdapter(ActorMeshProtocol): + """ + Adapter for PythonActorMesh to implement the normalized ActorMeshProtocol + interface. This adapter also provides a convenient way to add states to + the mesh on the python side, without changing the rust side implementation. + + Since PythonActorMesh cannot be pickled, this adapter also provides a + custom pickling logic which bind the mesh to PythonActorMeshRef during + pickling. + """ + + def __init__(self, inner: PythonActorMesh, proc_mesh: "ProcMesh") -> None: + self._inner = inner + self._proc_mesh = proc_mesh + + @property + def shape(self) -> Shape: + return self._inner.shape + + @property + def monitor(self) -> Optional[ActorMeshMonitor]: + return self._inner.monitor() + + @property + def proc_mesh(self) -> Optional["ProcMesh"]: + return self._proc_mesh + + @property + def inner_name(self) -> str: + return self._inner.__class__.__name__ + + def cast( + self, + message: PythonMessage, + selection: Selection, + mailbox: Optional[Mailbox], + ) -> None: + self._inner.cast(to_hy_sel(selection, self.shape), message) + + def slice(self, **kwargs: Any) -> "ActorMeshProtocol": + sliced: PythonActorMeshRef = self._inner.slice(**kwargs) + return _PythonActorMeshRefAdapter(sliced, self.proc_mesh, self.monitor) + + def bind(self) -> "ActorMeshProtocol": + # PythonActorMesh.bind returns PythonActorMeshRef + mesh_ref: PythonActorMeshRef = self._inner.bind() + return _PythonActorMeshRefAdapter(mesh_ref, self.proc_mesh, self.monitor) + + def __reduce_ex__(self, protocol: ...) -> Tuple[Any, Tuple[Any, ...]]: + """ + Automatically pickle as a PythonActorMeshRef by binding the mesh. + Unpicklable states such as proc_mesh and monitor are dropped as well. + """ + mesh_ref = self._inner.bind() + return _PythonActorMeshRefAdapter, (mesh_ref, None, None) + + +class _PythonActorMeshRefAdapter(ActorMeshProtocol): + """ + Adapter for PythonActorMeshRef to implement the normalized ActorMeshProtocol interface. It is + also used to store unpickable states such as proc_mesh and monitor. It is useful to have these + unpickable states when pickling is not needed. For example, slicing a mesh will result in a + mesh ref, and this mesh ref could be used by the same caller. This caller would expect the + mesh ref to have the same supervision behavior as the original mesh. In this case, having the + monitor field will be helpful. + """ + + def __init__( + self, + inner: PythonActorMeshRef, + proc_mesh: "Optional[ProcMesh]", + monitor: Optional[ActorMeshMonitor], + ) -> None: + self._inner = inner + self._proc_mesh = proc_mesh + self._monitor = monitor + + @property + def shape(self) -> Shape: + return self._inner.shape + + @property + def monitor(self) -> Optional[ActorMeshMonitor]: + return self._monitor + + @property + def proc_mesh(self) -> Optional["ProcMesh"]: + return self._proc_mesh + + @property + def inner_name(self) -> str: + return self._inner.__class__.__name__ + + def cast( + self, + message: PythonMessage, + selection: Selection, + mailbox: Optional[Mailbox] = None, + ) -> None: + if mailbox is None: + raise ValueError("mailbox is required for PythonActorMeshRef.cast()") + self._inner.cast(mailbox, to_hy_sel(selection, self.shape), message) + + def slice(self, **kwargs: Any) -> "ActorMeshProtocol": + sliced: PythonActorMeshRef = self._inner.slice(**kwargs) + return _PythonActorMeshRefAdapter(sliced, self._proc_mesh, self._monitor) + + def bind(self) -> "ActorMeshProtocol": + raise NotImplementedError("PythonActorMeshRef.bind() is not supported") + + def __reduce_ex__(self, protocol: ...) -> Tuple[Any, Tuple[Any, ...]]: + """ + Dropping all unpickable states. + """ + return _PythonActorMeshRefAdapter, (self._inner, None, None) + + +class _ActorIdAdapter(ActorMeshProtocol): + def __init__(self, inner: ActorId) -> None: + self._inner: ActorId = inner + + @property + def shape(self) -> Shape: + return singleton_shape + + @property + def monitor(self) -> Optional[ActorMeshMonitor]: + return None + + @property + def proc_mesh(self) -> Optional["ProcMesh"]: + return None + + @property + def inner_name(self) -> str: + return self._inner.__class__.__name__ + + def cast( + self, + message: PythonMessage, + selection: Selection, + mailbox: Optional[Mailbox], + ) -> None: + if mailbox is None: + raise ValueError("mailbox is required for ActorId") + mailbox.post(self._inner, message) + + def slice(self, **kwargs: Any) -> Self: + raise NotImplementedError("ActorId does not support slicing") + + def bind(self) -> Self: + raise NotImplementedError("ActorId does not support binding") + + # standin class for whatever is the serializable python object we use # to name an actor mesh. Hacked up today because ActorMesh # isn't plumbed to non-clients -class _ActorMeshRefImpl: +class _ActorMeshRefImpl(MeshTrait, ActorMeshProtocol): def __init__( self, mailbox: Mailbox, @@ -161,6 +374,10 @@ def __init__( shape: Shape, actor_ids: List[ActorId], ) -> None: + if not _use_standin_mesh(): + raise ValueError( + "ActorMeshRefImpl should only be used when USE_STANDIN_ACTOR_MESH is set" + ) self._mailbox = mailbox self._actor_mesh = hy_actor_mesh # actor meshes do not have a way to look this up at the moment, @@ -182,16 +399,25 @@ def from_hyperactor_mesh( [cast(ActorId, hy_actor_mesh.get(i)) for i in range(len(shape))], ) - @staticmethod - def from_actor_id(mailbox: Mailbox, actor_id: ActorId) -> "_ActorMeshRefImpl": - return _ActorMeshRefImpl(mailbox, None, None, singleton_shape, [actor_id]) + @property + def monitor(self) -> Optional[ActorMeshMonitor]: + return self._actor_mesh.monitor() if self._actor_mesh is not None else None - @staticmethod - def from_actor_ref_with_shape( - ref: "_ActorMeshRefImpl", shape: Shape - ) -> "_ActorMeshRefImpl": + @property + def shape(self) -> Shape: + return self._shape + + @property + def _ndslice(self) -> NDSlice: + return self._shape.ndslice + + @property + def _labels(self) -> Iterable[str]: + return self._shape.labels + + def _new_with_shape(self, shape: Shape) -> "_ActorMeshRefImpl": return _ActorMeshRefImpl( - ref._mailbox, None, None, shape, ref._please_replace_me_actor_ids + self._mailbox, None, None, shape, self._please_replace_me_actor_ids ) def __getstate__( @@ -216,15 +442,11 @@ def _check_state(self) -> None: if event is not None: raise SupervisionError(f"actor mesh is not in a healthy state: {event}") - def send(self, rank: int, message: PythonMessage) -> None: - self._check_state() - actor = self._please_replace_me_actor_ids[rank] - self._mailbox.post(actor, message) - def cast( self, message: PythonMessage, selection: Selection, + mailbox: Optional[Mailbox], ) -> None: self._check_state() @@ -266,14 +488,12 @@ def cast( else: raise ValueError(f"invalid selection: {selection}") + def bind(self) -> Self: + return self + def __len__(self) -> int: return len(self._shape) - @property - def _name_pid(self): - actor_id0 = self._please_replace_me_actor_ids[0] - return actor_id0.actor_name, actor_id0.pid - class Extent(NamedTuple): labels: Sequence[str] @@ -343,7 +563,7 @@ def call(self, *args: P.args, **kwargs: P.kwargs) -> "Future[ValueMesh[R]]": extent = self._send(args, kwargs, port=p) async def process() -> ValueMesh[R]: - results: List[R] = [None] * extent.nelements # pyre-fixme[9] + results: Dict[int, R] = {} for _ in range(extent.nelements): rank, value = await r.recv() results[rank] = value @@ -351,7 +571,8 @@ async def process() -> ValueMesh[R]: extent.labels, NDSlice.new_row_major(extent.sizes), ) - return ValueMesh(call_shape, results) + sorted_values = [results[rank] for rank in sorted(results)] + return ValueMesh(call_shape, sorted_values) return Future(impl=process, requires_loop=False) @@ -383,12 +604,12 @@ def broadcast(self, *args: P.args, **kwargs: P.kwargs) -> None: class ActorEndpoint(Endpoint[P, R]): def __init__( self, - actor_mesh_ref: _ActorMeshRefImpl, + actor_mesh: ActorMeshProtocol, name: str, impl: Callable[Concatenate[Any, P], Awaitable[R]], mailbox: Mailbox, ) -> None: - self._actor_mesh = actor_mesh_ref + self._actor_mesh = actor_mesh self._name = name self._signature: inspect.Signature = inspect.signature(impl) self._mailbox = mailbox @@ -415,20 +636,16 @@ def _send( ), bytes, ) - self._actor_mesh.cast(message, selection) + self._actor_mesh.cast(message, selection, self._mailbox) else: importlib.import_module("monarch." + "mesh_controller").actor_send( self, bytes, refs, port, selection ) - shape = self._actor_mesh._shape + shape = self._actor_mesh.shape return Extent(shape.labels, shape.ndslice.sizes) def _port(self, once: bool = False) -> "PortTuple[R]": - monitor = ( - None - if self._actor_mesh._actor_mesh is None - else self._actor_mesh._actor_mesh.monitor() - ) + monitor = self._actor_mesh.monitor return PortTuple.create(self._mailbox, monitor, once) @@ -875,19 +1092,22 @@ def _labels(self) -> Tuple[str, ...]: "actor implementations are not meshes, but we can't convince the typechecker of it..." ) - def _new_with_shape(self, shape: Shape) -> "ActorMeshRef": + def _new_with_shape(self, shape: Shape) -> Self: raise NotImplementedError( "actor implementations are not meshes, but we can't convince the typechecker of it..." ) -class ActorMeshRef(MeshTrait): +class ActorMesh(MeshTrait, Generic[T]): def __init__( - self, Class: Type[T], actor_mesh_ref: _ActorMeshRefImpl, mailbox: Mailbox + self, + Class: Type[T], + inner: ActorMeshProtocol, + mailbox: Mailbox, ) -> None: self.__name__: str = Class.__name__ self._class: Type[T] = Class - self._actor_mesh_ref: _ActorMeshRefImpl = actor_mesh_ref + self._inner: ActorMeshProtocol = inner self._mailbox: Mailbox = mailbox for attr_name in dir(self._class): attr_value = getattr(self._class, attr_name, None) @@ -896,7 +1116,7 @@ def __init__( self, attr_name, ActorEndpoint( - self._actor_mesh_ref, + self._inner, attr_name, attr_value._method, self._mailbox, @@ -915,7 +1135,7 @@ def __getattr__(self, name: str) -> Any: if isinstance(attr, EndpointProperty): # Dynamically create the endpoint endpoint = ActorEndpoint( - self._actor_mesh_ref, + self._inner, name, attr._method, self._mailbox, @@ -929,48 +1149,90 @@ def __getattr__(self, name: str) -> Any: f"'{self.__class__.__name__}' object has no attribute '{name}'" ) - def _create( - self, - args: Iterable[Any], - kwargs: Dict[str, Any], - ) -> None: + @classmethod + def create( + cls, + Class: Type[T], + actor_mesh: PythonActorMesh, + mailbox: Mailbox, + proc_mesh: "ProcMesh", + # args and kwargs are passed to the __init__ method of the user defined + # python actor object. + *args: Any, + **kwargs: Any, + ) -> "ActorMesh[T]": + if _use_standin_mesh(): + wrapper = _ActorMeshRefImpl.from_hyperactor_mesh( + mailbox, actor_mesh, proc_mesh + ) + else: + wrapper = _PythonActorMeshAdapter(actor_mesh, proc_mesh) + mesh = cls(Class, wrapper, mailbox) + + # send __init__ message to the mesh to initialize the user defined + # python actor object. async def null_func(*_args: Iterable[Any], **_kwargs: Dict[str, Any]) -> None: return None ep = ActorEndpoint( - self._actor_mesh_ref, + mesh._inner, "__init__", null_func, - self._mailbox, + mesh._mailbox, ) - send(ep, (self._class, *args), kwargs) + send(ep, (mesh._class, *args), kwargs) - def __reduce_ex__( - self, protocol: ... - ) -> "Tuple[Type[ActorMeshRef], Tuple[Any, ...]]": - return ActorMeshRef, ( + return mesh + + @classmethod + def from_actor_id( + cls, + Class: Type[T], + actor_id: ActorId, + mailbox: Mailbox, + ) -> "ActorMesh[T]": + return cls(Class, _ActorIdAdapter(actor_id), mailbox) + + def bind(self) -> "ActorMesh[T]": + if not isinstance(self._inner, _PythonActorMeshAdapter) or not isinstance( + self._inner, _ActorMeshRefImpl + ): + raise AttributeError( + "{msg} is only available on PythonActorMesh or _ActorMeshRefImpl, but got {self._inner.inner_name}" + ) + mesh_ref = self._inner.bind() + return ActorMesh(self._class, mesh_ref, self._mailbox) + + def __reduce_ex__(self, protocol: ...) -> "Tuple[Type[ActorMesh], Tuple[Any, ...]]": + return ActorMesh, ( self._class, - self._actor_mesh_ref, + self._inner, self._mailbox, ) + @property + def proc_mesh(self) -> "Optional[ProcMesh]": + return self._inner.proc_mesh + @property def _ndslice(self) -> NDSlice: - return self._actor_mesh_ref._shape.ndslice + return self._inner.shape.ndslice @property def _labels(self) -> Iterable[str]: - return self._actor_mesh_ref._shape.labels + return self._inner.shape.labels - def _new_with_shape(self, shape: Shape) -> "ActorMeshRef": - return ActorMeshRef( - self._class, - _ActorMeshRefImpl.from_actor_ref_with_shape(self._actor_mesh_ref, shape), - self._mailbox, + def _new_with_shape(self, shape: Shape) -> "ActorMesh": + raise NotImplementedError( + "should not be called because def slice is overridden" ) + def slice(self, **kwargs) -> "ActorMesh[T]": + sliced = self._inner.slice(**kwargs) + return ActorMesh(self._class, sliced, self._mailbox) + def __repr__(self) -> str: - return f"ActorMeshRef(class={self._class}, shape={self._actor_mesh_ref._shape})" + return f"ActorMesh(class={self._class}, shape={self._inner.shape}), inner={self._inner.inner_name})" class ActorError(Exception): diff --git a/python/monarch/_src/actor/debugger.py b/python/monarch/_src/actor/debugger.py index cb5386b62..f4904271a 100644 --- a/python/monarch/_src/actor/debugger.py +++ b/python/monarch/_src/actor/debugger.py @@ -16,9 +16,8 @@ from monarch._rust_bindings.monarch_hyperactor.proc import ActorId from monarch._src.actor.actor_mesh import ( - _ActorMeshRefImpl, Actor, - ActorMeshRef, + ActorMesh, DebugContext, endpoint, MonarchContext, @@ -514,14 +513,9 @@ def ref() -> "DebugManager": ctx = MonarchContext.get() return cast( DebugManager, - ActorMeshRef( + ActorMesh.from_actor_id( DebugManager, - _ActorMeshRefImpl.from_actor_id( - ctx.mailbox, - ActorId.from_string( - f"{ctx.proc_id}.{_DEBUG_MANAGER_ACTOR_NAME}[0]" - ), - ), + ActorId.from_string(f"{ctx.proc_id}.{_DEBUG_MANAGER_ACTOR_NAME}[0]"), ctx.mailbox, ), ) diff --git a/python/monarch/_src/actor/proc_mesh.py b/python/monarch/_src/actor/proc_mesh.py index 7765d0434..b7a1ce5ce 100644 --- a/python/monarch/_src/actor/proc_mesh.py +++ b/python/monarch/_src/actor/proc_mesh.py @@ -36,13 +36,7 @@ ProcMeshMonitor, ) from monarch._rust_bindings.monarch_hyperactor.shape import Shape, Slice -from monarch._src.actor.actor_mesh import ( - _Actor, - _ActorMeshRefImpl, - Actor, - ActorMeshRef, - fake_sync_state, -) +from monarch._src.actor.actor_mesh import _Actor, Actor, ActorMesh, fake_sync_state from monarch._src.actor.allocator import LocalAllocator, ProcessAllocator, SimAllocator from monarch._src.actor.code_sync import ( CodeSyncMeshClient, @@ -190,14 +184,14 @@ async def _spawn_nonblocking( f"{Class} must subclass monarch.service.Actor to spawn it." ) actor_mesh = await self._proc_mesh.spawn_nonblocking(name, _Actor) - service = ActorMeshRef( + service = ActorMesh.create( Class, - _ActorMeshRefImpl.from_hyperactor_mesh(self._mailbox, actor_mesh, self), + actor_mesh, self._mailbox, + self._proc_mesh, + *args, + **kwargs, ) - # useful to have this separate, because eventually we can reconstitute ActorMeshRef objects across pickling by - # doing `ActorMeshRef(Class, actor_handle)` but not calling _create. - service._create(args, kwargs) return cast(T, service) @property diff --git a/python/monarch/mesh_controller.py b/python/monarch/mesh_controller.py index 6545a9268..c5e7f37f0 100644 --- a/python/monarch/mesh_controller.py +++ b/python/monarch/mesh_controller.py @@ -293,7 +293,7 @@ def actor_send( # mutates checker.check_permission(()) selected_device_mesh = ( - endpoint._actor_mesh._proc_mesh and endpoint._actor_mesh._proc_mesh._device_mesh + endpoint._actor_mesh.proc_mesh and endpoint._actor_mesh.proc_mesh._device_mesh ) if selected_device_mesh is not checker.mesh: raise ValueError( @@ -325,7 +325,7 @@ def actor_send( ), args_kwargs_tuple, ) - endpoint._actor_mesh.cast(actor_msg, selection) + endpoint._actor_mesh.cast(actor_msg, selection, endpoint._mailbox) worker_msg = SendResultOfActorCall(ident, broker_id, tensors, [], stream_ref) client.send(checker.mesh._ndslice, worker_msg) # we have to ask for status updates diff --git a/python/monarch/rdma.py b/python/monarch/rdma.py index b9cc771a0..1d005146c 100644 --- a/python/monarch/rdma.py +++ b/python/monarch/rdma.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import ctypes from dataclasses import dataclass @@ -12,13 +14,7 @@ import torch from monarch._rust_bindings.monarch_hyperactor.proc import ActorId -from monarch._src.actor.actor_mesh import ( - _ActorMeshRefImpl, - Actor, - ActorMeshRef, - endpoint, - MonarchContext, -) +from monarch._src.actor.actor_mesh import Actor, ActorMesh, endpoint, MonarchContext @dataclass @@ -54,12 +50,9 @@ def on_proc(proc_id: str) -> "RDMAManager": ctx = MonarchContext.get() return cast( RDMAManager, - ActorMeshRef( + ActorMesh.from_actor_id( RDMAManager, - _ActorMeshRefImpl.from_actor_id( - ctx.mailbox, - ActorId.from_string(f"{proc_id}.rdma_manager[0]"), - ), + ActorId.from_string(f"{proc_id}.rdma_manager[0]"), ctx.mailbox, ), ) diff --git a/python/tests/test_actor_error.py b/python/tests/test_actor_error.py index c48186e7f..73502e1e5 100644 --- a/python/tests/test_actor_error.py +++ b/python/tests/test_actor_error.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe import importlib.resources import subprocess @@ -424,7 +425,7 @@ async def test_actor_mesh_supervision_handling(): await e.fail_with_supervision_error.call_one() # new call should fail with check of health state of actor mesh - with pytest.raises(SupervisionError, match="actor mesh is not in a healthy state"): + with pytest.raises(RuntimeError, match="actor mesh is unhealthy with reason"): await e.check.call() # should not be able to spawn actors anymore as proc mesh is unhealthy @@ -475,7 +476,7 @@ async def test_actor_mesh_supervision_handling_chained_error(): await intermediate_actor.forward_error.call() # calling success endpoint should fail with ActorError, but with supervision msg. - with pytest.raises(ActorError, match="actor mesh is not in a healthy state"): + with pytest.raises(ActorError, match="actor mesh is unhealthy with reason"): await intermediate_actor.forward_success.call() # healthy actor should still be working @@ -491,7 +492,9 @@ async def test_supervision_with_proc_mesh_stopped(): await proc.stop() # new call should fail with check of health state of actor mesh - with pytest.raises(SupervisionError, match="actor mesh is not in a healthy state"): + with pytest.raises( + RuntimeError, match="`PythonActorMesh` has already been stopped" + ): await actor_mesh.check.call() # proc mesh cannot spawn new actors anymore diff --git a/python/tests/test_python_actors.py b/python/tests/test_python_actors.py index 57513d46b..6026e5553 100644 --- a/python/tests/test_python_actors.py +++ b/python/tests/test_python_actors.py @@ -90,6 +90,8 @@ async def test_choose(): proc = await local_proc_mesh(gpus=2) v = await proc.spawn("counter", Counter, 3) i = await proc.spawn("indirect", Indirect) + # wait for meshes to be created + asyncio.sleep(1) v.incr.broadcast() result = await v.value.choose() @@ -329,6 +331,8 @@ async def test_sync_actor(): proc = await local_proc_mesh(gpus=2) a = await proc.spawn("actor", SyncActor) c = await proc.spawn("counter", Counter, 5) + # wait for meshes to be created + await asyncio.sleep(1) r = await a.sync_endpoint.choose(c) assert r == 5