diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index d56dce98..f4499970 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -35,12 +35,11 @@ def _torch_cross_entropy_forward_backward( logits_ if logits_scale_factor == 1 else logits_ * logits_scale_factor, target ) else: - loss = ( - torch.nn.functional.cross_entropy( - logits_ if logits_scale_factor == 1 else logits_ * logits_scale_factor, target, reduction="none" - ) - * loss_mask - ).mean() + per_token_loss = torch.nn.functional.cross_entropy( + logits_ if logits_scale_factor == 1 else logits_ * logits_scale_factor, target, reduction="none" + ) + loss = (per_token_loss * loss_mask).sum() / (loss_mask.sum() if loss_mask.sum() > 0 else 1.0) + if grad_output is None: grad = None else: @@ -133,7 +132,9 @@ def _fused_cross_entropy_forward_backward( if logits_scale_factor != 1.0: grad *= logits_scale_factor if loss_mask is not None: - grad *= loss_mask + # Take into account the modified denominator due to loss masking. + loss_masking_grad_factor = logits.size(0) / loss_mask.sum() if loss_mask.sum() > 0 else 1.0 + grad *= loss_mask * loss_masking_grad_factor grad = grad.to(logits.dtype) # loss = mean(log(sum_exp_logits) - sum(probabilities * logits)) @@ -149,9 +150,10 @@ def _fused_cross_entropy_forward_backward( if loss_mask is not None: per_sample_loss = per_sample_loss * loss_mask - loss = per_sample_loss.mean() + loss_mask_sum = loss_mask.sum() if loss_mask is not None else torch.tensor(per_sample_loss.numel()) + loss = per_sample_loss.sum() / (loss_mask_sum if loss_mask_sum > 0 else 1.0) if target_format != TargetFormat.labels and group is not None: - all_reduce(loss, op=ReduceOp.MEAN, group=group) + all_reduce(loss, op=ReduceOp.AVG, group=group) return loss, grad @@ -274,10 +276,10 @@ def _torch_reverse_kl_forward_backward( loss_per_sample = torch.nn.functional.kl_div( teacher_log_probs, student_log_probs, reduction="none", log_target=True ).sum(dim=-1) - loss = (loss_per_sample * loss_mask).mean() + loss = (loss_per_sample * loss_mask).sum() / (loss_mask.sum() if loss_mask.sum() > 0 else 1.0) if group is not None and target_format != TargetFormat.labels: - all_reduce(loss, op=ReduceOp.MEAN, group=group) + all_reduce(loss, op=ReduceOp.AVG, group=group) if grad_output is not None: loss.backward(torch.full_like(loss, grad_output)) diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 9a878c49..5c1acaa3 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -23,7 +23,7 @@ def _reverse_kl_loss( ): scaled_target = target / teacher_softmax_temperature - scaled_target = torch.clamp(target, min=-50, max=50) + scaled_target = torch.clamp(scaled_target, min=-50, max=50) teacher_log_probs = torch.log_softmax(scaled_target, dim=-1) with torch.enable_grad(): @@ -42,7 +42,7 @@ def _reverse_kl_loss( loss_per_sample = torch.nn.functional.kl_div( teacher_log_probs, student_log_probs, reduction="none", log_target=True ).sum(dim=-1) - loss = (loss_per_sample * loss_mask.flatten()).mean() + loss = (loss_per_sample * loss_mask.flatten()).sum() / loss_mask.sum() return loss @@ -84,7 +84,7 @@ def _lm_head( ) if loss_mask is not None: loss = loss * loss_mask.flatten() - loss = loss.mean() + loss = loss.sum() / (loss_mask.sum() if loss_mask is not None else torch.tensor(loss.numel())) else: loss = torch.nn.functional.cross_entropy(logits.flatten(0, -2), target.flatten()) loss.backward(torch.full_like(loss, grad_output))