@@ -514,9 +514,9 @@ def __init__(
514514 self ._debug = debug
515515 self ._seed = seed
516516
517- self ._consumed_tokens = 0
517+ self ._reduced_consumed_tokens = 0
518518 self ._exp_consumed_tokens = 0
519- self ._consumed_samples = 0
519+ self ._reduced_consumed_samples = 0
520520
521521 self ._train_time = 0
522522 self ._train_time_offset = 0
@@ -769,7 +769,10 @@ def fit(self):
769769 internal_metrics = self ._maybe_pop_model_internal_metrics (engine_input )
770770
771771 self ._cur_step += 1
772- self ._consumed_tokens += step_consumed_tokens
772+
773+ reduced_step_consumed_tokens = self ._reduce_number_across_rank (step_consumed_tokens )
774+ self ._reduced_consumed_tokens += reduced_step_consumed_tokens
775+
773776 self ._exp_consumed_tokens += step_consumed_tokens
774777 self ._train_time = time_after_train_step - train_begin
775778
@@ -778,7 +781,7 @@ def fit(self):
778781 loss_log = loss_log ,
779782 step_consumed_tokens = step_consumed_tokens ,
780783 exp_consumed_tokens = self ._exp_consumed_tokens ,
781- total_consumed_tokens = self ._consumed_tokens ,
784+ reduced_consumed_tokens = self ._reduced_consumed_tokens ,
782785 data_time = data_time ,
783786 step_time = step_time ,
784787 train_time = self ._train_time ,
@@ -805,6 +808,12 @@ def fit(self):
805808 self ._metrics_recorder .close ()
806809 self .logger .info (f"Training finished in { time .time () - train_begin :.2f} seconds" )
807810
811+ def _reduce_number_across_rank (self , rank_number : int ) -> int :
812+ _gathered_list = [None for _ in range (self .world_size )]
813+ dist .all_gather_object (_gathered_list , rank_number )
814+ reduced_number = sum (_gathered_list ) # type: ignore[arg-type]
815+ return reduced_number
816+
808817 def _maybe_init_model_metrics_recorder (
809818 self ,
810819 internal_metrics_cfg : InternalMetricsConfig | None ,
@@ -1128,8 +1137,8 @@ def _maybe_save(self, is_snapshot: bool = False) -> bool:
11281137 {
11291138 "cur_step" : self .cur_step ,
11301139 "cur_epoch" : self ._cur_epoch ,
1131- "consumed_samples " : self ._consumed_samples ,
1132- "consumed_tokens " : self ._consumed_tokens ,
1140+ "reduced_consumed_samples " : self ._reduced_consumed_samples ,
1141+ "reduced_consumed_tokens " : self ._reduced_consumed_tokens ,
11331142 "train_time_offset" : self ._train_time + self ._train_time_offset ,
11341143 }
11351144 )
@@ -1141,8 +1150,8 @@ def _maybe_save(self, is_snapshot: bool = False) -> bool:
11411150 ckp_list .append (str (checkpoint_path ))
11421151 current_exp .cur_step = self .cur_step
11431152 current_exp .cur_epoch = self ._cur_epoch
1144- current_exp .consumed_samples = self ._consumed_samples
1145- current_exp .consumed_tokens = int (self ._consumed_tokens )
1153+ current_exp .consumed_samples = int ( self ._reduced_consumed_samples )
1154+ current_exp .consumed_tokens = int (self ._reduced_consumed_tokens )
11461155 current_exp .history [- 1 ]["end" ] = self .cur_step
11471156
11481157 # Delete checkpoints and update meta's checkpoint_list
@@ -1178,12 +1187,8 @@ def _maybe_save(self, is_snapshot: bool = False) -> bool:
11781187 return True
11791188
11801189 def _save_dataloader (self , dataloader_path : Path | str ):
1181- _gathered_list = [None for _ in range (self .data_mesh ["dp" ].size ())]
1182- dist .all_gather_object (_gathered_list , self ._consumed_samples , group = self .data_mesh ["dp" ].get_group ())
1183- global_consumed_samples = sum (_gathered_list ) # type: ignore[arg-type]
1184-
11851190 if self .rank == 0 :
1186- dataloader_state = self ._dataloader .get_state_dict (global_consumed_samples )
1191+ dataloader_state = self ._dataloader .get_state_dict (self . _reduced_consumed_samples )
11871192 torch .save (dataloader_state , dataloader_path )
11881193
11891194 @property
@@ -1232,12 +1237,13 @@ def _data_iter(self):
12321237 # dist.breakpoint(skip=14)
12331238 try :
12341239 data = next (data_iter )
1235- self ._consumed_samples += len (data )
12361240 except StopIteration :
12371241 self ._cur_epoch += 1
12381242 self ._dataloader .set_epoch (self ._cur_epoch )
12391243 data_iter = iter (self ._dataloader )
12401244 data = next (data_iter )
1245+
1246+ self ._reduced_consumed_samples += self ._reduce_number_across_rank (len (data ))
12411247 yield data
12421248
12431249 def _get_checkpoint_path (self , epoch : int , step : int , is_snapshot : bool = False ) -> Path :
@@ -1410,7 +1416,7 @@ def _log_step(
14101416 loss_log : dict ,
14111417 step_consumed_tokens : int ,
14121418 exp_consumed_tokens : int ,
1413- total_consumed_tokens : int ,
1419+ reduced_consumed_tokens : int ,
14141420 data_time : float ,
14151421 step_time : float ,
14161422 train_time : float ,
@@ -1421,12 +1427,13 @@ def _log_step(
14211427 """Log the training step information."""
14221428 e2e_train_time = train_time + train_time_offset
14231429 tgs = step_consumed_tokens / step_time
1424- e2e_tgs = total_consumed_tokens / e2e_train_time
1430+ rank_consumed_tokens = reduced_consumed_tokens / self .world_size
1431+ e2e_tgs = rank_consumed_tokens / e2e_train_time
14251432 exp_tgs = exp_consumed_tokens / train_time
14261433 lr = self ._lr_scheduler .get_last_lr ()[0 ]
14271434
14281435 remaining_steps = self .total_step - self .cur_step
1429- avg_tokens_per_step = total_consumed_tokens / self .cur_step
1436+ avg_tokens_per_step = rank_consumed_tokens / self .cur_step
14301437 remaining_tokens = remaining_steps * avg_tokens_per_step
14311438 eta_seconds = remaining_tokens / (tgs + 1e-12 )
14321439 eta_hms = str (timedelta (seconds = int (eta_seconds )))
@@ -1447,7 +1454,7 @@ def _log_step(
14471454 f"Epoch { self ._cur_epoch } Step { self .cur_step } /{ self .total_step } "
14481455 f"data_time: { data_time :.4f} lr: { lr :.6e} time: { step_time :.4f} "
14491456 f"text_tokens: { step_consumed_tokens } "
1450- f"total_consumed_tokens : { total_consumed_tokens } "
1457+ f"reduced_consumed_tokens : { reduced_consumed_tokens } "
14511458 f"{ loss_log_str } "
14521459 f"grad_norm: { grad_norm :.8f} "
14531460 f"max_memory: { max_memory / (1024 ** 3 ):.2f} GB "
@@ -1467,7 +1474,7 @@ def _log_step(
14671474 "time/eta_seconds" : round (eta_seconds , 1 ),
14681475 "runtime_info/text_tokens" : step_consumed_tokens ,
14691476 "runtime_info/est_global_batch_tokens" : est_global_batch_tokens ,
1470- "runtime_info/total_consumed_tokens " : total_consumed_tokens ,
1477+ "runtime_info/reduced_consumed_tokens " : reduced_consumed_tokens ,
14711478 "runtime_info/tgs" : tgs ,
14721479 "runtime_info/exp_tgs" : exp_tgs ,
14731480 "runtime_info/e2e_tgs" : e2e_tgs ,
@@ -1663,20 +1670,25 @@ def _load_checkpoint(self):
16631670 load_args = load_checkpoint_cfg .load_optimizer_args ,
16641671 )
16651672
1666- if load_checkpoint_cfg .load_dataset :
1667- dataloader_path = resume_from / self ._SAVE_DATALOADER_DIR
1668- self ._resume_dataloader (dataloader_path )
1669-
16701673 train_state_path = resume_from / self ._SAVE_TRAIN_STATE_PATH
16711674
16721675 with train_state_path .open ("r" ) as f :
16731676 train_state = json .load (f )
16741677
16751678 self ._cur_step = train_state ["cur_step" ]
16761679 self ._cur_epoch = train_state ["cur_epoch" ]
1677- self ._consumed_samples = train_state ["consumed_samples" ]
1678- self ._consumed_tokens = train_state ["consumed_tokens" ]
1679- self ._train_time_offset = train_state ["train_time_offset" ]
1680+
1681+ if load_checkpoint_cfg .load_dataset :
1682+ self ._reduced_consumed_tokens = train_state .get ("reduced_consumed_tokens" , 0 ) # default 0 for BC
1683+ self ._train_time_offset = train_state ["train_time_offset" ]
1684+ # _reduced_consumed_samples 会影响 save dcp时 dataloader.get_state_dict的状态。
1685+ # 1) 如果加载 dataset,应该恢复_reduced_consumed_samples为checkpoint中的值。
1686+ # 2) 如果不加载 dataset,应该保持_reduced_consumed_samples为初始值0,否则如果加载上旧dataloader的reduced_consumed_samples
1687+ # 会导致存储新dataloader时 reduced_consumed_samples 是不正确的值。
1688+ self ._reduced_consumed_samples = train_state .get ("reduced_consumed_samples" , 0 ) # default 0 for BC
1689+
1690+ dataloader_path = resume_from / self ._SAVE_DATALOADER_DIR
1691+ self ._resume_dataloader (dataloader_path )
16801692
16811693 if load_checkpoint_cfg .load_scheduler :
16821694 scheduler_path = resume_from / self ._SAVE_SCHEDULER_DIR
0 commit comments