@@ -239,13 +239,6 @@ class FlatArguments:
239
239
"Useful if tokenization process is long. Default is 1800 seconds (30 minutes)."
240
240
},
241
241
)
242
- reduce_loss : str = field (
243
- default = "mean" ,
244
- metadata = {
245
- "help" : "How to reduce loss over tokens. Options are 'mean' or 'sum'."
246
- "Using 'sum' can improve chat model performance."
247
- },
248
- )
249
242
resume_from_checkpoint : Optional [str ] = field (
250
243
default = None , metadata = {"help" : "If the training should continue from a checkpoint folder." }
251
244
)
@@ -337,8 +330,6 @@ class FlatArguments:
337
330
)
338
331
339
332
def __post_init__ (self ):
340
- if self .reduce_loss not in ["mean" , "sum" ]:
341
- raise ValueError ("reduce_loss must be either 'mean' or 'sum'" )
342
333
if self .dataset_name is None and self .dataset_mixer is None and self .dataset_mixer_list is None :
343
334
raise ValueError ("Need either a dataset name, dataset mixer, or dataset mixer list." )
344
335
if (
@@ -552,8 +543,7 @@ def main(args: FlatArguments, tc: TokenizerConfig):
552
543
elif args .use_liger_kernel :
553
544
from liger_kernel .transformers import AutoLigerKernelForCausalLM
554
545
555
- fused_linear_cross_entropy = args .reduce_loss == "mean"
556
- logger .info (f"Attempting to apply liger-kernel. { fused_linear_cross_entropy = } " )
546
+ logger .info ("Attempting to apply liger-kernel. fused_linear_cross_entropy=True" )
557
547
558
548
# Supported models: https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/transformers/monkey_patch.py#L948
559
549
model = AutoLigerKernelForCausalLM .from_pretrained (
@@ -565,7 +555,7 @@ def main(args: FlatArguments, tc: TokenizerConfig):
565
555
low_cpu_mem_usage = args .low_cpu_mem_usage ,
566
556
attn_implementation = "flash_attention_2" if args .use_flash_attn else "eager" ,
567
557
# liger-kernel specific args
568
- fused_linear_cross_entropy = fused_linear_cross_entropy ,
558
+ fused_linear_cross_entropy = True ,
569
559
)
570
560
else :
571
561
model = AutoModelForCausalLM .from_pretrained (
@@ -783,47 +773,17 @@ def main(args: FlatArguments, tc: TokenizerConfig):
783
773
with accelerator .accumulate (model ):
784
774
if args .load_balancing_loss :
785
775
outputs = model (** batch , use_cache = False , output_router_logits = True )
776
+ total_aux_loss += outputs .aux_loss .detach ().float ()
786
777
else :
787
778
# Standard forward pass
788
779
outputs = model (** batch , use_cache = False )
789
780
790
- if args .reduce_loss == "mean" :
791
- loss = outputs .loss
792
- else :
793
- # reduce loss is sum
794
- # this ensures that we weight all tokens in the dataset equally,
795
- # rather than weighting each overall example equally when
796
- # using high amounts of gradient accumulation.
797
- # this can result in > 5 point improvements in AlpacaEval
798
- # see https://github.com/huggingface/transformers/issues/24725 for
799
- # more discussion and details.
800
- logits = outputs .logits
801
- labels = batch ["labels" ]
802
- # Shift so that tokens < n predict n
803
- shift_logits = logits [..., :- 1 , :].contiguous ()
804
- shift_labels = labels [..., 1 :].contiguous ()
805
- # Release logits to avoid memory leak
806
- del logits
807
-
808
- # Flatten the tokens
809
- loss_fct = torch .nn .CrossEntropyLoss (reduction = "sum" )
810
- shift_logits = shift_logits .view (- 1 , embedding_size )
811
- shift_labels = shift_labels .view (- 1 )
812
- # Enable model parallelism
813
- shift_labels = shift_labels .to (shift_logits .device )
814
- loss = loss_fct (shift_logits , shift_labels )
815
- # Release shift_logits to avoid memory leak
816
- del shift_logits
817
- if args .load_balancing_loss :
818
- aux_loss = args .load_balancing_weight * outputs .aux_loss
819
- loss += aux_loss
781
+ loss = outputs .loss
820
782
del outputs
821
783
822
784
# We keep track of the loss at each logged step
823
785
total_loss += loss .detach ().float ()
824
786
accelerator .backward (loss )
825
- if args .load_balancing_loss :
826
- total_aux_loss += aux_loss .detach ().float ()
827
787
# clip gradient norm. don't do this with deepspeed
828
788
if accelerator .sync_gradients and args .clip_grad_norm > 0 :
829
789
accelerator .clip_grad_norm_ (model .parameters (), args .clip_grad_norm )
@@ -842,7 +802,7 @@ def main(args: FlatArguments, tc: TokenizerConfig):
842
802
total_tokens_including_padding = accelerator .gather (total_token_including_padding ).sum ().item ()
843
803
total_tokens_this_log_period = accelerator .gather (local_total_tokens_this_log_period ).sum ().item ()
844
804
local_total_tokens_this_log_period .zero_ ()
845
- pred_tokens_this_log_period = accelerator .gather (local_pred_tokens_this_log_period ).sum ().item ()
805
+ accelerator .gather (local_pred_tokens_this_log_period ).sum ().item ()
846
806
local_pred_tokens_this_log_period .zero_ ()
847
807
848
808
avg_tokens_per_batch = (
@@ -900,22 +860,11 @@ def main(args: FlatArguments, tc: TokenizerConfig):
900
860
# period. We want the avg over each optimizer step (which scales with the
901
861
# global batch size), and the average loss per token and per prediction
902
862
# token (which are roughly independent of global batch size).
903
- if args .reduce_loss == "mean" :
904
- total_fwd_passes = (
905
- args .logging_steps * args .gradient_accumulation_steps * accelerator .num_processes
906
- )
907
- avg_loss = sum_loss / total_fwd_passes
908
- metrics_to_log ["train_loss" ] = avg_loss
909
- else :
910
- avg_loss = sum_loss / total_tokens_this_log_period
911
- # The loss per pred tok is the closest analogue to what we report as the
912
- # avg_loss in the "mean" case
913
- avg_loss_per_pred_tok = sum_loss / pred_tokens_this_log_period
914
- total_optim_steps = args .logging_steps * accelerator .num_processes
915
- avg_sum_loss = sum_loss / total_optim_steps
916
- metrics_to_log ["train_sum_loss" ] = avg_sum_loss
917
- metrics_to_log ["train_loss_per_total_tok" ] = avg_loss
918
- metrics_to_log ["train_loss_per_pred_tok" ] = avg_loss_per_pred_tok
863
+ total_fwd_passes = (
864
+ args .logging_steps * args .gradient_accumulation_steps * accelerator .num_processes
865
+ )
866
+ avg_loss = sum_loss / total_fwd_passes
867
+ metrics_to_log ["train_loss" ] = avg_loss
919
868
if args .verbose :
920
869
sec_per_step = (time .time () - start_time ) / (completed_steps - resume_step )
921
870
steps_remaining = args .max_train_steps - completed_steps
0 commit comments