Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion tests/train/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,10 @@ def grad_accumulation_steps(self, *args, **kwargs):

def train_step(self, *args, **kwargs):
self.train_step_calls += 1
return {"total_loss": 1.0, "reduced_llm_loss": 0.8}, {"consumed_tokens": 100, "grad_norm": torch.tensor(1.0), "efficient_attn_ratio": 0.5}
return (
{"local_loss": 1.0, "reduced_llm_loss": 0.8},
{"consumed_tokens": 100, "grad_norm": torch.tensor(1.0), "efficient_attn_ratio": 0.5}
)

def save_hf(self, hf_path):
self.save_hf_calls.append(hf_path)
Expand Down
10 changes: 5 additions & 5 deletions xtuner/v1/engine/train_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@

class LossLog(TypedDict):
__pydantic_config__ = ConfigDict(arbitrary_types_allowed=True) # type: ignore[misc]
total_loss: float
local_loss: float
reduced_llm_loss: float
reduced_balancing_loss: NotRequired[float]
reduced_z_loss: NotRequired[float]
Expand All @@ -53,7 +53,7 @@ class LossLog(TypedDict):
class OtherLog(TypedDict):
__pydantic_config__ = ConfigDict(arbitrary_types_allowed=True) # type: ignore[misc]
maxvio: NotRequired[float]
consumed_tokens: float
consumed_tokens: int
extra_info: ModelForwardExtraLogInfo
efficient_attn_ratio: float

Expand Down Expand Up @@ -252,7 +252,7 @@ def train_step(self, data_batches: list[ModelItem]) -> tuple[LossLog, OtherLog]:
step_llm_loss = torch.tensor(0.0, device=DEVICE)
step_balancing_loss: torch.Tensor | None = None
step_z_loss: torch.Tensor | None = None
step_consumed_tokens = torch.tensor(0.0, device=DEVICE)
step_consumed_tokens = torch.tensor(0, device=DEVICE)

if self._count == 0:
logger.info(f"grad_accumulation_steps: {iters_per_step}")
Expand Down Expand Up @@ -336,7 +336,7 @@ def train_step(self, data_batches: list[ModelItem]) -> tuple[LossLog, OtherLog]:
reduced_llm_loss = step_llm_loss
dist.all_reduce(reduced_llm_loss.div_(dist.get_world_size()))

loss_log["total_loss"] = step_loss.item()
loss_log["local_loss"] = step_loss.item()
loss_log["reduced_llm_loss"] = reduced_llm_loss.item()
if step_balancing_loss is not None:
reduced_balancing_loss = step_balancing_loss
Expand All @@ -346,7 +346,7 @@ def train_step(self, data_batches: list[ModelItem]) -> tuple[LossLog, OtherLog]:
reduced_z_loss = step_z_loss
dist.all_reduce(reduced_z_loss.div_(dist.get_world_size()))
loss_log["reduced_z_loss"] = reduced_z_loss.item()
other_log["consumed_tokens"] = step_consumed_tokens.item()
other_log["consumed_tokens"] = cast(int, step_consumed_tokens.item())
other_log["extra_info"] = train_engine_extra_info
other_log["efficient_attn_ratio"] = (efficient_forward_tokens / total_forward_tokens).item()
return loss_log, other_log
Expand Down
6 changes: 3 additions & 3 deletions xtuner/v1/engine/vision_compose_train_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def train_step(self, data_batches: List[ModelItem]) -> tuple[LossLog, OtherLog]:
step_llm_loss = torch.tensor(0.0, device=DEVICE)
step_balancing_loss: torch.Tensor | None = None
step_z_loss: torch.Tensor | None = None
step_consumed_tokens = torch.tensor(0.0, device=DEVICE)
step_consumed_tokens = torch.tensor(0, device=DEVICE)
efficient_forward_tokens = torch.tensor(0, device=DEVICE, dtype=torch.long)
total_forward_tokens = torch.tensor(0, device=DEVICE, dtype=torch.long)

Expand Down Expand Up @@ -252,7 +252,7 @@ def train_step(self, data_batches: List[ModelItem]) -> tuple[LossLog, OtherLog]:
reduced_llm_loss = step_llm_loss
dist.all_reduce(reduced_llm_loss.div_(dist.get_world_size()))

loss_log["total_loss"] = step_loss.item()
loss_log["local_loss"] = step_loss.item()
loss_log["reduced_llm_loss"] = reduced_llm_loss.item()
if step_balancing_loss is not None:
reduced_balancing_loss = step_balancing_loss
Expand All @@ -262,7 +262,7 @@ def train_step(self, data_batches: List[ModelItem]) -> tuple[LossLog, OtherLog]:
reduced_z_loss = step_z_loss
dist.all_reduce(reduced_z_loss.div_(dist.get_world_size()))
loss_log["reduced_z_loss"] = reduced_z_loss.item()
other_log["consumed_tokens"] = step_consumed_tokens.item()
other_log["consumed_tokens"] = cast(int, step_consumed_tokens.item())
other_log["extra_info"] = train_engine_extra_info # type: ignore[assignment]
other_log["efficient_attn_ratio"] = (efficient_forward_tokens / total_forward_tokens).item()
return loss_log, other_log
4 changes: 3 additions & 1 deletion xtuner/v1/model/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,5 +93,7 @@ def get(self):
while self["log_rank_loss"].dim() >= 1:
self["log_rank_loss"] = torch.sum(self["log_rank_loss"], dim=-1)
log_rank_loss_value = self["log_rank_loss"].item()
return_dict["loss"] = log_rank_loss_value
# vague keys such as `loss` should be avoided in extra_log_info,
# otherwise it may cause confusion in exp-track logs.
return_dict["local_base_loss"] = log_rank_loss_value
return return_dict
71 changes: 35 additions & 36 deletions xtuner/v1/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,9 +514,9 @@ def __init__(
self._debug = debug
self._seed = seed

self._reduced_consumed_tokens = 0
self._total_consumed_tokens = 0
self._exp_consumed_tokens = 0
self._reduced_consumed_samples = 0
self._total_consumed_samples = 0

self._train_time = 0
self._train_time_offset = 0
Expand Down Expand Up @@ -762,26 +762,24 @@ def fit(self):
extra_info_dict = extra_info_updated.get()
loss_log.update(extra_info_dict)

if "maxvio" in other_log:
loss_log["maxvio"] = other_log["maxvio"]
loss_log["efficient_attn_ratio"] = other_log["efficient_attn_ratio"]

internal_metrics = self._maybe_pop_model_internal_metrics(engine_input)

self._cur_step += 1

reduced_step_consumed_tokens = self._reduce_number_across_rank(step_consumed_tokens)
self._reduced_consumed_tokens += reduced_step_consumed_tokens

self._exp_consumed_tokens += step_consumed_tokens
self._total_consumed_tokens += reduced_step_consumed_tokens
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里命名要注意保持和loss在代码中和日志中的一致。
比如 total_loss表示单卡上总loss(llm, balance, zloss), reduced_llm_loss表示各卡聚合后llm_loss。
所以建议 要么保留原来的 reduced_consumed_tokens命名,要么将 loss的命名也改成和这里一致。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里确实有歧义,在现在的 codebase 里,total_ 被用来表示时间维度上的总和 (e.g., total_[step|epoch|iter]),也有被用来表达空间上的总和(total_GPU_per_node),total_loss是表示几种 loss 的聚合, 显得有点混乱。感觉是应该有个好的 naming convention 来把这些统一的区分开。

这里的_reduced_consumed_tokens 其实同时兼有 时间维度上的累积,和空间维度的聚合两个含义,所以感觉有点奇怪,我想一想要怎么改善这个 PR。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里命名要注意保持和loss在代码中和日志中的一致。 比如 total_loss表示单卡上总loss(llm, balance, zloss), reduced_llm_loss表示各卡聚合后llm_loss。 所以建议 要么保留原来的 reduced_consumed_tokens命名,要么将 loss的命名也改成和这里一致。

我现在把 total_loss 修改为 local_loss,同时还附带把原来 extra_log_info 里的一个叫 loss 的 key 重新命名(这个值对 pt/sft 而言是 reduce_sum 之前的 ce_loss,在 tensorboard 上很容易引发歧义).

我理想来讲,step vs. 整个训练周期、单 rank vs. 所有 rank、 单个值 vs. 总和,都要有个前缀做区分,但是这样每个变量都会显得非常冗余,个人浅见只能退而求其次,做一些合理的省略(只要上下文是清晰的):

  1. 对于 loss 而言,几个 loss 加和,理论上只要叫 loss 即可,细分的 loss 用 xx_loss 做区别;
  2. 对于 loss 或者 tokens 而言,需要区分单 rank 和全局的求和或平均,由于单 rank 的 log 不多,如无特别前缀,都表示 reduced 之后的 版本,rank 上的值用 local_ 前缀表示;
  3. 时间维度上仍然需要区分单步统计值或者累计值,主要针对 tokens 而言,loss 没有这类需求,step 表示当前步, total 表示整个训练周期。

self._exp_consumed_tokens += reduced_step_consumed_tokens
self._train_time = time_after_train_step - train_begin

# TODO: This log should be move before lr_scheduler.step, but for CI BC, keep it temporarily
self._log_step(
loss_log=loss_log,
step_consumed_tokens=step_consumed_tokens,
local_step_consumed_tokens=step_consumed_tokens,
step_consumed_tokens=reduced_step_consumed_tokens,
exp_consumed_tokens=self._exp_consumed_tokens,
reduced_consumed_tokens=self._reduced_consumed_tokens,
total_consumed_tokens=self._total_consumed_tokens,
data_time=data_time,
step_time=step_time,
train_time=self._train_time,
Expand Down Expand Up @@ -1137,8 +1135,8 @@ def _maybe_save(self, is_snapshot: bool = False) -> bool:
{
"cur_step": self.cur_step,
"cur_epoch": self._cur_epoch,
"reduced_consumed_samples": self._reduced_consumed_samples,
"reduced_consumed_tokens": self._reduced_consumed_tokens,
"total_consumed_samples": self._total_consumed_samples,
"total_consumed_tokens": self._total_consumed_tokens,
"train_time_offset": self._train_time + self._train_time_offset,
}
)
Expand All @@ -1150,8 +1148,8 @@ def _maybe_save(self, is_snapshot: bool = False) -> bool:
ckp_list.append(str(checkpoint_path))
current_exp.cur_step = self.cur_step
current_exp.cur_epoch = self._cur_epoch
current_exp.consumed_samples = int(self._reduced_consumed_samples)
current_exp.consumed_tokens = int(self._reduced_consumed_tokens)
current_exp.consumed_samples = int(self._total_consumed_samples)
current_exp.consumed_tokens = int(self._total_consumed_tokens)
current_exp.history[-1]["end"] = self.cur_step

# Delete checkpoints and update meta's checkpoint_list
Expand Down Expand Up @@ -1188,7 +1186,7 @@ def _maybe_save(self, is_snapshot: bool = False) -> bool:

def _save_dataloader(self, dataloader_path: Path | str):
if self.rank == 0:
dataloader_state = self._dataloader.get_state_dict(self._reduced_consumed_samples)
dataloader_state = self._dataloader.get_state_dict(self._total_consumed_samples)
torch.save(dataloader_state, dataloader_path)

@property
Expand Down Expand Up @@ -1243,7 +1241,7 @@ def _data_iter(self):
data_iter = iter(self._dataloader)
data = next(data_iter)

self._reduced_consumed_samples += self._reduce_number_across_rank(len(data))
self._total_consumed_samples += self._reduce_number_across_rank(len(data))
yield data

def _get_checkpoint_path(self, epoch: int, step: int, is_snapshot: bool = False) -> Path:
Expand Down Expand Up @@ -1413,10 +1411,11 @@ def _maybe_profiling(self):

def _log_step(
self,
loss_log: dict,
loss_log: LossLog,
local_step_consumed_tokens: int,
step_consumed_tokens: int,
exp_consumed_tokens: int,
reduced_consumed_tokens: int,
total_consumed_tokens: int,
data_time: float,
step_time: float,
train_time: float,
Expand All @@ -1426,20 +1425,20 @@ def _log_step(
):
"""Log the training step information."""
e2e_train_time = train_time + train_time_offset
tgs = step_consumed_tokens / step_time
rank_consumed_tokens = reduced_consumed_tokens / self.world_size
e2e_tgs = rank_consumed_tokens / e2e_train_time
exp_tgs = exp_consumed_tokens / train_time
total_consumed_tokens_per_rank = total_consumed_tokens / self.world_size
exp_consumed_tokens_per_rank = exp_consumed_tokens / self.world_size

tgs = local_step_consumed_tokens / step_time
e2e_tgs = total_consumed_tokens_per_rank / e2e_train_time
exp_tgs = exp_consumed_tokens_per_rank / train_time
lr = self._lr_scheduler.get_last_lr()[0]

remaining_steps = self.total_step - self.cur_step
avg_tokens_per_step = rank_consumed_tokens / self.cur_step
avg_tokens_per_step = total_consumed_tokens_per_rank / self.cur_step
remaining_tokens = remaining_steps * avg_tokens_per_step
eta_seconds = remaining_tokens / (tgs + 1e-12)
eta_hms = str(timedelta(seconds=int(eta_seconds)))

est_global_batch_tokens = self.data_mesh["dp"].size() * step_consumed_tokens

loss_log_list = [f"{k}: {v:.8f}" for k, v in loss_log.items()]
loss_log_str = ", ".join(loss_log_list)

Expand All @@ -1453,16 +1452,16 @@ def _log_step(
self.logger.info(
f"Epoch {self._cur_epoch} Step {self.cur_step}/{self.total_step} "
f"data_time: {data_time:.4f} lr: {lr:.6e} time: {step_time:.4f} "
f"text_tokens: {step_consumed_tokens} "
f"reduced_consumed_tokens: {reduced_consumed_tokens} "
f"text_tokens: {local_step_consumed_tokens} "
f"step_consumed_tokens: {step_consumed_tokens} "
f"total_consumed_tokens: {total_consumed_tokens} "
f"{loss_log_str} "
f"grad_norm: {grad_norm:.8f} "
f"max_memory: {max_memory / (1024**3):.2f} GB "
f"reserved_memory: {reserved_memory / (1024**3):.2f} GB "
f"tgs: {tgs:.1f} "
f"exp_tgs: {exp_tgs: .1f} "
f"e2e_tgs: {e2e_tgs:.1f} "
f"est_global_batch_tokens: {est_global_batch_tokens} "
f"eta: {eta_hms} "
)

Expand All @@ -1472,9 +1471,9 @@ def _log_step(
"time/step_time": round(step_time, 4),
"time/train_time": round(train_time, 4),
"time/eta_seconds": round(eta_seconds, 1),
"runtime_info/text_tokens": step_consumed_tokens,
"runtime_info/est_global_batch_tokens": est_global_batch_tokens,
"runtime_info/reduced_consumed_tokens": reduced_consumed_tokens,
"runtime_info/text_tokens": local_step_consumed_tokens,
"runtime_info/step_consumed_tokens": step_consumed_tokens,
"runtime_info/total_consumed_tokens": total_consumed_tokens,
"runtime_info/tgs": tgs,
"runtime_info/exp_tgs": exp_tgs,
"runtime_info/e2e_tgs": e2e_tgs,
Expand Down Expand Up @@ -1679,13 +1678,13 @@ def _load_checkpoint(self):
self._cur_epoch = train_state["cur_epoch"]

if load_checkpoint_cfg.load_dataset:
self._reduced_consumed_tokens = train_state.get("reduced_consumed_tokens", 0) # default 0 for BC
self._total_consumed_tokens = train_state.get("total_consumed_tokens", 0) # default 0 for BC
self._train_time_offset = train_state["train_time_offset"]
# _reduced_consumed_samples 会影响 save dcp时 dataloader.get_state_dict的状态。
# 1) 如果加载 dataset,应该恢复_reduced_consumed_samples为checkpoint中的值
# 2) 如果不加载 dataset,应该保持_reduced_consumed_samples为初始值0,否则如果加载上旧dataloader的reduced_consumed_samples
# 会导致存储新dataloader时 reduced_consumed_samples 是不正确的值。
self._reduced_consumed_samples = train_state.get("reduced_consumed_samples", 0) # default 0 for BC
# _total_consumed_samples 会影响 save dcp时 dataloader.get_state_dict的状态。
# 1) 如果加载 dataset,应该恢复_total_consumed_samples为checkpoint中的值
# 2) 如果不加载 dataset,应该保持_total_consumed_samples为初始值0,否则如果加载上旧dataloader的total_consumed_samples
# 会导致存储新dataloader时 total_consumed_samples 是不正确的值。
self._total_consumed_samples = train_state.get("total_consumed_samples", 0) # default 0 for BC

dataloader_path = resume_from / self._SAVE_DATALOADER_DIR
self._resume_dataloader(dataloader_path)
Expand Down