66import pathlib
77from collections import defaultdict
88from typing import Any , Callable , Dict
9+ import functools
10+ import time
911
1012import numpy as np
1113from intervaltree import Interval , IntervalTree
1214
1315logger = logging .getLogger (__name__ )
1416logger .setLevel (logging .INFO )
1517
18+ def timer_decorator (func ):
19+ """Decorator that prints the execution time of a function"""
20+ @functools .wraps (func )
21+ def wrapper (* args , ** kwargs ):
22+ start_time = time .time ()
23+ result = func (* args , ** kwargs )
24+ end_time = time .time ()
25+ print (f"{ func .__name__ } took { end_time - start_time :.2f} seconds" )
26+ return result
27+ return wrapper
28+
1629# refer to:
1730# https://github.com/pytorch/pytorch/blob/2cc01cc6d3ad2aff47e8460667ba654b2e4c9f21/c10/core/ScalarType.h#L61
1831_dtype_size_map : Dict [str , int ] = {
@@ -293,14 +306,16 @@ def pick_comm_bw_(trace_data, comm_bw_data):
293306 and i ["name" ].startswith (("ncclDevKernel_" , "ncclKernel_" ))
294307 and "algbw (GB/sec)" in i ["args" ]
295308 ]
309+ pg_name2config = {pg ["pg_name" ]: pg for pg in trace_data ["distributedInfo" ]["pg_config" ]}
296310 for evt in nccl_events :
297311 knl_name = evt ["name" ][: evt ["name" ].index ("(" )]
298312 coll_name = evt ["args" ]["Collective name" ]
299313 data_size = _calculate_event_data_size (evt )
300- ranks_count = evt ["args" ]["Group size" ]
301314
302- ranks = _parse_ranks ( evt ["args" ]["Process Group Ranks" ], ranks_count )
315+ ranks_count = evt ["args" ]["Group size" ]
303316 pg_id = int (evt ["args" ]["Process Group Name" ])
317+ ranks = pg_name2config [evt ["args" ]["Process Group Name" ]]['ranks' ]
318+
304319 # If there are multiple process groups with the same ranks, the last element
305320 # of this tuple is the idential index to differentiate them across ranks.
306321 pg = (* ranks , group_ranks_to_pg_id [tuple (ranks )].index (pg_id ))
@@ -314,7 +329,7 @@ def pick_comm_bw_(trace_data, comm_bw_data):
314329 ]
315330 )
316331
317-
332+ @ timer_decorator
318333def analyze_profiler_trace (trace_dir : str , report_dir : str ):
319334 """
320335 Analyse input PyTorch profiler trace (i.e. Kineto trace) and generate report.
0 commit comments