Skip to content

Commit f3eb13a

Browse files
committed
Fix debugger test
D78585722 broke test_debug and somehow the tests don't run on diffs. This adds the necessary annotations to make it ok for the debugger code to block the existing event loop while it is doing debugging stuff. Differential Revision: [D78634155](https://our.internmc.facebook.com/intern/diff/D78634155/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D78634155/)! ghstack-source-id: 297383325 Pull Request resolved: #590
1 parent 1b0d02f commit f3eb13a

File tree

5 files changed

+54
-39
lines changed

5 files changed

+54
-39
lines changed

python/monarch/_src/actor/actor_mesh.py

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import traceback
1919

2020
from abc import ABC, abstractmethod
21-
from contextlib import contextmanager
2221

2322
from dataclasses import dataclass
2423
from operator import mul
@@ -80,6 +79,7 @@
8079
from monarch._src.actor.pickle import flatten, unflatten
8180

8281
from monarch._src.actor.shape import MeshTrait, NDSlice
82+
from monarch._src.actor.sync_state import fake_sync_state
8383

8484
if TYPE_CHECKING:
8585
from monarch._src.actor.proc_mesh import ProcMesh
@@ -686,16 +686,6 @@ def _process(self, msg: PythonMessage) -> Tuple[int, R]:
686686
# We do this by blanking out the running event loop during the call to the synchronous actor function.
687687

688688

689-
@contextmanager
690-
def fake_sync_state():
691-
prev_loop = asyncio.events._get_running_loop()
692-
asyncio._set_running_loop(None)
693-
try:
694-
yield
695-
finally:
696-
asyncio._set_running_loop(prev_loop)
697-
698-
699689
class _Actor:
700690
"""
701691
This is the message handling implementation of a Python actor.
@@ -828,16 +818,17 @@ def _post_mortem_debug(self, exc_tb) -> None:
828818
from monarch._src.actor.debugger import DebugManager
829819

830820
if (pdb_wrapper := DebugContext.get().pdb_wrapper) is not None:
831-
ctx = MonarchContext.get()
832-
pdb_wrapper = PdbWrapper(
833-
ctx.point.rank,
834-
ctx.point.shape.coordinates(ctx.point.rank),
835-
ctx.mailbox.actor_id,
836-
DebugManager.ref().get_debug_client.call_one().get(),
837-
)
838-
DebugContext.set(DebugContext(pdb_wrapper))
839-
pdb_wrapper.post_mortem(exc_tb)
840-
self._maybe_exit_debugger(do_continue=False)
821+
with fake_sync_state():
822+
ctx = MonarchContext.get()
823+
pdb_wrapper = PdbWrapper(
824+
ctx.point.rank,
825+
ctx.point.shape.coordinates(ctx.point.rank),
826+
ctx.mailbox.actor_id,
827+
DebugManager.ref().get_debug_client.call_one().get(),
828+
)
829+
DebugContext.set(DebugContext(pdb_wrapper))
830+
pdb_wrapper.post_mortem(exc_tb)
831+
self._maybe_exit_debugger(do_continue=False)
841832

842833

843834
def _is_mailbox(x: object) -> bool:

python/monarch/_src/actor/debugger.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
MonarchContext,
2525
)
2626
from monarch._src.actor.pdb_wrapper import DebuggerWrite, PdbWrapper
27+
from monarch._src.actor.sync_state import fake_sync_state
2728
from tabulate import tabulate
2829

2930

@@ -549,12 +550,14 @@ def remote_breakpointhook():
549550
"exists on both your client and worker processes."
550551
)
551552

553+
with fake_sync_state():
554+
manager = DebugManager.ref().get_debug_client.call_one().get()
552555
ctx = MonarchContext.get()
553556
pdb_wrapper = PdbWrapper(
554557
ctx.point.rank,
555558
ctx.point.shape.coordinates(ctx.point.rank),
556559
ctx.mailbox.actor_id,
557-
DebugManager.ref().get_debug_client.call_one().get(),
560+
manager,
558561
)
559562
DebugContext.set(DebugContext(pdb_wrapper))
560563
pdb_wrapper.set_trace(frame)

python/monarch/_src/actor/pdb_wrapper.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from typing import Dict, TYPE_CHECKING
1717

1818
from monarch._rust_bindings.monarch_hyperactor.proc import ActorId
19+
from monarch._src.actor.sync_state import fake_sync_state
1920

2021
if TYPE_CHECKING:
2122
from monarch._src.actor.debugger import DebugClient
@@ -47,9 +48,9 @@ def __init__(
4748
self._first = True
4849

4950
def set_trace(self, frame):
50-
self.client_ref.debugger_session_start.call_one(
51+
self.client_ref.debugger_session_start.broadcast(
5152
self.rank, self.coords, socket.getfqdn(socket.gethostname()), self.actor_id
52-
).get()
53+
)
5354
if self.header:
5455
self.message(self.header)
5556
super().set_trace(frame)
@@ -66,7 +67,7 @@ def do_clear(self, arg):
6667
super().do_clear(arg)
6768

6869
def end_debug_session(self):
69-
self.client_ref.debugger_session_end.call_one(self.rank).get()
70+
self.client_ref.debugger_session_end.broadcast(self.rank)
7071
# Once the debug client actor is notified of the session being over,
7172
# we need to prevent any additional requests being sent for the session
7273
# by redirecting stdin and stdout.
@@ -85,16 +86,19 @@ def __init__(self, session: "PdbWrapper"):
8586
self.session = session
8687

8788
def readinto(self, b):
88-
response = self.session.client_ref.debugger_read.call_one(
89-
self.session.rank, len(b)
90-
).get()
91-
if response == "detach":
92-
# this gets injected by the worker event loop to
93-
# get the worker thread to exit on an Exit command.
94-
raise bdb.BdbQuit
95-
assert isinstance(response, DebuggerWrite) and len(response.payload) <= len(b)
96-
b[: len(response.payload)] = response.payload
97-
return len(response.payload)
89+
with fake_sync_state():
90+
response = self.session.client_ref.debugger_read.call_one(
91+
self.session.rank, len(b)
92+
).get()
93+
if response == "detach":
94+
# this gets injected by the worker event loop to
95+
# get the worker thread to exit on an Exit command.
96+
raise bdb.BdbQuit
97+
assert isinstance(response, DebuggerWrite) and len(response.payload) <= len(
98+
b
99+
)
100+
b[: len(response.payload)] = response.payload
101+
return len(response.payload)
98102

99103
def readable(self) -> bool:
100104
return True
@@ -119,14 +123,14 @@ def write(self, s: str):
119123
function = f"{inspect.getmodulename(self.session.curframe.f_code.co_filename)}.{self.session.curframe.f_code.co_name}"
120124
# pyre-ignore
121125
lineno = self.session.curframe.f_lineno
122-
self.session.client_ref.debugger_write.call_one(
126+
self.session.client_ref.debugger_write.broadcast(
123127
self.session.rank,
124128
DebuggerWrite(
125129
s.encode(),
126130
function,
127131
lineno,
128132
),
129-
).get()
133+
)
130134

131135
def flush(self):
132136
pass
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import asyncio
8+
from contextlib import contextmanager
9+
10+
11+
@contextmanager
12+
def fake_sync_state():
13+
prev_loop = asyncio.events._get_running_loop()
14+
asyncio._set_running_loop(None)
15+
try:
16+
yield
17+
finally:
18+
asyncio._set_running_loop(prev_loop)

python/tests/test_allocator.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,13 @@
3333
ChannelTransport,
3434
)
3535

36-
from monarch._src.actor.actor_mesh import fake_sync_state
37-
3836
from monarch._src.actor.allocator import (
3937
ALLOC_LABEL_PROC_MESH_NAME,
4038
RemoteAllocator,
4139
StaticRemoteAllocInitializer,
4240
TorchXRemoteAllocInitializer,
4341
)
42+
from monarch._src.actor.sync_state import fake_sync_state
4443
from monarch.actor import (
4544
Actor,
4645
current_rank,

0 commit comments

Comments
 (0)