Skip to content

Commit c1ab42b

Browse files
committed
Allow KD loss during eval
Signed-off-by: Asha Anoosheh <[email protected]>
1 parent 99184ff commit c1ab42b

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

megatron/post_training/loss_func.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,16 +55,18 @@ def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor, model: GPTMo
5555
num_tokens = loss_mask.sum().clone().detach().to(torch.int)
5656
report = {'lm loss': torch.cat([loss_lm.clone().detach().view(1), num_tokens.view(1)])}
5757

58-
if model.training and args.export_kd_teacher_load:
58+
if args.export_kd_teacher_load:
5959
# [ModelOpt]: Handle knowledge distillation
6060
losses = model.compute_kd_loss(
6161
student_loss=loss_lm,
6262
loss_reduction_fn=lambda x: _mask_loss(x, loss_mask),
6363
)
64-
loss = losses["kd_loss"]
6564

6665
report["total loss"] = torch.cat([losses["kd_loss"].clone().detach().view(1), num_tokens.view(1)])
6766
report["logits distillation loss"] = torch.cat([losses["logits_loss"].clone().detach().view(1), num_tokens.view(1)])
6867
report["intermediate distillation loss"] = torch.cat([losses["intermediate_loss"].clone().detach().view(1), num_tokens.view(1)])
6968

69+
if model.training:
70+
loss = losses["kd_loss"]
71+
7072
return loss, num_tokens, report

megatron/training/training.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1300,7 +1300,10 @@ def train_step(forward_step_func, data_iterator, model, optimizer, opt_param_sch
13001300
if has_nvidia_modelopt:
13011301
# [ModelOpt]: Pipeline-parallel Distillation stacks student and teacher tensors
13021302
adjust_tensor_shapes_fn = get_tensor_shapes_adjust_fn_for_distillation(
1303-
model, args.seq_length, args.micro_batch_size, args.decoder_seq_length
1303+
model,
1304+
seq_length=args.seq_length,
1305+
micro_batch_size=args.micro_batch_size,
1306+
decoder_seq_length=args.decoder_seq_length,
13041307
)
13051308
else:
13061309
adjust_tensor_shapes_fn = None

0 commit comments

Comments
 (0)