Skip to content

fix loss masking #345

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 13 additions & 11 deletions fast_llm/functional/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,11 @@ def _torch_cross_entropy_forward_backward(
logits_ if logits_scale_factor == 1 else logits_ * logits_scale_factor, target
)
else:
loss = (
torch.nn.functional.cross_entropy(
logits_ if logits_scale_factor == 1 else logits_ * logits_scale_factor, target, reduction="none"
)
* loss_mask
).mean()
per_token_loss = torch.nn.functional.cross_entropy(
logits_ if logits_scale_factor == 1 else logits_ * logits_scale_factor, target, reduction="none"
)
loss = (per_token_loss * loss_mask).sum() / (loss_mask.sum() if loss_mask.sum() > 0 else 1.0)

if grad_output is None:
grad = None
else:
Expand Down Expand Up @@ -133,7 +132,9 @@ def _fused_cross_entropy_forward_backward(
if logits_scale_factor != 1.0:
grad *= logits_scale_factor
if loss_mask is not None:
grad *= loss_mask
# Take into account the modified denominator due to loss masking.
loss_masking_grad_factor = logits.size(0) / loss_mask.sum() if loss_mask.sum() > 0 else 1.0
grad *= loss_mask * loss_masking_grad_factor
grad = grad.to(logits.dtype)

# loss = mean(log(sum_exp_logits) - sum(probabilities * logits))
Expand All @@ -149,9 +150,10 @@ def _fused_cross_entropy_forward_backward(
if loss_mask is not None:
per_sample_loss = per_sample_loss * loss_mask

loss = per_sample_loss.mean()
loss_mask_sum = loss_mask.sum() if loss_mask is not None else torch.tensor(per_sample_loss.numel())
loss = per_sample_loss.sum() / (loss_mask_sum if loss_mask_sum > 0 else 1.0)
if target_format != TargetFormat.labels and group is not None:
all_reduce(loss, op=ReduceOp.MEAN, group=group)
all_reduce(loss, op=ReduceOp.AVG, group=group)

return loss, grad

Expand Down Expand Up @@ -274,10 +276,10 @@ def _torch_reverse_kl_forward_backward(
loss_per_sample = torch.nn.functional.kl_div(
teacher_log_probs, student_log_probs, reduction="none", log_target=True
).sum(dim=-1)
loss = (loss_per_sample * loss_mask).mean()
loss = (loss_per_sample * loss_mask).sum() / (loss_mask.sum() if loss_mask.sum() > 0 else 1.0)

if group is not None and target_format != TargetFormat.labels:
all_reduce(loss, op=ReduceOp.MEAN, group=group)
all_reduce(loss, op=ReduceOp.AVG, group=group)

if grad_output is not None:
loss.backward(torch.full_like(loss, grad_output))
Expand Down
6 changes: 3 additions & 3 deletions tests/layers/test_lm_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def _reverse_kl_loss(
):
scaled_target = target / teacher_softmax_temperature

scaled_target = torch.clamp(target, min=-50, max=50)
scaled_target = torch.clamp(scaled_target, min=-50, max=50)
teacher_log_probs = torch.log_softmax(scaled_target, dim=-1)

with torch.enable_grad():
Expand All @@ -42,7 +42,7 @@ def _reverse_kl_loss(
loss_per_sample = torch.nn.functional.kl_div(
teacher_log_probs, student_log_probs, reduction="none", log_target=True
).sum(dim=-1)
loss = (loss_per_sample * loss_mask.flatten()).mean()
loss = (loss_per_sample * loss_mask.flatten()).sum() / loss_mask.sum()
return loss


Expand Down Expand Up @@ -84,7 +84,7 @@ def _lm_head(
)
if loss_mask is not None:
loss = loss * loss_mask.flatten()
loss = loss.mean()
loss = loss.sum() / (loss_mask.sum() if loss_mask is not None else torch.tensor(loss.numel()))
else:
loss = torch.nn.functional.cross_entropy(logits.flatten(0, -2), target.flatten())
loss.backward(torch.full_like(loss, grad_output))
Expand Down