Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 44 additions & 26 deletions torchtitan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,16 +466,9 @@ def forward_backward_step(

return loss

def train_step(
self, data_iterator: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]]
):
self.optimizers.zero_grad()
# Save the current step learning rate for logging
lr = self.lr_schedulers.schedulers[0].get_last_lr()[0]
def gradient_computation(self, data_iterator: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]]) -> tuple[list[torch.Tensor], torch.Tensor]:

# Keep these variables local to shorten the code as these are
# the major variables that are used in the training loop.
parallel_dims = self.parallel_dims
self.optimizers.zero_grad()

accumulated_losses = []
# If data runs out during gradient accumulation, that
Expand All @@ -490,48 +483,70 @@ def train_step(
self.job_config.training.max_norm,
foreach=True,
pp_mesh=(
parallel_dims.world_mesh["pp"] if parallel_dims.pp_enabled else None
self.parallel_dims.world_mesh["pp"] if self.parallel_dims.pp_enabled else None
),
ep_enabled=parallel_dims.ep_enabled,
ep_enabled=self.parallel_dims.ep_enabled,
)
self.checkpointer.maybe_wait_for_staging()
self.optimizers.step()
self.lr_schedulers.step()

# Reduce the data collected over gradient accumulation steps.
loss = torch.sum(torch.stack(accumulated_losses))

# log metrics
if not self.metrics_processor.should_log(self.step):
return
return loss, grad_norm

def run_optimizer_step(self) -> None:
self.checkpointer.maybe_wait_for_staging()
self.optimizers.step()
self.lr_schedulers.step()

if parallel_dims.dp_cp_enabled:
def compute_global_loss(self, loss: torch.Tensor) -> tuple[float, float, int]:
if self.parallel_dims.dp_cp_enabled:
loss = loss.detach()
ft_pg = self.ft_manager.loss_sync_pg
global_avg_loss, global_max_loss, global_ntokens_seen = (
dist_utils.dist_mean(loss, parallel_dims.world_mesh["dp_cp"], ft_pg),
dist_utils.dist_max(loss, parallel_dims.world_mesh["dp_cp"], ft_pg),
dist_utils.dist_mean(loss, self.parallel_dims.world_mesh["dp_cp"], ft_pg),
dist_utils.dist_max(loss, self.parallel_dims.world_mesh["dp_cp"], ft_pg),
dist_utils.dist_sum(
torch.tensor(
self.ntokens_seen, dtype=torch.int64, device=self.device
),
parallel_dims.world_mesh["dp_cp"],
self.parallel_dims.world_mesh["dp_cp"],
ft_pg,
),
)
else:
global_avg_loss = global_max_loss = loss.detach().item()
global_ntokens_seen = self.ntokens_seen

return global_avg_loss, global_max_loss, global_ntokens_seen

def train_step(
self, data_iterator: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]]
):

# Run the gradient computation
loss, grad_norm = self.gradient_computation(data_iterator=data_iterator)

# Save the current step learning rate for logging
lr = self.lr_schedulers.schedulers[0].get_last_lr()[0]

# Run the optimizer step
self.run_optimizer_step()

# log metrics
if not self.metrics_processor.should_log(self.step):
return

global_avg_loss, global_max_loss, global_ntokens_seen = self.compute_global_loss(loss=loss)

extra_metrics = {
"n_tokens_seen": global_ntokens_seen,
"lr": lr,
}
self.metrics_processor.log(
self.step,
global_avg_loss,
global_max_loss,
grad_norm.item(),
step=self.step,
global_avg_loss=global_avg_loss,
global_max_loss=global_max_loss,
grad_norm=grad_norm.item(),
extra_metrics=extra_metrics,
)

Expand Down Expand Up @@ -578,7 +593,7 @@ def train(self):
),
):
data_iterator = self.batch_generator(self.dataloader)
while self.step < job_config.training.steps:
while self.should_continue_training():
self.step += 1
self.gc_handler.run(self.step)
try:
Expand Down Expand Up @@ -620,6 +635,9 @@ def train(self):

logger.info("Training completed")

def should_continue_training(self) -> bool:
return self.step < self.job_config.training.steps

def state_dict(self) -> dict[str, Any]:
return {"step": self.step, "ntokens_seen": self.ntokens_seen}

Expand Down