Skip to content

Commit bb98dbc

Browse files
Removes the reduce_loss option, as we don't need to support reduce_loss="sum" any longer. (#1024)
* Removed outdates files * Added tool files * Ran tests * Ran linter * Always uses the real code execution server * Reduced nesting * Removed comments * Removed reduce_loss argument, default is now mean * Removed -n auto from tests. * Fixed linter errors. * Removed reduce_loss arguments from scripts.
1 parent 1b695ca commit bb98dbc

File tree

7 files changed

+10
-67
lines changed

7 files changed

+10
-67
lines changed

open_instruct/finetune.py

Lines changed: 10 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -239,13 +239,6 @@ class FlatArguments:
239239
"Useful if tokenization process is long. Default is 1800 seconds (30 minutes)."
240240
},
241241
)
242-
reduce_loss: str = field(
243-
default="mean",
244-
metadata={
245-
"help": "How to reduce loss over tokens. Options are 'mean' or 'sum'."
246-
"Using 'sum' can improve chat model performance."
247-
},
248-
)
249242
resume_from_checkpoint: Optional[str] = field(
250243
default=None, metadata={"help": "If the training should continue from a checkpoint folder."}
251244
)
@@ -337,8 +330,6 @@ class FlatArguments:
337330
)
338331

339332
def __post_init__(self):
340-
if self.reduce_loss not in ["mean", "sum"]:
341-
raise ValueError("reduce_loss must be either 'mean' or 'sum'")
342333
if self.dataset_name is None and self.dataset_mixer is None and self.dataset_mixer_list is None:
343334
raise ValueError("Need either a dataset name, dataset mixer, or dataset mixer list.")
344335
if (
@@ -552,8 +543,7 @@ def main(args: FlatArguments, tc: TokenizerConfig):
552543
elif args.use_liger_kernel:
553544
from liger_kernel.transformers import AutoLigerKernelForCausalLM
554545

555-
fused_linear_cross_entropy = args.reduce_loss == "mean"
556-
logger.info(f"Attempting to apply liger-kernel. {fused_linear_cross_entropy=}")
546+
logger.info("Attempting to apply liger-kernel. fused_linear_cross_entropy=True")
557547

558548
# Supported models: https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/transformers/monkey_patch.py#L948
559549
model = AutoLigerKernelForCausalLM.from_pretrained(
@@ -565,7 +555,7 @@ def main(args: FlatArguments, tc: TokenizerConfig):
565555
low_cpu_mem_usage=args.low_cpu_mem_usage,
566556
attn_implementation="flash_attention_2" if args.use_flash_attn else "eager",
567557
# liger-kernel specific args
568-
fused_linear_cross_entropy=fused_linear_cross_entropy,
558+
fused_linear_cross_entropy=True,
569559
)
570560
else:
571561
model = AutoModelForCausalLM.from_pretrained(
@@ -783,47 +773,17 @@ def main(args: FlatArguments, tc: TokenizerConfig):
783773
with accelerator.accumulate(model):
784774
if args.load_balancing_loss:
785775
outputs = model(**batch, use_cache=False, output_router_logits=True)
776+
total_aux_loss += outputs.aux_loss.detach().float()
786777
else:
787778
# Standard forward pass
788779
outputs = model(**batch, use_cache=False)
789780

790-
if args.reduce_loss == "mean":
791-
loss = outputs.loss
792-
else:
793-
# reduce loss is sum
794-
# this ensures that we weight all tokens in the dataset equally,
795-
# rather than weighting each overall example equally when
796-
# using high amounts of gradient accumulation.
797-
# this can result in > 5 point improvements in AlpacaEval
798-
# see https://github.com/huggingface/transformers/issues/24725 for
799-
# more discussion and details.
800-
logits = outputs.logits
801-
labels = batch["labels"]
802-
# Shift so that tokens < n predict n
803-
shift_logits = logits[..., :-1, :].contiguous()
804-
shift_labels = labels[..., 1:].contiguous()
805-
# Release logits to avoid memory leak
806-
del logits
807-
808-
# Flatten the tokens
809-
loss_fct = torch.nn.CrossEntropyLoss(reduction="sum")
810-
shift_logits = shift_logits.view(-1, embedding_size)
811-
shift_labels = shift_labels.view(-1)
812-
# Enable model parallelism
813-
shift_labels = shift_labels.to(shift_logits.device)
814-
loss = loss_fct(shift_logits, shift_labels)
815-
# Release shift_logits to avoid memory leak
816-
del shift_logits
817-
if args.load_balancing_loss:
818-
aux_loss = args.load_balancing_weight * outputs.aux_loss
819-
loss += aux_loss
781+
loss = outputs.loss
820782
del outputs
821783

822784
# We keep track of the loss at each logged step
823785
total_loss += loss.detach().float()
824786
accelerator.backward(loss)
825-
if args.load_balancing_loss:
826-
total_aux_loss += aux_loss.detach().float()
827787
# clip gradient norm. don't do this with deepspeed
828788
if accelerator.sync_gradients and args.clip_grad_norm > 0:
829789
accelerator.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
@@ -842,7 +802,7 @@ def main(args: FlatArguments, tc: TokenizerConfig):
842802
total_tokens_including_padding = accelerator.gather(total_token_including_padding).sum().item()
843803
total_tokens_this_log_period = accelerator.gather(local_total_tokens_this_log_period).sum().item()
844804
local_total_tokens_this_log_period.zero_()
845-
pred_tokens_this_log_period = accelerator.gather(local_pred_tokens_this_log_period).sum().item()
805+
accelerator.gather(local_pred_tokens_this_log_period).sum().item()
846806
local_pred_tokens_this_log_period.zero_()
847807

848808
avg_tokens_per_batch = (
@@ -900,22 +860,11 @@ def main(args: FlatArguments, tc: TokenizerConfig):
900860
# period. We want the avg over each optimizer step (which scales with the
901861
# global batch size), and the average loss per token and per prediction
902862
# token (which are roughly independent of global batch size).
903-
if args.reduce_loss == "mean":
904-
total_fwd_passes = (
905-
args.logging_steps * args.gradient_accumulation_steps * accelerator.num_processes
906-
)
907-
avg_loss = sum_loss / total_fwd_passes
908-
metrics_to_log["train_loss"] = avg_loss
909-
else:
910-
avg_loss = sum_loss / total_tokens_this_log_period
911-
# The loss per pred tok is the closest analogue to what we report as the
912-
# avg_loss in the "mean" case
913-
avg_loss_per_pred_tok = sum_loss / pred_tokens_this_log_period
914-
total_optim_steps = args.logging_steps * accelerator.num_processes
915-
avg_sum_loss = sum_loss / total_optim_steps
916-
metrics_to_log["train_sum_loss"] = avg_sum_loss
917-
metrics_to_log["train_loss_per_total_tok"] = avg_loss
918-
metrics_to_log["train_loss_per_pred_tok"] = avg_loss_per_pred_tok
863+
total_fwd_passes = (
864+
args.logging_steps * args.gradient_accumulation_steps * accelerator.num_processes
865+
)
866+
avg_loss = sum_loss / total_fwd_passes
867+
metrics_to_log["train_loss"] = avg_loss
919868
if args.verbose:
920869
sec_per_step = (time.time() - start_time) / (completed_steps - resume_step)
921870
steps_remaining = args.max_train_steps - completed_steps

scripts/train/debug/finetune.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ uv run accelerate launch \
1515
--output_dir output/ \
1616
--report_to wandb \
1717
--logging_steps 1 \
18-
--reduce_loss sum \
1918
--model_revision main \
2019
--dataset_mixer_list allenai/tulu-3-sft-personas-algebra 100 \
2120
--add_bos \

scripts/train/olmo2/finetune_13b.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ python mason.py \
3030
--warmup_ratio 0.03 \
3131
--weight_decay 0.0 \
3232
--num_train_epochs 2 \
33-
--reduce_loss sum \
3433
--gradient_checkpointing \
3534
--report_to wandb \
3635
--with_tracking \

scripts/train/olmo2/finetune_32b.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ python mason.py \
3030
--warmup_ratio 0.03 \
3131
--weight_decay 0.0 \
3232
--num_train_epochs 2 \
33-
--reduce_loss sum \
3433
--gradient_checkpointing \
3534
--report_to wandb \
3635
--with_tracking \

scripts/train/olmo2/finetune_7b.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ python mason.py \
3030
--warmup_ratio 0.03 \
3131
--weight_decay 0.0 \
3232
--num_train_epochs 2 \
33-
--reduce_loss sum \
3433
--report_to wandb \
3534
--with_tracking \
3635
--logging_steps 1 \

scripts/train/qwen/finetune_7b.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ python mason.py \
2828
--warmup_ratio 0.03 \
2929
--weight_decay 0.0 \
3030
--num_train_epochs 2 \
31-
--reduce_loss sum \
3231
--use_flash_attn \
3332
--gradient_checkpointing \
3433
--report_to wandb \

scripts/train/tulu3/finetune_8b.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ python mason.py \
2828
--warmup_ratio 0.03 \
2929
--weight_decay 0.0 \
3030
--num_train_epochs 2 \
31-
--reduce_loss sum \
3231
--use_flash_attn \
3332
--gradient_checkpointing \
3433
--report_to wandb \

0 commit comments

Comments
 (0)