Skip to content

Commit 4ac7e6d

Browse files
committed
bugfix to group ranks parsing
1 parent bf3ecba commit 4ac7e6d

File tree

1 file changed

+18
-3
lines changed

1 file changed

+18
-3
lines changed

et_replay/comm/profiler_trace_analysis.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,26 @@
66
import pathlib
77
from collections import defaultdict
88
from typing import Any, Callable, Dict
9+
import functools
10+
import time
911

1012
import numpy as np
1113
from intervaltree import Interval, IntervalTree
1214

1315
logger = logging.getLogger(__name__)
1416
logger.setLevel(logging.INFO)
1517

18+
def timer_decorator(func):
19+
"""Decorator that prints the execution time of a function"""
20+
@functools.wraps(func)
21+
def wrapper(*args, **kwargs):
22+
start_time = time.time()
23+
result = func(*args, **kwargs)
24+
end_time = time.time()
25+
print(f"{func.__name__} took {end_time - start_time:.2f} seconds")
26+
return result
27+
return wrapper
28+
1629
# refer to:
1730
# https://github.com/pytorch/pytorch/blob/2cc01cc6d3ad2aff47e8460667ba654b2e4c9f21/c10/core/ScalarType.h#L61
1831
_dtype_size_map: Dict[str, int] = {
@@ -293,14 +306,16 @@ def pick_comm_bw_(trace_data, comm_bw_data):
293306
and i["name"].startswith(("ncclDevKernel_", "ncclKernel_"))
294307
and "algbw (GB/sec)" in i["args"]
295308
]
309+
pg_name2config = {pg["pg_name"]: pg for pg in trace_data["distributedInfo"]["pg_config"]}
296310
for evt in nccl_events:
297311
knl_name = evt["name"][: evt["name"].index("(")]
298312
coll_name = evt["args"]["Collective name"]
299313
data_size = _calculate_event_data_size(evt)
300-
ranks_count = evt["args"]["Group size"]
301314

302-
ranks = _parse_ranks(evt["args"]["Process Group Ranks"], ranks_count)
315+
ranks_count = evt["args"]["Group size"]
303316
pg_id = int(evt["args"]["Process Group Name"])
317+
ranks = pg_name2config[evt["args"]["Process Group Name"]]['ranks']
318+
304319
# If there are multiple process groups with the same ranks, the last element
305320
# of this tuple is the idential index to differentiate them across ranks.
306321
pg = (*ranks, group_ranks_to_pg_id[tuple(ranks)].index(pg_id))
@@ -314,7 +329,7 @@ def pick_comm_bw_(trace_data, comm_bw_data):
314329
]
315330
)
316331

317-
332+
@timer_decorator
318333
def analyze_profiler_trace(trace_dir: str, report_dir: str):
319334
"""
320335
Analyse input PyTorch profiler trace (i.e. Kineto trace) and generate report.

0 commit comments

Comments
 (0)