Skip to content

Commit c1726bc

Browse files
committed
OPTIONAL COMMIT: Debugging prints
1 parent 831c142 commit c1726bc

File tree

5 files changed

+25
-5
lines changed

5 files changed

+25
-5
lines changed

megatron/core/inference/data_parallel_inference_coordinator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def start(self):
148148
)
149149
continue
150150

151-
# print(f"New client connected: {sender_identity}")
151+
logging.info(f"New client connected: {sender_identity}")
152152
known_clients.add(sender_identity)
153153
self.router_socket.send_multipart(
154154
[sender_identity, msgpack.packb([Headers.ACK.value], use_bin_type=True)]

megatron/core/inference/engines/dynamic_engine.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import torch
1919
from torch import Tensor
2020
from torch.cuda.nvtx import range_pop, range_push
21+
from megatron.core.utils import log_single_rank
2122

2223
from megatron.core import parallel_state
2324
from megatron.core.inference.contexts.dynamic_context import (
@@ -488,6 +489,7 @@ async def start_listening_to_data_parallel_coordinator(
488489

489490
# Finally run the engine infinite loop
490491
loop = get_asyncio_loop(loop)
492+
logging.info(f"Creating engine loop task on loop {id(loop)} on rank {torch.distributed.get_rank()}")
491493
self.engine_loop_task = loop.create_task(
492494
self.run_engine_with_coordinator(loop=loop, verbose=verbose)
493495
)
@@ -1246,6 +1248,7 @@ def schedule_requests(self) -> int:
12461248
request_id, prompt, sampling_params = data[1:]
12471249
sampling_params = SamplingParams.deserialize(sampling_params)
12481250
self.add_request(request_id, prompt, sampling_params)
1251+
logging.info(f"Added request {request_id} on rank {torch.distributed.get_rank()}")
12491252
elif header == Headers.PAUSE:
12501253
self.paused = True
12511254
elif header == Headers.UNPAUSE:
@@ -1259,6 +1262,9 @@ def schedule_requests(self) -> int:
12591262
else:
12601263
raise UnknownHeaderError(header)
12611264

1265+
if len(all_messages) > 0:
1266+
logging.info(f"Drained {len(all_messages)} messages from coordinator on rank {torch.distributed.get_rank()}")
1267+
12621268
return len(all_messages)
12631269

12641270
def stop(self):
@@ -1308,6 +1314,7 @@ async def run_engine_with_coordinator(
13081314
"""Continually steps the engine asynchronously."""
13091315
self._loop = get_asyncio_loop(loop)
13101316
try:
1317+
logging.info(f"Running engine with coordinator on rank {torch.distributed.get_rank()}")
13111318
while True:
13121319
self.schedule_requests()
13131320
if self.stopped:
@@ -1327,6 +1334,7 @@ async def run_engine_with_coordinator(
13271334
# todo [Siddharth]: Can this hardcoded sleep be avoided
13281335
# with asyncio zmq sockets?
13291336
if self.paused:
1337+
logging.info(f"Suspending engine on rank {torch.distributed.get_rank()}")
13301338
await asyncio.sleep(0.02)
13311339
continue
13321340

@@ -1344,11 +1352,14 @@ async def run_engine_with_coordinator(
13441352
self.context.get_active_request_count() == 0
13451353
and len(self.waiting_request_ids) == 0
13461354
):
1355+
logging.info(f"No requests to process on rank {torch.distributed.get_rank()}")
13471356
await asyncio.sleep(0.02)
13481357
continue
13491358

1350-
# Step.
1351-
engine_output = await self.async_step(verbose=verbose)
1359+
logging.info(f"Processing requests on rank {torch.distributed.get_rank()}")
1360+
logging.info(f"Active requests: {self.context.get_active_request_count()}")
1361+
logging.info(f"Waiting requests: {len(self.waiting_request_ids)}")
1362+
engine_output = await self.async_step(verbose=True)
13521363

13531364
# Send finished requests.
13541365
is_tp0_and_pp0 = (

megatron/core/inference/inference_client.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from megatron.core.inference.sampling_params import SamplingParams
1111
from megatron.core.utils import get_asyncio_loop, trace_async_exceptions
1212

13+
import torch.distributed as dist
14+
1315
from .headers import Headers
1416

1517
try:
@@ -99,6 +101,7 @@ def add_request(
99101
`DynamicInferenceRequestRecord` object containing the completed result.
100102
"""
101103
request_id = self.next_request_id
104+
logging.info(f"Adding request {request_id}")
102105
self.next_request_id += 1
103106
payload = [Headers.SUBMIT_REQUEST.value, request_id, prompt, sampling_params.serialize()]
104107
payload_serialized = msgpack.packb(payload, use_bin_type=True)
@@ -126,6 +129,7 @@ async def _listen_for_completed_requests(self):
126129
request_id
127130
)
128131
completion_future = self.completion_futures.pop(request_id)
132+
logging.info(f"Received reply for request {request_id}")
129133
completion_future.set_result(DynamicInferenceRequestRecord.deserialize(reply))
130134
except zmq.Again:
131135
await asyncio.sleep(0.005)

megatron/core/utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2087,14 +2087,18 @@ def maybe_cat(a, b, dim=0, *, required=False):
20872087
return None
20882088
return xs[0] if len(xs) == 1 else torch.cat(xs, dim=dim)
20892089

2090+
_ASYNC_IO_LOOP : asyncio.AbstractEventLoop | None = None
20902091

20912092
def get_asyncio_loop(loop: asyncio.AbstractEventLoop | None = None) -> asyncio.AbstractEventLoop:
20922093
"""Creates an asyncio loop if necessary and then returns the current asyncio loop."""
2094+
global _ASYNC_IO_LOOP
20932095
if loop is None:
2096+
if _ASYNC_IO_LOOP is not None:
2097+
return _ASYNC_IO_LOOP
20942098
try:
20952099
loop = asyncio.get_running_loop()
20962100
except RuntimeError as e:
2097-
loop = asyncio.new_event_loop()
2101+
_ASYNC_IO_LOOP = loop =asyncio.new_event_loop()
20982102
asyncio.set_event_loop(loop)
20992103
return loop
21002104

@@ -2130,7 +2134,6 @@ async def wrapper(*args, **kwargs):
21302134
try:
21312135
return await fn(*args, **kwargs)
21322136
except Exception as e:
2133-
logger.error(f"Exception in async function {fn.__name__}: {e}")
21342137
traceback.print_exc()
21352138
sys.exit(1)
21362139
finally:

megatron/rl/inference/megatron.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ async def base_generate(self, request: InferenceRequest):
177177
assert self._client is not None, "Client is not initialized"
178178

179179
tokenizer = get_tokenizer()
180+
print(f"Adding request to client on rank {dist.get_rank()}")
180181

181182
sampling_params = SamplingParams(
182183
num_tokens_to_generate=None,
@@ -193,6 +194,7 @@ async def base_generate(self, request: InferenceRequest):
193194
self._client.add_request(prompt=prompt, sampling_params=sampling_params)
194195
for prompt in request.prompt
195196
]
197+
print(f"Waiting for responses on rank {dist.get_rank()}")
196198
responses = await asyncio.gather(
197199
*requests
198200
)

0 commit comments

Comments
 (0)