|
| 1 | +/* |
| 2 | + * Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | + * All rights reserved. |
| 4 | + * |
| 5 | + * This source code is licensed under the BSD-style license found in the |
| 6 | + * LICENSE file in the root directory of this source tree. |
| 7 | + */ |
| 8 | + |
| 9 | +use anyhow::Result; |
| 10 | +use async_trait::async_trait; |
| 11 | +use hyperactor::Actor; |
| 12 | +use hyperactor::Context; |
| 13 | +use hyperactor::Handler; |
| 14 | +use hyperactor::Named; |
| 15 | +use hyperactor::PortRef; |
| 16 | +use monarch_types::SerializablePyErr; |
| 17 | +use pyo3::prelude::*; |
| 18 | +use serde::Deserialize; |
| 19 | +use serde::Serialize; |
| 20 | + |
| 21 | +/// Message to trigger module reloading |
| 22 | +#[derive(Debug, Clone, Named, Serialize, Deserialize)] |
| 23 | +pub struct AutoReloadMessage { |
| 24 | + pub result: PortRef<Result<(), String>>, |
| 25 | +} |
| 26 | + |
| 27 | +/// Parameters for creating an AutoReloadActor |
| 28 | +#[derive(Debug, Clone, Named, Serialize, Deserialize)] |
| 29 | +pub struct AutoReloadParams {} |
| 30 | + |
| 31 | +/// Simple Rust Actor that wraps the Python AutoReloader class via pyo3 |
| 32 | +#[derive(Debug)] |
| 33 | +#[hyperactor::export(spawn = true, handlers = [AutoReloadMessage])] |
| 34 | +pub struct AutoReloadActor { |
| 35 | + state: Result<(PyObject, PyObject), SerializablePyErr>, |
| 36 | +} |
| 37 | + |
| 38 | +#[async_trait] |
| 39 | +impl Actor for AutoReloadActor { |
| 40 | + type Params = AutoReloadParams; |
| 41 | + |
| 42 | + async fn new(Self::Params {}: Self::Params) -> Result<Self> { |
| 43 | + Ok(Self { |
| 44 | + state: tokio::task::spawn_blocking(move || { |
| 45 | + Python::with_gil(|py| { |
| 46 | + Self::create_state(py).map_err(SerializablePyErr::from_fn(py)) |
| 47 | + }) |
| 48 | + }) |
| 49 | + .await?, |
| 50 | + }) |
| 51 | + } |
| 52 | +} |
| 53 | + |
| 54 | +impl AutoReloadActor { |
| 55 | + fn create_state(py: Python) -> PyResult<(PyObject, PyObject)> { |
| 56 | + // Import the Python AutoReloader class |
| 57 | + let auto_reload_module = py.import("monarch._src.actor.code_sync.auto_reload")?; |
| 58 | + let auto_reloader_class = auto_reload_module.getattr("AutoReloader")?; |
| 59 | + |
| 60 | + let reloader = auto_reloader_class.call0()?; |
| 61 | + |
| 62 | + // Install the audit import hook: SysAuditImportHook.install(reloader.import_callback) |
| 63 | + let sys_audit_import_hook_class = auto_reload_module.getattr("SysAuditImportHook")?; |
| 64 | + let import_callback = reloader.getattr("import_callback")?; |
| 65 | + let hook_guard = sys_audit_import_hook_class.call_method1("install", (import_callback,))?; |
| 66 | + |
| 67 | + Ok((reloader.into(), hook_guard.into())) |
| 68 | + } |
| 69 | + |
| 70 | + fn reload(py: Python, py_reloader: PyObject) -> PyResult<()> { |
| 71 | + let reloader = py_reloader.bind(py); |
| 72 | + let changed_modules: Vec<String> = reloader.call_method0("reload_changes")?.extract()?; |
| 73 | + if !changed_modules.is_empty() { |
| 74 | + eprintln!("reloaded modules: {:?}", changed_modules); |
| 75 | + } |
| 76 | + Ok(()) |
| 77 | + } |
| 78 | +} |
| 79 | + |
| 80 | +#[async_trait] |
| 81 | +impl Handler<AutoReloadMessage> for AutoReloadActor { |
| 82 | + async fn handle( |
| 83 | + &mut self, |
| 84 | + cx: &Context<Self>, |
| 85 | + AutoReloadMessage { result }: AutoReloadMessage, |
| 86 | + ) -> Result<()> { |
| 87 | + // Call the Python reloader's reload_changes method |
| 88 | + let res = async { |
| 89 | + let py_reloader: PyObject = self.state.as_ref().map_err(Clone::clone)?.0.clone(); |
| 90 | + tokio::task::spawn_blocking(move || { |
| 91 | + Python::with_gil(|py| { |
| 92 | + Self::reload(py, py_reloader).map_err(SerializablePyErr::from_fn(py)) |
| 93 | + }) |
| 94 | + }) |
| 95 | + .await??; |
| 96 | + anyhow::Ok(()) |
| 97 | + } |
| 98 | + .await; |
| 99 | + result.send(cx, res.map_err(|e| format!("{:#?}", e)))?; |
| 100 | + Ok(()) |
| 101 | + } |
| 102 | +} |
| 103 | + |
| 104 | +#[cfg(test)] |
| 105 | +mod tests { |
| 106 | + use anyhow::anyhow; |
| 107 | + use hyperactor_mesh::actor_mesh::ActorMesh; |
| 108 | + use hyperactor_mesh::alloc::AllocSpec; |
| 109 | + use hyperactor_mesh::alloc::Allocator; |
| 110 | + use hyperactor_mesh::alloc::local::LocalAllocator; |
| 111 | + use hyperactor_mesh::mesh::Mesh; |
| 112 | + use hyperactor_mesh::proc_mesh::ProcMesh; |
| 113 | + use ndslice::shape; |
| 114 | + use pyo3::ffi::c_str; |
| 115 | + use tempfile::TempDir; |
| 116 | + use tokio::fs; |
| 117 | + |
| 118 | + use super::*; |
| 119 | + |
| 120 | + #[tokio::test] |
| 121 | + async fn test_auto_reload_actor() -> Result<()> { |
| 122 | + pyo3::prepare_freethreaded_python(); |
| 123 | + Python::with_gil(|py| py.run(c_str!("import monarch._rust_bindings"), None, None))?; |
| 124 | + |
| 125 | + // Create a temporary directory for Python files |
| 126 | + let temp_dir = TempDir::new()?; |
| 127 | + let py_file_path = temp_dir.path().join("test_module.py"); |
| 128 | + |
| 129 | + // Create initial Python file content |
| 130 | + let initial_content = r#" |
| 131 | +# Test module for auto-reload |
| 132 | +def get_value(): |
| 133 | + return "initial_value" |
| 134 | +
|
| 135 | +CONSTANT = "initial_constant" |
| 136 | +"#; |
| 137 | + fs::write(&py_file_path, initial_content).await?; |
| 138 | + |
| 139 | + // Set up a single AutoReloadActor |
| 140 | + let alloc = LocalAllocator |
| 141 | + .allocate(AllocSpec { |
| 142 | + shape: shape! { replica = 1 }, |
| 143 | + constraints: Default::default(), |
| 144 | + }) |
| 145 | + .await?; |
| 146 | + |
| 147 | + let proc_mesh = ProcMesh::allocate(alloc).await?; |
| 148 | + let params = AutoReloadParams {}; |
| 149 | + let actor_mesh = proc_mesh |
| 150 | + .spawn::<AutoReloadActor>("auto_reload_test", ¶ms) |
| 151 | + .await?; |
| 152 | + |
| 153 | + // Get a reference to the single actor |
| 154 | + let actor_ref = actor_mesh |
| 155 | + .get(0) |
| 156 | + .ok_or_else(|| anyhow!("No actor at index 0"))?; |
| 157 | + let mailbox = actor_mesh.proc_mesh().client(); |
| 158 | + |
| 159 | + // First, we need to import the module to get it tracked by the AutoReloader |
| 160 | + // We'll do this by running Python code that imports our test module |
| 161 | + let temp_path = temp_dir.path().to_path_buf(); |
| 162 | + let import_result = tokio::task::spawn_blocking({ |
| 163 | + move || { |
| 164 | + Python::with_gil(|py| -> PyResult<String> { |
| 165 | + // Add the temp directory to Python path |
| 166 | + let sys = py.import("sys")?; |
| 167 | + let path = sys.getattr("path")?; |
| 168 | + let path_list = path.downcast::<pyo3::types::PyList>()?; |
| 169 | + path_list.insert(0, temp_path.to_string_lossy())?; |
| 170 | + |
| 171 | + // Import the test module |
| 172 | + let test_module = py.import("test_module")?; |
| 173 | + let get_value_func = test_module.getattr("get_value")?; |
| 174 | + let initial_value: String = get_value_func.call0()?.extract()?; |
| 175 | + |
| 176 | + Ok(initial_value) |
| 177 | + }) |
| 178 | + } |
| 179 | + }) |
| 180 | + .await??; |
| 181 | + |
| 182 | + // Verify we got the initial value |
| 183 | + assert_eq!(import_result, "initial_value"); |
| 184 | + println!("Initial import successful, got: {}", import_result); |
| 185 | + |
| 186 | + // Now modify the Python file |
| 187 | + let modified_content = r#" |
| 188 | +# Test module for auto-reload (MODIFIED) |
| 189 | +def get_value(): |
| 190 | + return "modified_value" |
| 191 | +
|
| 192 | +CONSTANT = "modified_constant" |
| 193 | +"#; |
| 194 | + fs::write(&py_file_path, modified_content).await?; |
| 195 | + println!("Modified Python file"); |
| 196 | + |
| 197 | + // Send AutoReloadMessage to trigger reload |
| 198 | + let (result_tx, mut result_rx) = mailbox.open_port::<Result<(), String>>(); |
| 199 | + actor_ref.send( |
| 200 | + &mailbox, |
| 201 | + AutoReloadMessage { |
| 202 | + result: result_tx.bind(), |
| 203 | + }, |
| 204 | + )?; |
| 205 | + |
| 206 | + // Wait for reload to complete |
| 207 | + let reload_result = result_rx.recv().await?; |
| 208 | + reload_result.map_err(|e| anyhow!("Reload failed: {}", e))?; |
| 209 | + println!("Auto-reload completed successfully"); |
| 210 | + |
| 211 | + // Now import the module again and verify the changes were propagated |
| 212 | + let final_result = tokio::task::spawn_blocking({ |
| 213 | + move || { |
| 214 | + Python::with_gil(|py| -> PyResult<String> { |
| 215 | + // Re-import the test module (it should be reloaded now) |
| 216 | + let test_module = py.import("test_module")?; |
| 217 | + let get_value_func = test_module.getattr("get_value")?; |
| 218 | + let final_value: String = get_value_func.call0()?.extract()?; |
| 219 | + |
| 220 | + Ok(final_value) |
| 221 | + }) |
| 222 | + } |
| 223 | + }) |
| 224 | + .await??; |
| 225 | + |
| 226 | + // Verify that the changes were propagated |
| 227 | + assert_eq!(final_result, "modified_value"); |
| 228 | + println!("Final import successful, got: {}", final_result); |
| 229 | + |
| 230 | + // Verify that the module was actually reloaded by checking if we get the new value |
| 231 | + assert_ne!(import_result, final_result); |
| 232 | + println!("Auto-reload test completed successfully - module was reloaded!"); |
| 233 | + |
| 234 | + Ok(()) |
| 235 | + } |
| 236 | +} |
0 commit comments