Skip to content

Commit 847e8f7

Browse files
pzhan9facebook-github-bot
authored andcommitted
Update Handler macros to use context::Actor (#1349)
Summary: Pull Request resolved: #1349 Motivation: In `Rust V1`, we will make the sequence assignment logic an actor-level logic. Specifically, the sender actor is responsible for assigning its message's seq numbers. The implementation plan is to encapsulate that logic inside sender's `context::Actor` object (see D82983004). As a result, we need to surface `context::Actor` to all the send call sites, so the sender's sequencing capability is accessible. These callsites include `OncePort/PortRef.send`, and `PortHandle.send`, `ActorHandle.send`, `ActorRef.send`, and the corresponding callsites on the python side. For this diff, it specifically updates the macros. Reviewed By: mariusae Differential Revision: D83371710 fbshipit-source-id: 994c776c93b449e57e04ef83c38ce9badaefc5f3
1 parent 7cd4311 commit 847e8f7

File tree

8 files changed

+136
-65
lines changed

8 files changed

+136
-65
lines changed

hyperactor_macros/src/lib.rs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -737,14 +737,14 @@ pub fn derive_handler(input: TokenStream) -> TokenStream {
737737
#[doc = "The generated client method for this enum variant."]
738738
async fn #variant_name_snake(
739739
&self,
740-
cx: &impl hyperactor::context::Mailbox,
740+
cx: &impl hyperactor::context::Actor,
741741
#(#arg_names: #arg_types),*)
742742
-> Result<#return_type, hyperactor::anyhow::Error>;
743743

744744
#[doc = "The DEPRECATED DO NOT USE generated client method for this enum variant."]
745745
async fn #variant_name_snake_deprecated(
746746
&self,
747-
cx: &impl hyperactor::context::Mailbox,
747+
cx: &impl hyperactor::context::Actor,
748748
#(#arg_names: #arg_types),*)
749749
-> Result<#return_type, hyperactor::anyhow::Error>;
750750
});
@@ -811,14 +811,14 @@ pub fn derive_handler(input: TokenStream) -> TokenStream {
811811
#[doc = "The generated client method for this enum variant."]
812812
async fn #variant_name_snake(
813813
&self,
814-
cx: &impl hyperactor::context::Mailbox,
814+
cx: &impl hyperactor::context::Actor,
815815
#(#arg_names: #arg_types),*)
816816
-> Result<(), hyperactor::anyhow::Error>;
817817

818818
#[doc = "The DEPRECATED DO NOT USE generated client method for this enum variant."]
819819
async fn #variant_name_snake_deprecated(
820820
&self,
821-
cx: &impl hyperactor::context::Mailbox,
821+
cx: &impl hyperactor::context::Actor,
822822
#(#arg_names: #arg_types),*)
823823
-> Result<(), hyperactor::anyhow::Error>;
824824
});
@@ -963,7 +963,7 @@ fn derive_client(input: TokenStream, is_handle: bool) -> TokenStream {
963963
#[hyperactor::instrument(level=#log_level, rpc = "call", message_type=#name)]
964964
async fn #variant_name_snake(
965965
&self,
966-
cx: &impl hyperactor::context::Mailbox,
966+
cx: &impl hyperactor::context::Actor,
967967
#(#arg_names: #arg_types),*)
968968
-> Result<#return_type, hyperactor::anyhow::Error> {
969969
let (#reply_port_arg, #rx_mod reply_receiver) =
@@ -977,7 +977,7 @@ fn derive_client(input: TokenStream, is_handle: bool) -> TokenStream {
977977
#[hyperactor::instrument(level=#log_level, rpc = "call", message_type=#name)]
978978
async fn #variant_name_snake_deprecated(
979979
&self,
980-
cx: &impl hyperactor::context::Mailbox,
980+
cx: &impl hyperactor::context::Actor,
981981
#(#arg_names: #arg_types),*)
982982
-> Result<#return_type, hyperactor::anyhow::Error> {
983983
let (#reply_port_arg, #rx_mod reply_receiver) =
@@ -993,7 +993,7 @@ fn derive_client(input: TokenStream, is_handle: bool) -> TokenStream {
993993
#[hyperactor::instrument(level=#log_level, rpc="call", message_type=#name)]
994994
async fn #variant_name_snake(
995995
&self,
996-
cx: &impl hyperactor::context::Mailbox,
996+
cx: &impl hyperactor::context::Actor,
997997
#(#arg_names: #arg_types),*)
998998
-> Result<#return_type, hyperactor::anyhow::Error> {
999999
let (#reply_port_arg, #rx_mod reply_receiver) =
@@ -1008,7 +1008,7 @@ fn derive_client(input: TokenStream, is_handle: bool) -> TokenStream {
10081008
#[hyperactor::instrument(level=#log_level, rpc="call", message_type=#name)]
10091009
async fn #variant_name_snake_deprecated(
10101010
&self,
1011-
cx: &impl hyperactor::context::Mailbox,
1011+
cx: &impl hyperactor::context::Actor,
10121012
#(#arg_names: #arg_types),*)
10131013
-> Result<#return_type, hyperactor::anyhow::Error> {
10141014
let (#reply_port_arg, #rx_mod reply_receiver) =
@@ -1054,7 +1054,7 @@ fn derive_client(input: TokenStream, is_handle: bool) -> TokenStream {
10541054
impl_methods.push(quote! {
10551055
async fn #variant_name_snake(
10561056
&self,
1057-
cx: &impl hyperactor::context::Mailbox,
1057+
cx: &impl hyperactor::context::Actor,
10581058
#(#arg_names: #arg_types),*)
10591059
-> Result<(), hyperactor::anyhow::Error> {
10601060
let message = #constructor;
@@ -1065,7 +1065,7 @@ fn derive_client(input: TokenStream, is_handle: bool) -> TokenStream {
10651065

10661066
async fn #variant_name_snake_deprecated(
10671067
&self,
1068-
cx: &impl hyperactor::context::Mailbox,
1068+
cx: &impl hyperactor::context::Actor,
10691069
#(#arg_names: #arg_types),*)
10701070
-> Result<(), hyperactor::anyhow::Error> {
10711071
let message = #constructor;

hyperactor_macros/tests/basic.rs

Lines changed: 69 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,46 @@ enum TestVariantForms {
7878
},
7979
}
8080

81+
#[derive(Debug, Default, Actor)]
82+
#[hyperactor::export(handlers = [TestVariantForms])]
83+
struct TestVariantFormsActor {}
84+
85+
#[async_trait]
86+
#[forward(TestVariantForms)]
87+
impl TestVariantFormsHandler for TestVariantFormsActor {
88+
async fn one_way_struct(&mut self, _cx: &Context<Self>, _a: u64, _b: u64) -> Result<()> {
89+
Ok(())
90+
}
91+
92+
async fn one_way_tuple(&mut self, _cx: &Context<Self>, _a: u64, _b: u64) -> Result<()> {
93+
Ok(())
94+
}
95+
96+
async fn one_way_tuple_no_args(&mut self, _cx: &Context<Self>) -> Result<()> {
97+
Ok(())
98+
}
99+
100+
async fn one_way_struct_no_args(&mut self, _cx: &Context<Self>) -> Result<()> {
101+
Ok(())
102+
}
103+
104+
async fn call_struct(&mut self, _cx: &Context<Self>, a: u64) -> Result<u64> {
105+
Ok(a)
106+
}
107+
108+
async fn call_tuple(&mut self, _cx: &Context<Self>, a: u64) -> Result<u64> {
109+
Ok(a)
110+
}
111+
112+
async fn call_tuple_no_args(&mut self, _cx: &Context<Self>) -> Result<u64> {
113+
Ok(0)
114+
}
115+
116+
async fn call_struct_no_args(&mut self, _cx: &Context<Self>) -> Result<u64> {
117+
Ok(0)
118+
}
119+
}
120+
81121
#[instrument(fields(name = 4))]
82122
async fn yolo() -> Result<i32, i32> {
83123
Ok(10)
@@ -88,11 +128,6 @@ async fn yeet() -> String {
88128
String::from("cawwww")
89129
}
90130

91-
#[test]
92-
fn basic() {
93-
// nothing, just checks whether this file will compile
94-
}
95-
96131
#[derive(Debug, Handler, HandleClient)]
97132
enum GenericArgMessage<A: Clone + Sync + Send + Debug + 'static> {
98133
Variant(A),
@@ -146,3 +181,32 @@ struct SimpleStructMessage {
146181
field1: String,
147182
field2: u32,
148183
}
184+
185+
#[cfg(test)]
186+
mod tests {
187+
use hyperactor::proc::Proc;
188+
use timed_test::async_timed_test;
189+
190+
use super::*;
191+
192+
#[test]
193+
fn basic() {
194+
// nothing, just checks whether this file will compile
195+
}
196+
197+
// Verify it compiles
198+
#[async_timed_test(timeout_secs = 30)]
199+
async fn test_client_macros() {
200+
let proc = Proc::local();
201+
let (client, _) = proc.instance("client").unwrap();
202+
let actor_handle = proc
203+
.spawn::<TestVariantFormsActor>("foo", ())
204+
.await
205+
.unwrap();
206+
207+
assert_eq!(actor_handle.call_struct(&client, 10).await.unwrap(), 10,);
208+
209+
let actor_ref = actor_handle.bind::<TestVariantFormsActor>();
210+
assert_eq!(actor_ref.call_struct(&client, 10).await.unwrap(), 10,);
211+
}
212+
}

monarch_hyperactor/src/context.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ use crate::proc::PyActorId;
1818
use crate::runtime;
1919
use crate::shape::PyPoint;
2020

21-
pub(crate) enum ContextInstance {
21+
pub enum ContextInstance {
2222
Client(hyperactor::Instance<()>),
2323
PythonActor(hyperactor::Instance<PythonActor>),
2424
}
@@ -100,7 +100,7 @@ impl PyInstance {
100100
}
101101

102102
impl PyInstance {
103-
pub(crate) fn context_instance(&self) -> &ContextInstance {
103+
pub fn context_instance(&self) -> &ContextInstance {
104104
&self.inner
105105
}
106106
}

monarch_rdma/examples/parameter_server/src/parameter_server.rs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -368,9 +368,7 @@ impl Handler<WorkerStep> for WorkerActor {
368368
)
369369
.await?;
370370

371-
buffer
372-
.read_into(cx.mailbox(), ps_grad_handle.clone(), 5)
373-
.await?;
371+
buffer.read_into(cx, ps_grad_handle.clone(), 5).await?;
374372

375373
self.local_gradients.fill(0);
376374

@@ -409,9 +407,7 @@ impl Handler<WorkerUpdate> for WorkerActor {
409407
.ps_weights_handle
410408
.as_ref()
411409
.expect("worker_actor should be initialized");
412-
buffer
413-
.write_from(cx.mailbox(), ps_weights_handle.clone(), 5)
414-
.await?;
410+
buffer.write_from(cx, ps_weights_handle.clone(), 5).await?;
415411
reply.send(cx, true)?;
416412
Ok(())
417413
}

monarch_rdma/extension/lib.rs

Lines changed: 40 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ use hyperactor::Named;
1313
use hyperactor::ProcId;
1414
use hyperactor_mesh::RootActorMesh;
1515
use hyperactor_mesh::shared_cell::SharedCell;
16-
use monarch_hyperactor::mailbox::PyMailbox;
16+
use monarch_hyperactor::context::PyInstance;
17+
use monarch_hyperactor::instance_dispatch;
1718
use monarch_hyperactor::proc_mesh::PyProcMesh;
1819
use monarch_hyperactor::pytokio::PyPythonTask;
1920
use monarch_hyperactor::runtime::signal_safe_block_on;
@@ -52,17 +53,18 @@ async fn create_rdma_buffer(
5253
addr: usize,
5354
size: usize,
5455
proc_id: ProcId,
55-
client: PyMailbox,
56+
client: PyInstance,
5657
) -> PyResult<PyRdmaBuffer> {
5758
// Get the owning RdmaManagerActor's ActorRef
5859
let owner_id = ActorId(proc_id, "rdma_manager".to_string(), 0);
5960
let owner_ref: ActorRef<RdmaManagerActor> = ActorRef::attest(owner_id);
6061

61-
let caps = client.get_inner();
6262
// Create the RdmaBuffer
63-
let buffer = owner_ref
64-
.request_buffer_deprecated(caps, addr, size)
65-
.await?;
63+
let buffer = instance_dispatch!(client, |cx_instance| {
64+
owner_ref
65+
.request_buffer_deprecated(&cx_instance, addr, size)
66+
.await?
67+
});
6668
Ok(PyRdmaBuffer { buffer, owner_ref })
6769
}
6870

@@ -75,7 +77,7 @@ impl PyRdmaBuffer {
7577
addr: usize,
7678
size: usize,
7779
proc_id: String,
78-
client: PyMailbox,
80+
client: PyInstance,
7981
) -> PyResult<PyPythonTask> {
8082
if !ibverbs_supported() {
8183
return Err(PyException::new_err(
@@ -97,7 +99,7 @@ impl PyRdmaBuffer {
9799
addr: usize,
98100
size: usize,
99101
proc_id: String,
100-
client: PyMailbox,
102+
client: PyInstance,
101103
) -> PyResult<PyRdmaBuffer> {
102104
if !ibverbs_supported() {
103105
return Err(PyException::new_err(
@@ -135,7 +137,7 @@ impl PyRdmaBuffer {
135137
/// * `addr` - The address of the local buffer to read from
136138
/// * `size` - The size of the data to transfer
137139
/// * `local_proc_id` - The process ID where the local buffer resides
138-
/// * `client` - The mailbox for communication
140+
/// * `client` - The actor who does the reading.
139141
/// * `timeout` - Maximum time in milliseconds to wait for the operation
140142
#[pyo3(signature = (addr, size, local_proc_id, client, timeout))]
141143
fn read_into<'py>(
@@ -144,19 +146,24 @@ impl PyRdmaBuffer {
144146
addr: usize,
145147
size: usize,
146148
local_proc_id: String,
147-
client: PyMailbox,
149+
client: PyInstance,
148150
timeout: u64,
149151
) -> PyResult<PyPythonTask> {
150152
let (local_owner_ref, buffer) = setup_rdma_context(self, local_proc_id);
151153
PyPythonTask::new(async move {
152-
let caps = client.get_inner();
153-
let local_buffer = local_owner_ref
154-
.request_buffer_deprecated(caps, addr, size)
155-
.await?;
156-
let _result_ = local_buffer
157-
.write_from(caps, buffer, timeout)
158-
.await
159-
.map_err(|e| PyException::new_err(format!("failed to read into buffer: {}", e)))?;
154+
let local_buffer = instance_dispatch!(client, |cx_instance| {
155+
local_owner_ref
156+
.request_buffer_deprecated(cx_instance, addr, size)
157+
.await?
158+
});
159+
let _result_ = instance_dispatch!(client, |cx_instance| {
160+
local_buffer
161+
.write_from(cx_instance, buffer, timeout)
162+
.await
163+
.map_err(|e| {
164+
PyException::new_err(format!("failed to read into buffer: {}", e))
165+
})?
166+
});
160167
Ok(())
161168
})
162169
}
@@ -171,7 +178,7 @@ impl PyRdmaBuffer {
171178
/// * `addr` - The address of the local buffer to write to
172179
/// * `size` - The size of the data to transfer
173180
/// * `local_proc_id` - The process ID where the local buffer resides
174-
/// * `client` - The mailbox for communication
181+
/// * `client` - The actor who does the writing
175182
/// * `timeout` - Maximum time in milliseconds to wait for the operation
176183
#[pyo3(signature = (addr, size, local_proc_id, client, timeout))]
177184
fn write_from<'py>(
@@ -180,19 +187,24 @@ impl PyRdmaBuffer {
180187
addr: usize,
181188
size: usize,
182189
local_proc_id: String,
183-
client: PyMailbox,
190+
client: PyInstance,
184191
timeout: u64,
185192
) -> PyResult<PyPythonTask> {
186193
let (local_owner_ref, buffer) = setup_rdma_context(self, local_proc_id);
187194
PyPythonTask::new(async move {
188-
let caps = client.get_inner();
189-
let local_buffer = local_owner_ref
190-
.request_buffer_deprecated(caps, addr, size)
191-
.await?;
192-
let _result_ = local_buffer
193-
.read_into(caps, buffer, timeout)
194-
.await
195-
.map_err(|e| PyException::new_err(format!("failed to write from buffer: {}", e)))?;
195+
let local_buffer = instance_dispatch!(client, |cx_instance| {
196+
local_owner_ref
197+
.request_buffer_deprecated(cx_instance, addr, size)
198+
.await?
199+
});
200+
let _result_ = instance_dispatch!(&client, |cx_instance| {
201+
local_buffer
202+
.read_into(cx_instance, buffer, timeout)
203+
.await
204+
.map_err(|e| {
205+
PyException::new_err(format!("failed to write from buffer: {}", e))
206+
})?
207+
});
196208
Ok(())
197209
})
198210
}

0 commit comments

Comments
 (0)