Skip to content

Push MeshTrait implementation to _ActorMeshRefImpl #553

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 54 additions & 25 deletions hyperactor_mesh/src/actor_mesh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ use ndslice::Range;
use ndslice::Selection;
use ndslice::Shape;
use ndslice::ShapeError;
use ndslice::SliceError;
use ndslice::selection;
use ndslice::selection::EvalOpts;
use ndslice::selection::ReifyView;
Expand Down Expand Up @@ -95,6 +96,47 @@ where
Ok(())
}

#[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `CastError`.
pub(crate) fn cast_to_sliced_mesh<A, M>(
caps: &impl cap::CanSend,
actor_mesh_id: ActorMeshId,
sender: &ActorId,
comm_actor_ref: &ActorRef<CommActor>,
sel_of_sliced: &Selection,
message: M,
sliced_shape: &Shape,
base_shape: &Shape,
) -> Result<(), CastError>
where
A: RemoteActor + RemoteHandles<IndexedErasedUnbound<M>>,
M: Castable + RemoteMessage,
{
let base_slice = base_shape.slice();

// Casting to `*`?
let sel_of_base = if selection::normalize(sel_of_sliced) == normal::NormalizedSelection::True {
// Reify this view into base.
base_slice.reify_view(sliced_shape.slice())?
} else {
// No, fall back on `of_ranks`.
let ranks = sel_of_sliced
.eval(&EvalOpts::strict(), sliced_shape.slice())?
.collect::<BTreeSet<_>>();
Selection::of_ranks(base_slice, &ranks)?
};

// Cast.
actor_mesh_cast::<A, M>(
caps,
actor_mesh_id,
base_shape,
sender,
comm_actor_ref,
sel_of_base,
message,
)
}

/// A mesh of actors, all of which reside on the same [`ProcMesh`].
pub trait ActorMesh: Mesh<Id = ActorMeshId> {
/// The type of actor in the mesh.
Expand Down Expand Up @@ -350,31 +392,15 @@ impl<A: RemoteActor> ActorMesh for SlicedActorMesh<'_, A> {
Self::Actor: RemoteHandles<IndexedErasedUnbound<M>>,
M: Castable + RemoteMessage,
{
let base_shape = self.0.shape();
let base_slice = base_shape.slice();

// Casting to `*`?
let selection = if selection::normalize(&sel) == normal::NormalizedSelection::True {
// Reify this view into base.
base_slice.reify_view(self.shape().slice()).unwrap()
} else {
// No, fall back on `of_ranks`.
let ranks = sel
.eval(&EvalOpts::strict(), self.shape().slice())
.unwrap()
.collect::<BTreeSet<_>>();
Selection::of_ranks(base_slice, &ranks).unwrap()
};

// Cast.
actor_mesh_cast::<A, M>(
self.proc_mesh().client(), // send capability
self.id(), // actor mesh id (destination mesh)
base_shape, // actor mesh shape
self.proc_mesh().client().actor_id(), // sender
self.proc_mesh().comm_actor(), // comm actor
selection, // the selected actors
message, // the message
cast_to_sliced_mesh::<A, M>(
/*caps=*/ self.proc_mesh().client(),
/*actor_mesh_id=*/ self.id(),
/*sender=*/ self.proc_mesh().client().actor_id(),
/*comm_actor_ref*/ self.proc_mesh().comm_actor(),
/*sel_of_sliced=*/ &sel,
/*message=*/ message,
/*sliced_shape=*/ self.shape(),
/*base_shape=*/ self.0.shape(),
)
}
}
Expand All @@ -394,6 +420,9 @@ pub enum CastError {
#[error(transparent)]
ShapeError(#[from] ShapeError),

#[error(transparent)]
SliceError(#[from] SliceError),

#[error(transparent)]
SerializationError(#[from] bincode::Error),

Expand Down
74 changes: 51 additions & 23 deletions hyperactor_mesh/src/reference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,17 @@ use hyperactor::actor::RemoteActor;
use hyperactor::cap;
use hyperactor::message::Castable;
use hyperactor::message::IndexedErasedUnbound;
use ndslice::Range;
use ndslice::Selection;
use ndslice::Shape;
use ndslice::ShapeError;
use serde::Deserialize;
use serde::Serialize;

use crate::CommActor;
use crate::actor_mesh::CastError;
use crate::actor_mesh::actor_mesh_cast;
use crate::actor_mesh::cast_to_sliced_mesh;

#[macro_export]
macro_rules! mesh_id {
Expand Down Expand Up @@ -71,10 +74,15 @@ pub struct ProcMeshId(pub String);
pub struct ActorMeshId(pub ProcMeshId, pub String);

/// Types references to Actor Meshes.
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, PartialEq)]
pub struct ActorMeshRef<A: RemoteActor> {
pub(crate) mesh_id: ActorMeshId,
shape: Shape,
/// The shape of the root mesh.
root: Shape,
/// If some, it mean this mesh ref points to a sliced mesh, and this field
/// is this sliced mesh's shape. If None, it means this mesh ref points to
/// the root mesh.
sliced: Option<Shape>,
/// The reference to the comm actor of the underlying Proc Mesh.
comm_actor_ref: ActorRef<CommActor>,
phantom: PhantomData<A>,
Expand All @@ -87,12 +95,13 @@ impl<A: RemoteActor> ActorMeshRef<A> {
/// line argument) is a valid reference.
pub(crate) fn attest(
mesh_id: ActorMeshId,
shape: Shape,
root: Shape,
comm_actor_ref: ActorRef<CommActor>,
) -> Self {
Self {
mesh_id,
shape,
root,
sliced: None,
comm_actor_ref,
phantom: PhantomData,
}
Expand All @@ -105,7 +114,10 @@ impl<A: RemoteActor> ActorMeshRef<A> {

/// Shape of the Actor Mesh.
pub fn shape(&self) -> &Shape {
&self.shape
match &self.sliced {
Some(s) => s,
None => &self.root,
}
}

/// Cast an [`M`]-typed message to the ranks selected by `sel`
Expand All @@ -121,37 +133,53 @@ impl<A: RemoteActor> ActorMeshRef<A> {
A: RemoteHandles<M> + RemoteHandles<IndexedErasedUnbound<M>>,
M: Castable + RemoteMessage,
{
actor_mesh_cast::<A, M>(
caps,
self.mesh_id.clone(),
self.shape(),
caps.mailbox().actor_id(),
&self.comm_actor_ref,
selection,
message,
)
match &self.sliced {
Some(sliced_shape) => cast_to_sliced_mesh::<A, M>(
caps,
self.mesh_id.clone(),
caps.mailbox().actor_id(),
&self.comm_actor_ref,
&selection,
message,
sliced_shape,
&self.root,
),
None => actor_mesh_cast::<A, M>(
caps,
self.mesh_id.clone(),
&self.root,
caps.mailbox().actor_id(),
&self.comm_actor_ref,
selection,
message,
),
}
}

pub fn select<R: Into<Range>>(&self, label: &str, range: R) -> Result<Self, ShapeError> {
let sliced = self.shape().select(label, range)?;
Ok(Self {
mesh_id: self.mesh_id.clone(),
root: self.root.clone(),
sliced: Some(sliced),
comm_actor_ref: self.comm_actor_ref.clone(),
phantom: PhantomData,
})
}
}

impl<A: RemoteActor> Clone for ActorMeshRef<A> {
fn clone(&self) -> Self {
Self {
mesh_id: self.mesh_id.clone(),
shape: self.shape.clone(),
root: self.root.clone(),
sliced: self.sliced.clone(),
comm_actor_ref: self.comm_actor_ref.clone(),
phantom: PhantomData,
}
}
}

impl<A: RemoteActor> PartialEq for ActorMeshRef<A> {
fn eq(&self, other: &Self) -> bool {
self.mesh_id == other.mesh_id && self.shape == other.shape
}
}

impl<A: RemoteActor> Eq for ActorMeshRef<A> {}

#[cfg(test)]
mod tests {
use async_trait::async_trait;
Expand Down
76 changes: 76 additions & 0 deletions monarch_hyperactor/src/actor_mesh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ use pyo3::exceptions::PyRuntimeError;
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::types::PyBytes;
use pyo3::types::PyDict;
use pyo3::types::PySlice;
use serde::Deserialize;
use serde::Serialize;
use tokio::sync::Mutex;
Expand Down Expand Up @@ -178,6 +180,11 @@ impl PythonActorMesh {
Ok(monitor_instance.into_py(py))
}

#[pyo3(signature = (**kwargs))]
fn slice(&self, kwargs: Option<&Bound<'_, PyDict>>) -> PyResult<PythonActorMeshRef> {
self.bind()?.slice(kwargs)
}

#[getter]
pub fn client(&self) -> PyMailbox {
self.client.clone()
Expand Down Expand Up @@ -222,6 +229,75 @@ impl PythonActorMeshRef {
Ok(())
}

#[pyo3(signature = (**kwargs))]
fn slice(&self, kwargs: Option<&Bound<'_, PyDict>>) -> PyResult<Self> {
// When the input type is `int`, convert it into `ndslice::Range`.
fn convert_int(index: isize) -> PyResult<ndslice::Range> {
if index < 0 {
return Err(PyException::new_err(format!(
"does not support negative index in selection: {}",
index
)));
}
Ok(ndslice::Range::from(index as usize))
}

// When the input type is `slice`, convert it into `ndslice::Range`.
fn convert_py_slice<'py>(s: &Bound<'py, PySlice>) -> PyResult<ndslice::Range> {
fn get_attr<'py>(s: &Bound<'py, PySlice>, attr: &str) -> PyResult<Option<isize>> {
let v = s.getattr(attr)?.extract::<Option<isize>>()?;
if v.is_some() && v.unwrap() < 0 {
return Err(PyException::new_err(format!(
"does not support negative {} in slice: {}",
attr,
v.unwrap(),
)));
}
Ok(v)
}

let start = get_attr(s, "start")?.unwrap_or(0);
let stop: Option<isize> = get_attr(s, "stop")?;
let step = get_attr(s, "step")?.unwrap_or(1);
Ok(ndslice::Range(
start as usize,
stop.map(|s| s as usize),
step as usize,
))
}

if kwargs.is_none() || kwargs.unwrap().is_empty() {
return Err(PyException::new_err("selection cannot be empty"));
}

let mut sliced = self.inner.clone();

for entry in kwargs.unwrap().items() {
let label = entry.get_item(0)?.str()?;
let label_str = label.to_str()?;

let value = entry.get_item(1)?;

let range = if let Ok(index) = value.extract::<isize>() {
convert_int(index)?
} else if let Ok(s) = value.downcast::<PySlice>() {
convert_py_slice(s)?
} else {
return Err(PyException::new_err(
"selection only supports type int or slice",
));
};
sliced = sliced.select(label_str, range).map_err(|err| {
PyException::new_err(format!(
"failed to select label {}; error is: {}",
label_str, err
))
})?;
}

Ok(Self { inner: sliced })
}

#[getter]
fn shape(&self) -> PyShape {
PyShape::from(self.inner.shape().clone())
Expand Down
24 changes: 23 additions & 1 deletion python/monarch/_rust_bindings/monarch_hyperactor/actor_mesh.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

# pyre-strict

from collections.abc import Mapping
from typing import AsyncIterator, final

from monarch._rust_bindings.monarch_hyperactor.actor import PythonMessage
Expand All @@ -18,6 +17,7 @@ from monarch._rust_bindings.monarch_hyperactor.mailbox import (
from monarch._rust_bindings.monarch_hyperactor.proc import ActorId
from monarch._rust_bindings.monarch_hyperactor.selection import Selection
from monarch._rust_bindings.monarch_hyperactor.shape import Shape
from typing_extensions import Self

@final
class PythonActorMeshRef:
Expand All @@ -31,6 +31,12 @@ class PythonActorMeshRef:
"""Cast a message to the selected actors in the mesh."""
...

def slice(self, **kwargs: int | slice[int | None, int | None, int | None]) -> Self:
"""
See PythonActorMeshRef.slice for documentation.
"""
...

@property
def shape(self) -> Shape:
"""
Expand All @@ -53,6 +59,22 @@ class PythonActorMesh:
"""
Cast a message to the selected actors in the mesh.
"""
...

def slice(
self, **kwargs: int | slice[int | None, int | None, int | None]
) -> PythonActorMeshRef:
"""
Slice the mesh into a new mesh ref with the given selection. The reason
it returns a mesh ref, rather than the mesh object itself, is because
sliced mesh is a view of the original mesh, and does not own the mesh's
resources.

Arguments:
- `kwargs`: argument name is the label, and argument value is how to
slice the mesh along the dimension of that label.
"""
...

def get_supervision_event(self) -> ActorSupervisionEvent | None:
"""
Expand Down
Loading
Loading