diff --git a/tests/train/test_trainer.py b/tests/train/test_trainer.py index 6adf4e0a9..47f2aaecf 100644 --- a/tests/train/test_trainer.py +++ b/tests/train/test_trainer.py @@ -54,7 +54,10 @@ def grad_accumulation_steps(self, *args, **kwargs): def train_step(self, *args, **kwargs): self.train_step_calls += 1 - return {"total_loss": 1.0, "reduced_llm_loss": 0.8}, {"consumed_tokens": 100, "grad_norm": torch.tensor(1.0), "efficient_attn_ratio": 0.5} + return ( + {"local_loss": 1.0, "reduced_llm_loss": 0.8}, + {"consumed_tokens": 100, "grad_norm": torch.tensor(1.0), "efficient_attn_ratio": 0.5} + ) def save_hf(self, hf_path): self.save_hf_calls.append(hf_path) diff --git a/xtuner/v1/engine/train_engine.py b/xtuner/v1/engine/train_engine.py index b1633e773..e78454117 100644 --- a/xtuner/v1/engine/train_engine.py +++ b/xtuner/v1/engine/train_engine.py @@ -44,7 +44,7 @@ class LossLog(TypedDict): __pydantic_config__ = ConfigDict(arbitrary_types_allowed=True) # type: ignore[misc] - total_loss: float + local_loss: float reduced_llm_loss: float reduced_balancing_loss: NotRequired[float] reduced_z_loss: NotRequired[float] @@ -53,7 +53,7 @@ class LossLog(TypedDict): class OtherLog(TypedDict): __pydantic_config__ = ConfigDict(arbitrary_types_allowed=True) # type: ignore[misc] maxvio: NotRequired[float] - consumed_tokens: float + consumed_tokens: int extra_info: ModelForwardExtraLogInfo efficient_attn_ratio: float @@ -252,7 +252,7 @@ def train_step(self, data_batches: list[ModelItem]) -> tuple[LossLog, OtherLog]: step_llm_loss = torch.tensor(0.0, device=DEVICE) step_balancing_loss: torch.Tensor | None = None step_z_loss: torch.Tensor | None = None - step_consumed_tokens = torch.tensor(0.0, device=DEVICE) + step_consumed_tokens = torch.tensor(0, device=DEVICE) if self._count == 0: logger.info(f"grad_accumulation_steps: {iters_per_step}") @@ -336,7 +336,7 @@ def train_step(self, data_batches: list[ModelItem]) -> tuple[LossLog, OtherLog]: reduced_llm_loss = step_llm_loss dist.all_reduce(reduced_llm_loss.div_(dist.get_world_size())) - loss_log["total_loss"] = step_loss.item() + loss_log["local_loss"] = step_loss.item() loss_log["reduced_llm_loss"] = reduced_llm_loss.item() if step_balancing_loss is not None: reduced_balancing_loss = step_balancing_loss @@ -346,7 +346,7 @@ def train_step(self, data_batches: list[ModelItem]) -> tuple[LossLog, OtherLog]: reduced_z_loss = step_z_loss dist.all_reduce(reduced_z_loss.div_(dist.get_world_size())) loss_log["reduced_z_loss"] = reduced_z_loss.item() - other_log["consumed_tokens"] = step_consumed_tokens.item() + other_log["consumed_tokens"] = cast(int, step_consumed_tokens.item()) other_log["extra_info"] = train_engine_extra_info other_log["efficient_attn_ratio"] = (efficient_forward_tokens / total_forward_tokens).item() return loss_log, other_log diff --git a/xtuner/v1/engine/vision_compose_train_engine.py b/xtuner/v1/engine/vision_compose_train_engine.py index 15bba4079..4251f8641 100644 --- a/xtuner/v1/engine/vision_compose_train_engine.py +++ b/xtuner/v1/engine/vision_compose_train_engine.py @@ -188,7 +188,7 @@ def train_step(self, data_batches: List[ModelItem]) -> tuple[LossLog, OtherLog]: step_llm_loss = torch.tensor(0.0, device=DEVICE) step_balancing_loss: torch.Tensor | None = None step_z_loss: torch.Tensor | None = None - step_consumed_tokens = torch.tensor(0.0, device=DEVICE) + step_consumed_tokens = torch.tensor(0, device=DEVICE) efficient_forward_tokens = torch.tensor(0, device=DEVICE, dtype=torch.long) total_forward_tokens = torch.tensor(0, device=DEVICE, dtype=torch.long) @@ -252,7 +252,7 @@ def train_step(self, data_batches: List[ModelItem]) -> tuple[LossLog, OtherLog]: reduced_llm_loss = step_llm_loss dist.all_reduce(reduced_llm_loss.div_(dist.get_world_size())) - loss_log["total_loss"] = step_loss.item() + loss_log["local_loss"] = step_loss.item() loss_log["reduced_llm_loss"] = reduced_llm_loss.item() if step_balancing_loss is not None: reduced_balancing_loss = step_balancing_loss @@ -262,7 +262,7 @@ def train_step(self, data_batches: List[ModelItem]) -> tuple[LossLog, OtherLog]: reduced_z_loss = step_z_loss dist.all_reduce(reduced_z_loss.div_(dist.get_world_size())) loss_log["reduced_z_loss"] = reduced_z_loss.item() - other_log["consumed_tokens"] = step_consumed_tokens.item() + other_log["consumed_tokens"] = cast(int, step_consumed_tokens.item()) other_log["extra_info"] = train_engine_extra_info # type: ignore[assignment] other_log["efficient_attn_ratio"] = (efficient_forward_tokens / total_forward_tokens).item() return loss_log, other_log diff --git a/xtuner/v1/model/utils/misc.py b/xtuner/v1/model/utils/misc.py index 6ab8281aa..2b8ca33eb 100644 --- a/xtuner/v1/model/utils/misc.py +++ b/xtuner/v1/model/utils/misc.py @@ -93,5 +93,7 @@ def get(self): while self["log_rank_loss"].dim() >= 1: self["log_rank_loss"] = torch.sum(self["log_rank_loss"], dim=-1) log_rank_loss_value = self["log_rank_loss"].item() - return_dict["loss"] = log_rank_loss_value + # vague keys such as `loss` should be avoided in extra_log_info, + # otherwise it may cause confusion in exp-track logs. + return_dict["local_base_loss"] = log_rank_loss_value return return_dict diff --git a/xtuner/v1/train/trainer.py b/xtuner/v1/train/trainer.py index 60d60e537..c692eefc0 100644 --- a/xtuner/v1/train/trainer.py +++ b/xtuner/v1/train/trainer.py @@ -514,9 +514,9 @@ def __init__( self._debug = debug self._seed = seed - self._reduced_consumed_tokens = 0 + self._total_consumed_tokens = 0 self._exp_consumed_tokens = 0 - self._reduced_consumed_samples = 0 + self._total_consumed_samples = 0 self._train_time = 0 self._train_time_offset = 0 @@ -762,8 +762,6 @@ def fit(self): extra_info_dict = extra_info_updated.get() loss_log.update(extra_info_dict) - if "maxvio" in other_log: - loss_log["maxvio"] = other_log["maxvio"] loss_log["efficient_attn_ratio"] = other_log["efficient_attn_ratio"] internal_metrics = self._maybe_pop_model_internal_metrics(engine_input) @@ -771,17 +769,17 @@ def fit(self): self._cur_step += 1 reduced_step_consumed_tokens = self._reduce_number_across_rank(step_consumed_tokens) - self._reduced_consumed_tokens += reduced_step_consumed_tokens - - self._exp_consumed_tokens += step_consumed_tokens + self._total_consumed_tokens += reduced_step_consumed_tokens + self._exp_consumed_tokens += reduced_step_consumed_tokens self._train_time = time_after_train_step - train_begin # TODO: This log should be move before lr_scheduler.step, but for CI BC, keep it temporarily self._log_step( loss_log=loss_log, - step_consumed_tokens=step_consumed_tokens, + local_step_consumed_tokens=step_consumed_tokens, + step_consumed_tokens=reduced_step_consumed_tokens, exp_consumed_tokens=self._exp_consumed_tokens, - reduced_consumed_tokens=self._reduced_consumed_tokens, + total_consumed_tokens=self._total_consumed_tokens, data_time=data_time, step_time=step_time, train_time=self._train_time, @@ -1137,8 +1135,8 @@ def _maybe_save(self, is_snapshot: bool = False) -> bool: { "cur_step": self.cur_step, "cur_epoch": self._cur_epoch, - "reduced_consumed_samples": self._reduced_consumed_samples, - "reduced_consumed_tokens": self._reduced_consumed_tokens, + "total_consumed_samples": self._total_consumed_samples, + "total_consumed_tokens": self._total_consumed_tokens, "train_time_offset": self._train_time + self._train_time_offset, } ) @@ -1150,8 +1148,8 @@ def _maybe_save(self, is_snapshot: bool = False) -> bool: ckp_list.append(str(checkpoint_path)) current_exp.cur_step = self.cur_step current_exp.cur_epoch = self._cur_epoch - current_exp.consumed_samples = int(self._reduced_consumed_samples) - current_exp.consumed_tokens = int(self._reduced_consumed_tokens) + current_exp.consumed_samples = int(self._total_consumed_samples) + current_exp.consumed_tokens = int(self._total_consumed_tokens) current_exp.history[-1]["end"] = self.cur_step # Delete checkpoints and update meta's checkpoint_list @@ -1188,7 +1186,7 @@ def _maybe_save(self, is_snapshot: bool = False) -> bool: def _save_dataloader(self, dataloader_path: Path | str): if self.rank == 0: - dataloader_state = self._dataloader.get_state_dict(self._reduced_consumed_samples) + dataloader_state = self._dataloader.get_state_dict(self._total_consumed_samples) torch.save(dataloader_state, dataloader_path) @property @@ -1243,7 +1241,7 @@ def _data_iter(self): data_iter = iter(self._dataloader) data = next(data_iter) - self._reduced_consumed_samples += self._reduce_number_across_rank(len(data)) + self._total_consumed_samples += self._reduce_number_across_rank(len(data)) yield data def _get_checkpoint_path(self, epoch: int, step: int, is_snapshot: bool = False) -> Path: @@ -1413,10 +1411,11 @@ def _maybe_profiling(self): def _log_step( self, - loss_log: dict, + loss_log: LossLog, + local_step_consumed_tokens: int, step_consumed_tokens: int, exp_consumed_tokens: int, - reduced_consumed_tokens: int, + total_consumed_tokens: int, data_time: float, step_time: float, train_time: float, @@ -1426,20 +1425,20 @@ def _log_step( ): """Log the training step information.""" e2e_train_time = train_time + train_time_offset - tgs = step_consumed_tokens / step_time - rank_consumed_tokens = reduced_consumed_tokens / self.world_size - e2e_tgs = rank_consumed_tokens / e2e_train_time - exp_tgs = exp_consumed_tokens / train_time + total_consumed_tokens_per_rank = total_consumed_tokens / self.world_size + exp_consumed_tokens_per_rank = exp_consumed_tokens / self.world_size + + tgs = local_step_consumed_tokens / step_time + e2e_tgs = total_consumed_tokens_per_rank / e2e_train_time + exp_tgs = exp_consumed_tokens_per_rank / train_time lr = self._lr_scheduler.get_last_lr()[0] remaining_steps = self.total_step - self.cur_step - avg_tokens_per_step = rank_consumed_tokens / self.cur_step + avg_tokens_per_step = total_consumed_tokens_per_rank / self.cur_step remaining_tokens = remaining_steps * avg_tokens_per_step eta_seconds = remaining_tokens / (tgs + 1e-12) eta_hms = str(timedelta(seconds=int(eta_seconds))) - est_global_batch_tokens = self.data_mesh["dp"].size() * step_consumed_tokens - loss_log_list = [f"{k}: {v:.8f}" for k, v in loss_log.items()] loss_log_str = ", ".join(loss_log_list) @@ -1453,8 +1452,9 @@ def _log_step( self.logger.info( f"Epoch {self._cur_epoch} Step {self.cur_step}/{self.total_step} " f"data_time: {data_time:.4f} lr: {lr:.6e} time: {step_time:.4f} " - f"text_tokens: {step_consumed_tokens} " - f"reduced_consumed_tokens: {reduced_consumed_tokens} " + f"text_tokens: {local_step_consumed_tokens} " + f"step_consumed_tokens: {step_consumed_tokens} " + f"total_consumed_tokens: {total_consumed_tokens} " f"{loss_log_str} " f"grad_norm: {grad_norm:.8f} " f"max_memory: {max_memory / (1024**3):.2f} GB " @@ -1462,7 +1462,6 @@ def _log_step( f"tgs: {tgs:.1f} " f"exp_tgs: {exp_tgs: .1f} " f"e2e_tgs: {e2e_tgs:.1f} " - f"est_global_batch_tokens: {est_global_batch_tokens} " f"eta: {eta_hms} " ) @@ -1472,9 +1471,9 @@ def _log_step( "time/step_time": round(step_time, 4), "time/train_time": round(train_time, 4), "time/eta_seconds": round(eta_seconds, 1), - "runtime_info/text_tokens": step_consumed_tokens, - "runtime_info/est_global_batch_tokens": est_global_batch_tokens, - "runtime_info/reduced_consumed_tokens": reduced_consumed_tokens, + "runtime_info/text_tokens": local_step_consumed_tokens, + "runtime_info/step_consumed_tokens": step_consumed_tokens, + "runtime_info/total_consumed_tokens": total_consumed_tokens, "runtime_info/tgs": tgs, "runtime_info/exp_tgs": exp_tgs, "runtime_info/e2e_tgs": e2e_tgs, @@ -1679,13 +1678,13 @@ def _load_checkpoint(self): self._cur_epoch = train_state["cur_epoch"] if load_checkpoint_cfg.load_dataset: - self._reduced_consumed_tokens = train_state.get("reduced_consumed_tokens", 0) # default 0 for BC + self._total_consumed_tokens = train_state.get("total_consumed_tokens", 0) # default 0 for BC self._train_time_offset = train_state["train_time_offset"] - # _reduced_consumed_samples 会影响 save dcp时 dataloader.get_state_dict的状态。 - # 1) 如果加载 dataset,应该恢复_reduced_consumed_samples为checkpoint中的值。 - # 2) 如果不加载 dataset,应该保持_reduced_consumed_samples为初始值0,否则如果加载上旧dataloader的reduced_consumed_samples - # 会导致存储新dataloader时 reduced_consumed_samples 是不正确的值。 - self._reduced_consumed_samples = train_state.get("reduced_consumed_samples", 0) # default 0 for BC + # _total_consumed_samples 会影响 save dcp时 dataloader.get_state_dict的状态。 + # 1) 如果加载 dataset,应该恢复_total_consumed_samples为checkpoint中的值。 + # 2) 如果不加载 dataset,应该保持_total_consumed_samples为初始值0,否则如果加载上旧dataloader的total_consumed_samples + # 会导致存储新dataloader时 total_consumed_samples 是不正确的值。 + self._total_consumed_samples = train_state.get("total_consumed_samples", 0) # default 0 for BC dataloader_path = resume_from / self._SAVE_DATALOADER_DIR self._resume_dataloader(dataloader_path)