diff --git a/monarch_hyperactor/src/telemetry.rs b/monarch_hyperactor/src/telemetry.rs index 90871232e..d4ca8a1b5 100644 --- a/monarch_hyperactor/src/telemetry.rs +++ b/monarch_hyperactor/src/telemetry.rs @@ -8,45 +8,12 @@ #![allow(unsafe_op_in_unsafe_fn)] -use std::cell::Cell; - use hyperactor::clock::ClockKind; use hyperactor::clock::RealClock; use hyperactor::clock::SimClock; use hyperactor_telemetry::swap_telemetry_clock; use pyo3::prelude::*; use pyo3::types::PyTraceback; -use tracing::span::EnteredSpan; -// Thread local to store the current span -thread_local! { - static ACTIVE_ACTOR_SPAN: Cell> = const { Cell::new(None) }; -} - -/// Enter the span stored in the thread local -#[pyfunction] -pub fn enter_span(module_name: String, method_name: String, actor_id: String) -> PyResult<()> { - let mut maybe_span = ACTIVE_ACTOR_SPAN.take(); - if maybe_span.is_none() { - maybe_span = Some( - tracing::info_span!( - "py_actor_method", - name = method_name, - target = module_name, - actor_id = actor_id - ) - .entered(), - ); - } - ACTIVE_ACTOR_SPAN.set(maybe_span); - Ok(()) -} - -/// Exit the span stored in the thread local -#[pyfunction] -pub fn exit_span() -> PyResult<()> { - ACTIVE_ACTOR_SPAN.replace(None); - Ok(()) -} /// Get the current span ID from the active span #[pyfunction] @@ -122,8 +89,17 @@ struct PySpan { #[pymethods] impl PySpan { #[new] - fn new(name: &str) -> Self { - let span = tracing::span!(tracing::Level::DEBUG, "python.span", name = name); + fn new(name: &str, actor_id: Option<&str>) -> Self { + let span = if let Some(actor_id) = actor_id { + tracing::span!( + tracing::Level::DEBUG, + "python.span", + name = name, + actor_id = actor_id + ) + } else { + tracing::span!(tracing::Level::DEBUG, "python.span", name = name) + }; let entered_span = span.entered(); Self { span: entered_span } @@ -147,20 +123,6 @@ pub fn register_python_bindings(module: &Bound<'_, PyModule>) -> PyResult<()> { module.add_function(f)?; // Register the span-related functions - let enter_span_fn = wrap_pyfunction!(enter_span, module)?; - enter_span_fn.setattr( - "__module__", - "monarch._rust_bindings.monarch_hyperactor.telemetry", - )?; - module.add_function(enter_span_fn)?; - - let exit_span_fn = wrap_pyfunction!(exit_span, module)?; - exit_span_fn.setattr( - "__module__", - "monarch._rust_bindings.monarch_hyperactor.telemetry", - )?; - module.add_function(exit_span_fn)?; - let get_current_span_id_fn = wrap_pyfunction!(get_current_span_id, module)?; get_current_span_id_fn.setattr( "__module__", diff --git a/python/monarch/_rust_bindings/monarch_hyperactor/telemetry.pyi b/python/monarch/_rust_bindings/monarch_hyperactor/telemetry.pyi index a9c6fc871..4f42f0cf2 100644 --- a/python/monarch/_rust_bindings/monarch_hyperactor/telemetry.pyi +++ b/python/monarch/_rust_bindings/monarch_hyperactor/telemetry.pyi @@ -28,35 +28,6 @@ def forward_to_tracing(record: logging.LogRecord) -> None: """ ... -def enter_span(module_name: str, method_name: str, actor_id: str) -> None: - """ - Enter a tracing span for a Python actor method. - - Creates and enters a new tracing span for the current thread that tracks - execution of a Python actor method. The span is stored in thread-local - storage and will be active until exit_span() is called. - - If a span is already active for the current thread, this function will - preserve that span and not create a new one. - - Args: - - module_name (str): The name of the module containing the actor (used as the target). - - method_name (str): The name of the method being called (used as the span name). - - actor_id (str): The ID of the actor instance (included as a field in the span). - """ - ... - -def exit_span() -> None: - """ - Exit the current tracing span for a Python actor method. - - Exits and drops the tracing span that was previously created by enter_span(). - This should be called when the actor method execution is complete. - - If no span is currently active for this thread, this function has no effect. - """ - ... - def get_current_span_id() -> int: """ Get the current span ID from the active span. @@ -87,12 +58,14 @@ def use_sim_clock() -> None: ... class PySpan: - def __init__(self, name: str) -> None: + def __init__(self, name: str, actor_id: str | None = None) -> None: """ Create a new PySpan. Args: - name (str): The name of the span. + - actor_id (str | None, optional): The actor ID associated with the span. + If None, Rust will handle actor identification automatically. """ ... @@ -101,3 +74,13 @@ class PySpan: Exit the span. """ ... + + @property + def actor_id(self) -> str | None: + """ + Get the actor ID associated with this span. + + Returns: + - str | None: The actor ID, or None if not set. + """ + ... diff --git a/python/monarch/_src/actor/actor_mesh.py b/python/monarch/_src/actor/actor_mesh.py index 4cd85d697..143bc1dab 100644 --- a/python/monarch/_src/actor/actor_mesh.py +++ b/python/monarch/_src/actor/actor_mesh.py @@ -55,7 +55,6 @@ from monarch._rust_bindings.monarch_hyperactor.proc import ActorId from monarch._rust_bindings.monarch_hyperactor.shape import Point as HyPoint, Shape -from monarch._rust_bindings.monarch_hyperactor.telemetry import enter_span, exit_span from monarch._src.actor.allocator import LocalAllocator, ProcessAllocator from monarch._src.actor.future import Future from monarch._src.actor.pdb_wrapper import PdbWrapper @@ -63,9 +62,12 @@ from monarch._src.actor.pickle import flatten, unpickle from monarch._src.actor.shape import MeshTrait, NDSlice +from monarch._src.actor.telemetry.rust_span_tracing import get_monarch_tracer logger: logging.Logger = logging.getLogger(__name__) +TRACER = get_monarch_tracer() + Allocator = ProcessAllocator | LocalAllocator try: @@ -668,31 +670,28 @@ async def handle( if inspect.iscoroutinefunction(the_method): async def instrumented(): - enter_span( - the_method.__module__, + with TRACER.start_as_current_span( message.method, - str(ctx.mailbox.actor_id), - ) - try: - result = await the_method(self.instance, *args, **kwargs) - self._maybe_exit_debugger() - except Exception as e: - logging.critical( - "Unhandled exception in actor endpoint", - exc_info=e, - ) - raise e - exit_span() - return result + attributes={"actor_id": str(ctx.mailbox.actor_id)}, + ): + try: + result = await the_method(self.instance, *args, **kwargs) + self._maybe_exit_debugger() + except Exception as e: + logging.critical( + "Unhandled exception in actor endpoint", + exc_info=e, + ) + raise e + return result result = await instrumented() else: - enter_span( - the_method.__module__, message.method, str(ctx.mailbox.actor_id) - ) - result = the_method(self.instance, *args, **kwargs) - self._maybe_exit_debugger() - exit_span() + with TRACER.start_as_current_span( + message.method, attributes={"actor_id": str(ctx.mailbox.actor_id)} + ): + result = the_method(self.instance, *args, **kwargs) + self._maybe_exit_debugger() if port is not None: port.send("result", result) @@ -759,6 +758,10 @@ def logger(cls) -> logging.Logger: lgr.setLevel(logging.DEBUG) return lgr + @property + def tracer(self): + return TRACER + @property def _ndslice(self) -> NDSlice: raise NotImplementedError( diff --git a/python/monarch/_src/actor/telemetry/rust_span_tracing.py b/python/monarch/_src/actor/telemetry/rust_span_tracing.py index b6716dfd9..7f0c1c6e0 100644 --- a/python/monarch/_src/actor/telemetry/rust_span_tracing.py +++ b/python/monarch/_src/actor/telemetry/rust_span_tracing.py @@ -85,7 +85,8 @@ def start_span( record_exception: bool = True, set_status_on_exception: bool = True, ) -> trace.Span: - return SpanWrapper(name) + actor_id = str(attributes.get("actor_id")) if attributes else None + return SpanWrapper(name, actor_id) @contextmanager # pyre-fixme[15]: `start_as_current_span` overrides method defined in `Tracer` @@ -102,7 +103,9 @@ def start_as_current_span( set_status_on_exception: bool = True, end_on_exit: bool = True, ) -> Iterator[trace.Span]: - with SpanWrapper(name) as s: + actor_id = str(attributes.get("actor_id")) if attributes else None + + with SpanWrapper(name, actor_id) as s: with trace.use_span(s): yield s