Skip to content

Commit 0f9f751

Browse files
jayhenryHAOCHENYE
authored andcommitted
1) Add reduced consumed tokens/samples. 2) Do not resume consumed samples if not load dataset (#1326)
* [Feature] add reduced consumed tokens and samples * [Fix] Do not resume consumed samples and tokens if not load_dataset
1 parent b659fd1 commit 0f9f751

File tree

1 file changed

+38
-26
lines changed

1 file changed

+38
-26
lines changed

xtuner/v1/train/trainer.py

Lines changed: 38 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -514,9 +514,9 @@ def __init__(
514514
self._debug = debug
515515
self._seed = seed
516516

517-
self._consumed_tokens = 0
517+
self._reduced_consumed_tokens = 0
518518
self._exp_consumed_tokens = 0
519-
self._consumed_samples = 0
519+
self._reduced_consumed_samples = 0
520520

521521
self._train_time = 0
522522
self._train_time_offset = 0
@@ -769,7 +769,10 @@ def fit(self):
769769
internal_metrics = self._maybe_pop_model_internal_metrics(engine_input)
770770

771771
self._cur_step += 1
772-
self._consumed_tokens += step_consumed_tokens
772+
773+
reduced_step_consumed_tokens = self._reduce_number_across_rank(step_consumed_tokens)
774+
self._reduced_consumed_tokens += reduced_step_consumed_tokens
775+
773776
self._exp_consumed_tokens += step_consumed_tokens
774777
self._train_time = time_after_train_step - train_begin
775778

@@ -778,7 +781,7 @@ def fit(self):
778781
loss_log=loss_log,
779782
step_consumed_tokens=step_consumed_tokens,
780783
exp_consumed_tokens=self._exp_consumed_tokens,
781-
total_consumed_tokens=self._consumed_tokens,
784+
reduced_consumed_tokens=self._reduced_consumed_tokens,
782785
data_time=data_time,
783786
step_time=step_time,
784787
train_time=self._train_time,
@@ -805,6 +808,12 @@ def fit(self):
805808
self._metrics_recorder.close()
806809
self.logger.info(f"Training finished in {time.time() - train_begin:.2f} seconds")
807810

811+
def _reduce_number_across_rank(self, rank_number: int) -> int:
812+
_gathered_list = [None for _ in range(self.world_size)]
813+
dist.all_gather_object(_gathered_list, rank_number)
814+
reduced_number = sum(_gathered_list) # type: ignore[arg-type]
815+
return reduced_number
816+
808817
def _maybe_init_model_metrics_recorder(
809818
self,
810819
internal_metrics_cfg: InternalMetricsConfig | None,
@@ -1128,8 +1137,8 @@ def _maybe_save(self, is_snapshot: bool = False) -> bool:
11281137
{
11291138
"cur_step": self.cur_step,
11301139
"cur_epoch": self._cur_epoch,
1131-
"consumed_samples": self._consumed_samples,
1132-
"consumed_tokens": self._consumed_tokens,
1140+
"reduced_consumed_samples": self._reduced_consumed_samples,
1141+
"reduced_consumed_tokens": self._reduced_consumed_tokens,
11331142
"train_time_offset": self._train_time + self._train_time_offset,
11341143
}
11351144
)
@@ -1141,8 +1150,8 @@ def _maybe_save(self, is_snapshot: bool = False) -> bool:
11411150
ckp_list.append(str(checkpoint_path))
11421151
current_exp.cur_step = self.cur_step
11431152
current_exp.cur_epoch = self._cur_epoch
1144-
current_exp.consumed_samples = self._consumed_samples
1145-
current_exp.consumed_tokens = int(self._consumed_tokens)
1153+
current_exp.consumed_samples = int(self._reduced_consumed_samples)
1154+
current_exp.consumed_tokens = int(self._reduced_consumed_tokens)
11461155
current_exp.history[-1]["end"] = self.cur_step
11471156

11481157
# Delete checkpoints and update meta's checkpoint_list
@@ -1178,12 +1187,8 @@ def _maybe_save(self, is_snapshot: bool = False) -> bool:
11781187
return True
11791188

11801189
def _save_dataloader(self, dataloader_path: Path | str):
1181-
_gathered_list = [None for _ in range(self.data_mesh["dp"].size())]
1182-
dist.all_gather_object(_gathered_list, self._consumed_samples, group=self.data_mesh["dp"].get_group())
1183-
global_consumed_samples = sum(_gathered_list) # type: ignore[arg-type]
1184-
11851190
if self.rank == 0:
1186-
dataloader_state = self._dataloader.get_state_dict(global_consumed_samples)
1191+
dataloader_state = self._dataloader.get_state_dict(self._reduced_consumed_samples)
11871192
torch.save(dataloader_state, dataloader_path)
11881193

11891194
@property
@@ -1232,12 +1237,13 @@ def _data_iter(self):
12321237
# dist.breakpoint(skip=14)
12331238
try:
12341239
data = next(data_iter)
1235-
self._consumed_samples += len(data)
12361240
except StopIteration:
12371241
self._cur_epoch += 1
12381242
self._dataloader.set_epoch(self._cur_epoch)
12391243
data_iter = iter(self._dataloader)
12401244
data = next(data_iter)
1245+
1246+
self._reduced_consumed_samples += self._reduce_number_across_rank(len(data))
12411247
yield data
12421248

12431249
def _get_checkpoint_path(self, epoch: int, step: int, is_snapshot: bool = False) -> Path:
@@ -1410,7 +1416,7 @@ def _log_step(
14101416
loss_log: dict,
14111417
step_consumed_tokens: int,
14121418
exp_consumed_tokens: int,
1413-
total_consumed_tokens: int,
1419+
reduced_consumed_tokens: int,
14141420
data_time: float,
14151421
step_time: float,
14161422
train_time: float,
@@ -1421,12 +1427,13 @@ def _log_step(
14211427
"""Log the training step information."""
14221428
e2e_train_time = train_time + train_time_offset
14231429
tgs = step_consumed_tokens / step_time
1424-
e2e_tgs = total_consumed_tokens / e2e_train_time
1430+
rank_consumed_tokens = reduced_consumed_tokens / self.world_size
1431+
e2e_tgs = rank_consumed_tokens / e2e_train_time
14251432
exp_tgs = exp_consumed_tokens / train_time
14261433
lr = self._lr_scheduler.get_last_lr()[0]
14271434

14281435
remaining_steps = self.total_step - self.cur_step
1429-
avg_tokens_per_step = total_consumed_tokens / self.cur_step
1436+
avg_tokens_per_step = rank_consumed_tokens / self.cur_step
14301437
remaining_tokens = remaining_steps * avg_tokens_per_step
14311438
eta_seconds = remaining_tokens / (tgs + 1e-12)
14321439
eta_hms = str(timedelta(seconds=int(eta_seconds)))
@@ -1447,7 +1454,7 @@ def _log_step(
14471454
f"Epoch {self._cur_epoch} Step {self.cur_step}/{self.total_step} "
14481455
f"data_time: {data_time:.4f} lr: {lr:.6e} time: {step_time:.4f} "
14491456
f"text_tokens: {step_consumed_tokens} "
1450-
f"total_consumed_tokens: {total_consumed_tokens} "
1457+
f"reduced_consumed_tokens: {reduced_consumed_tokens} "
14511458
f"{loss_log_str} "
14521459
f"grad_norm: {grad_norm:.8f} "
14531460
f"max_memory: {max_memory / (1024**3):.2f} GB "
@@ -1467,7 +1474,7 @@ def _log_step(
14671474
"time/eta_seconds": round(eta_seconds, 1),
14681475
"runtime_info/text_tokens": step_consumed_tokens,
14691476
"runtime_info/est_global_batch_tokens": est_global_batch_tokens,
1470-
"runtime_info/total_consumed_tokens": total_consumed_tokens,
1477+
"runtime_info/reduced_consumed_tokens": reduced_consumed_tokens,
14711478
"runtime_info/tgs": tgs,
14721479
"runtime_info/exp_tgs": exp_tgs,
14731480
"runtime_info/e2e_tgs": e2e_tgs,
@@ -1663,20 +1670,25 @@ def _load_checkpoint(self):
16631670
load_args=load_checkpoint_cfg.load_optimizer_args,
16641671
)
16651672

1666-
if load_checkpoint_cfg.load_dataset:
1667-
dataloader_path = resume_from / self._SAVE_DATALOADER_DIR
1668-
self._resume_dataloader(dataloader_path)
1669-
16701673
train_state_path = resume_from / self._SAVE_TRAIN_STATE_PATH
16711674

16721675
with train_state_path.open("r") as f:
16731676
train_state = json.load(f)
16741677

16751678
self._cur_step = train_state["cur_step"]
16761679
self._cur_epoch = train_state["cur_epoch"]
1677-
self._consumed_samples = train_state["consumed_samples"]
1678-
self._consumed_tokens = train_state["consumed_tokens"]
1679-
self._train_time_offset = train_state["train_time_offset"]
1680+
1681+
if load_checkpoint_cfg.load_dataset:
1682+
self._reduced_consumed_tokens = train_state.get("reduced_consumed_tokens", 0) # default 0 for BC
1683+
self._train_time_offset = train_state["train_time_offset"]
1684+
# _reduced_consumed_samples 会影响 save dcp时 dataloader.get_state_dict的状态。
1685+
# 1) 如果加载 dataset,应该恢复_reduced_consumed_samples为checkpoint中的值。
1686+
# 2) 如果不加载 dataset,应该保持_reduced_consumed_samples为初始值0,否则如果加载上旧dataloader的reduced_consumed_samples
1687+
# 会导致存储新dataloader时 reduced_consumed_samples 是不正确的值。
1688+
self._reduced_consumed_samples = train_state.get("reduced_consumed_samples", 0) # default 0 for BC
1689+
1690+
dataloader_path = resume_from / self._SAVE_DATALOADER_DIR
1691+
self._resume_dataloader(dataloader_path)
16801692

16811693
if load_checkpoint_cfg.load_scheduler:
16821694
scheduler_path = resume_from / self._SAVE_SCHEDULER_DIR

0 commit comments

Comments
 (0)