Skip to content
Merged
71 changes: 10 additions & 61 deletions open_instruct/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,13 +239,6 @@ class FlatArguments:
"Useful if tokenization process is long. Default is 1800 seconds (30 minutes)."
},
)
reduce_loss: str = field(
default="mean",
metadata={
"help": "How to reduce loss over tokens. Options are 'mean' or 'sum'."
"Using 'sum' can improve chat model performance."
},
)
resume_from_checkpoint: Optional[str] = field(
default=None, metadata={"help": "If the training should continue from a checkpoint folder."}
)
Expand Down Expand Up @@ -337,8 +330,6 @@ class FlatArguments:
)

def __post_init__(self):
if self.reduce_loss not in ["mean", "sum"]:
raise ValueError("reduce_loss must be either 'mean' or 'sum'")
if self.dataset_name is None and self.dataset_mixer is None and self.dataset_mixer_list is None:
raise ValueError("Need either a dataset name, dataset mixer, or dataset mixer list.")
if (
Expand Down Expand Up @@ -552,8 +543,7 @@ def main(args: FlatArguments, tc: TokenizerConfig):
elif args.use_liger_kernel:
from liger_kernel.transformers import AutoLigerKernelForCausalLM

fused_linear_cross_entropy = args.reduce_loss == "mean"
logger.info(f"Attempting to apply liger-kernel. {fused_linear_cross_entropy=}")
logger.info("Attempting to apply liger-kernel. fused_linear_cross_entropy=True")

# Supported models: https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/transformers/monkey_patch.py#L948
model = AutoLigerKernelForCausalLM.from_pretrained(
Expand All @@ -565,7 +555,7 @@ def main(args: FlatArguments, tc: TokenizerConfig):
low_cpu_mem_usage=args.low_cpu_mem_usage,
attn_implementation="flash_attention_2" if args.use_flash_attn else "eager",
# liger-kernel specific args
fused_linear_cross_entropy=fused_linear_cross_entropy,
fused_linear_cross_entropy=True,
)
else:
model = AutoModelForCausalLM.from_pretrained(
Expand Down Expand Up @@ -783,47 +773,17 @@ def main(args: FlatArguments, tc: TokenizerConfig):
with accelerator.accumulate(model):
if args.load_balancing_loss:
outputs = model(**batch, use_cache=False, output_router_logits=True)
total_aux_loss += outputs.aux_loss.detach().float()
else:
# Standard forward pass
outputs = model(**batch, use_cache=False)

if args.reduce_loss == "mean":
loss = outputs.loss
else:
# reduce loss is sum
# this ensures that we weight all tokens in the dataset equally,
# rather than weighting each overall example equally when
# using high amounts of gradient accumulation.
# this can result in > 5 point improvements in AlpacaEval
# see https://github.com/huggingface/transformers/issues/24725 for
# more discussion and details.
logits = outputs.logits
labels = batch["labels"]
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Release logits to avoid memory leak
del logits

# Flatten the tokens
loss_fct = torch.nn.CrossEntropyLoss(reduction="sum")
shift_logits = shift_logits.view(-1, embedding_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
# Release shift_logits to avoid memory leak
del shift_logits
if args.load_balancing_loss:
aux_loss = args.load_balancing_weight * outputs.aux_loss
loss += aux_loss
loss = outputs.loss
del outputs

# We keep track of the loss at each logged step
total_loss += loss.detach().float()
accelerator.backward(loss)
if args.load_balancing_loss:
total_aux_loss += aux_loss.detach().float()
# clip gradient norm. don't do this with deepspeed
if accelerator.sync_gradients and args.clip_grad_norm > 0:
accelerator.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
Expand All @@ -842,7 +802,7 @@ def main(args: FlatArguments, tc: TokenizerConfig):
total_tokens_including_padding = accelerator.gather(total_token_including_padding).sum().item()
total_tokens_this_log_period = accelerator.gather(local_total_tokens_this_log_period).sum().item()
local_total_tokens_this_log_period.zero_()
pred_tokens_this_log_period = accelerator.gather(local_pred_tokens_this_log_period).sum().item()
accelerator.gather(local_pred_tokens_this_log_period).sum().item()
local_pred_tokens_this_log_period.zero_()

avg_tokens_per_batch = (
Expand Down Expand Up @@ -900,22 +860,11 @@ def main(args: FlatArguments, tc: TokenizerConfig):
# period. We want the avg over each optimizer step (which scales with the
# global batch size), and the average loss per token and per prediction
# token (which are roughly independent of global batch size).
if args.reduce_loss == "mean":
total_fwd_passes = (
args.logging_steps * args.gradient_accumulation_steps * accelerator.num_processes
)
avg_loss = sum_loss / total_fwd_passes
metrics_to_log["train_loss"] = avg_loss
else:
avg_loss = sum_loss / total_tokens_this_log_period
# The loss per pred tok is the closest analogue to what we report as the
# avg_loss in the "mean" case
avg_loss_per_pred_tok = sum_loss / pred_tokens_this_log_period
total_optim_steps = args.logging_steps * accelerator.num_processes
avg_sum_loss = sum_loss / total_optim_steps
metrics_to_log["train_sum_loss"] = avg_sum_loss
metrics_to_log["train_loss_per_total_tok"] = avg_loss
metrics_to_log["train_loss_per_pred_tok"] = avg_loss_per_pred_tok
total_fwd_passes = (
args.logging_steps * args.gradient_accumulation_steps * accelerator.num_processes
)
avg_loss = sum_loss / total_fwd_passes
metrics_to_log["train_loss"] = avg_loss
if args.verbose:
sec_per_step = (time.time() - start_time) / (completed_steps - resume_step)
steps_remaining = args.max_train_steps - completed_steps
Expand Down
1 change: 0 additions & 1 deletion scripts/train/debug/finetune.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ uv run accelerate launch \
--output_dir output/ \
--report_to wandb \
--logging_steps 1 \
--reduce_loss sum \
--model_revision main \
--dataset_mixer_list allenai/tulu-3-sft-personas-algebra 100 \
--add_bos \
Expand Down
1 change: 0 additions & 1 deletion scripts/train/olmo2/finetune_13b.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ python mason.py \
--warmup_ratio 0.03 \
--weight_decay 0.0 \
--num_train_epochs 2 \
--reduce_loss sum \
--gradient_checkpointing \
--report_to wandb \
--with_tracking \
Expand Down
1 change: 0 additions & 1 deletion scripts/train/olmo2/finetune_32b.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ python mason.py \
--warmup_ratio 0.03 \
--weight_decay 0.0 \
--num_train_epochs 2 \
--reduce_loss sum \
--gradient_checkpointing \
--report_to wandb \
--with_tracking \
Expand Down
1 change: 0 additions & 1 deletion scripts/train/olmo2/finetune_7b.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ python mason.py \
--warmup_ratio 0.03 \
--weight_decay 0.0 \
--num_train_epochs 2 \
--reduce_loss sum \
--report_to wandb \
--with_tracking \
--logging_steps 1 \
Expand Down
1 change: 0 additions & 1 deletion scripts/train/qwen/finetune_7b.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ python mason.py \
--warmup_ratio 0.03 \
--weight_decay 0.0 \
--num_train_epochs 2 \
--reduce_loss sum \
--use_flash_attn \
--gradient_checkpointing \
--report_to wandb \
Expand Down
1 change: 0 additions & 1 deletion scripts/train/tulu3/finetune_8b.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ python mason.py \
--warmup_ratio 0.03 \
--weight_decay 0.0 \
--num_train_epochs 2 \
--reduce_loss sum \
--use_flash_attn \
--gradient_checkpointing \
--report_to wandb \
Expand Down