@@ -139,7 +139,18 @@ def _get_event_busbw_factor(evt):
139139
140140 return correction_factor_func (group_size )
141141
142- def _calculate_busbw_for_uneven_all_to_all (evt , global_rank ):
142+ def _is_uneven_all_to_all_evt (evt ):
143+ coll_name = _get_dict_value (
144+ evt ["args" ],
145+ "Collective name" ,
146+ f'Missing "Collective name" in event: { evt } '
147+ )
148+ return (coll_name in ["all_to_all" , "all_to_allv" ]
149+ and (ast .literal_eval (evt ['args' ]['In split size' ])
150+ or ast .literal_eval (evt ['args' ]['Out split size' ]))
151+ )
152+
153+ def _get_uneven_all_to_all_data_size (evt , global_rank ):
143154 group_size = evt ["args" ]["Group size" ]
144155 local_rank = _parse_ranks (evt ["args" ]["Process Group Ranks" ], group_size ).index (global_rank )
145156 in_elems_count = evt ["args" ]["In msg nelems" ]
@@ -158,7 +169,10 @@ def _calculate_busbw_for_uneven_all_to_all(evt, global_rank):
158169 else :
159170 recv_elems = out_elems_count / group_size * (group_size - 1 )
160171
161- return round (max (send_elems , recv_elems ) * dtype_size / evt ["dur" ] * 1e-3 , 2 )
172+ return max (send_elems , recv_elems ) * dtype_size
173+
174+ def _calculate_busbw_for_uneven_all_to_all (evt , global_rank ):
175+ return round (_get_uneven_all_to_all_data_size (evt , global_rank ) / evt ["dur" ] * 1e-3 , 2 )
162176
163177def calculate_bw_ (trace_data , global_rank ):
164178 nccl_events = [
@@ -184,10 +198,7 @@ def calculate_bw_(trace_data, global_rank):
184198
185199 algbw = _calculate_algbw (evt )
186200 busbw_factor = _get_event_busbw_factor (evt )
187- if (coll_name in ["all_to_all" , "all_to_allv" ]
188- and (ast .literal_eval (evt ['args' ]['In split size' ])
189- or ast .literal_eval (evt ['args' ]['Out split size' ]))
190- ):
201+ if _is_uneven_all_to_all_evt (evt ):
191202 # calculate busbw for uneven all_to_all
192203 busbw = _calculate_busbw_for_uneven_all_to_all (evt , global_rank )
193204 else :
@@ -206,7 +217,7 @@ def calculate_bw_(trace_data, global_rank):
206217 logger .error (f"- Error: { err_msg } " )
207218
208219
209- def calculate_sbw (trace_data ):
220+ def calculate_sbw (trace_data , global_rank ):
210221 # calculate shared bw per rank
211222 nccl_events = [
212223 i
@@ -221,6 +232,8 @@ def calculate_sbw(trace_data):
221232
222233 total_data_size = sum (
223234 _calculate_event_data_size (evt ) * _get_event_busbw_factor (evt )
235+ if not _is_uneven_all_to_all_evt (evt )
236+ else _get_uneven_all_to_all_data_size (evt , global_rank )
224237 for evt in nccl_events
225238 )
226239
@@ -336,7 +349,7 @@ def analyze_profiler_trace(trace_dir: str, report_dir: str):
336349 ) as f :
337350 json .dump (trace , f )
338351
339- sbw_lst .append (calculate_sbw (trace ))
352+ sbw_lst .append (calculate_sbw (trace , global_rank ))
340353
341354 pick_iter_e2e_time_ (trace , iter_e2e_time )
342355 pick_comm_bw_ (trace , comm_bw_data )
@@ -367,7 +380,7 @@ def analyze_profiler_trace(trace_dir: str, report_dir: str):
367380 f"avg. E2ETime of iters among all ranks: { sum (iter_e2e_time ) / len (iter_e2e_time ) / 1e3 :.3f} ms\n "
368381 )
369382 f .write (
370- f"avg. SharedBW (i.e. sum(data_size * busbw_factor ) / GPU_comm_busy_time per rank) among all ranks: { sum (sbw_lst ) / len (sbw_lst ) :.3f} GB/s\n "
383+ f"avg. SharedBW (i.e. sum(busbw_data_size ) / GPU_comm_busy_time per rank) among all ranks: { sum (sbw_lst ) / len (sbw_lst ) :.3f} GB/s\n "
371384 )
372385
373386 f .write (
0 commit comments