diff --git a/pytorch_translate/evals.py b/pytorch_translate/evals.py index baf13db1..c224aa2b 100644 --- a/pytorch_translate/evals.py +++ b/pytorch_translate/evals.py @@ -160,61 +160,6 @@ def eval_tune_loss(args, trainer, task, subset, extra_state): return extra_state, stop_due_to_tune_loss -def evaluate_bleu( - args, task, extra_state: Dict[str, Any], trainer, averaged_params: OrderedDict -) -> Tuple[Dict[str, Any], bool, bool, List]: - if args.disable_eval_bleu: - extra_state["tune_bleu"]["current"] = 0.0 - return (extra_state, False, True, None) - epoch, offset = extra_state["epoch"], extra_state["batch_offset"] - if args.log_verbose: - print( - f"| Preparing to calculate BLEU score for epoch {epoch}, offset {offset}." - ) - extra_state["tune_bleu"]["current"], translation_samples = calculate_bleu_on_subset( - args=args, - task=task, - epoch_str=f"{epoch:03d}", - offset=offset, - dataset_split=args.valid_subset, - trainer=trainer, - model_params=averaged_params, - ) - if args.log_verbose: - print(f"| Finished calculating BLEU score for epoch {epoch}, offset {offset}.") - - new_best_averaged_checkpoint = False - if ( - extra_state["tune_bleu"]["best"] is None - or extra_state["tune_bleu"]["current"] > extra_state["tune_bleu"]["best"] - ): - extra_state["tune_bleu"]["best"] = extra_state["tune_bleu"]["current"] - extra_state["tune_bleu"]["best_epoch"] = epoch - extra_state["tune_bleu"]["num_since_best"] = 0 - new_best_averaged_checkpoint = True - else: - extra_state["tune_bleu"]["num_since_best"] += 1 - - stop_due_to_tune_bleu = False - if ( - args.stop_no_best_bleu_eval >= 0 - and extra_state["tune_bleu"]["num_since_best"] > args.stop_no_best_bleu_eval - ): - stop_due_to_tune_bleu = True - print( - f"Stopping training due to BLEU score stagnation on tune set - " - f"last best BLEU score of {extra_state['tune_bleu']['best']} " - f"(current score: {extra_state['tune_bleu']['current']}) was " - f"{extra_state['tune_bleu']['num_since_best']} evals ago." - ) - return ( - extra_state, - stop_due_to_tune_bleu, - new_best_averaged_checkpoint, - translation_samples, - ) - - def calculate_bleu_on_subset( args, task, @@ -336,7 +281,7 @@ def save_and_eval( args.save_interval_updates <= 0 or (extra_state["num_iterations"] % args.save_interval_updates != 0) ): - return extra_state, stop_due_to_time_limit, None + return extra_state, stop_due_to_time_limit # Update training time before saving the checkpoint. time_now: float = time.time() @@ -369,34 +314,11 @@ def save_and_eval( f"have a checkpoint_manager defined." ) - # trick to prepare the task for evaluation, e.g. in latent variable model we need to set eval_key in RoundRobinZipDataset - if hasattr(task, "prepare_for_eval") and callable(task.prepare_for_eval): - task.prepare_for_eval() - # Only save checkpoints and eval tune BLEU on the master - all other - # processes will just get the results from the master. - translation_samples: Optional[List] = None if is_master: averaged_params: OrderedDict = checkpoint_manager.get_averaged_params( new_params=trainer.get_model().state_dict() ) - - # TODO: fix after masked lm work completes - if "save_only" not in args or not args.save_only: - ( - extra_state, - stop_due_to_tune_bleu, - new_best_averaged_checkpoint, - translation_samples, - ) = evaluate_bleu( - args=args, - task=task, - extra_state=extra_state, - trainer=trainer, - averaged_params=averaged_params, - ) - else: - new_best_averaged_checkpoint = True - stop_due_to_tune_bleu = False + new_best_averaged_checkpoint = extra_state["tune_eval"]["num_since_best"] == 0 # checkpoint_manager takes ownership of averaged_params. extra_state = checkpoint_manager.save( args=args, @@ -408,26 +330,15 @@ def save_and_eval( checkpoint_manager.save_best_averaged_checkpoint( args=args, trainer=trainer, extra_state=extra_state ) - if hasattr(task, "prepare_for_train") and callable(task.prepare_for_train): - task.prepare_for_train() - # extra_state["tune_bleu"] needs to be sync'ed between master and workers - # since we only do BLEU eval on master, but then need that info for - # determining when to do lr_shrink on all workers. - master_tune_bleu = None master_stop_training = None if is_master: - master_tune_bleu = extra_state["tune_bleu"] master_stop_training = ( - stop_due_to_time_limit - or stop_due_to_tune_loss - or stop_due_to_tune_bleu - or stop_due_to_max_update + stop_due_to_time_limit or stop_due_to_tune_loss or stop_due_to_max_update ) - tune_bleu, stop_training = pytorch_translate_utils.all_gather_from_master( - args=args, data=[master_tune_bleu, master_stop_training] + stop_training = pytorch_translate_utils.all_gather_from_master( + args=args, data=[master_stop_training] ) - extra_state["tune_bleu"] = tune_bleu # TODO: fix after masked lm work completes if "save_only" not in args or not args.save_only: @@ -435,6 +346,5 @@ def save_and_eval( assert ( extra_state["tune_eval"]["loss"] is not None and extra_state["tune_eval"]["perplexity"] is not None - and extra_state["tune_bleu"]["current"] is not None ) - return extra_state, stop_training, translation_samples + return extra_state, stop_training diff --git a/pytorch_translate/options.py b/pytorch_translate/options.py index 1b50a7e6..c20cfe44 100644 --- a/pytorch_translate/options.py +++ b/pytorch_translate/options.py @@ -555,12 +555,12 @@ def expand_optimization_args(group): "in the first place. A value of < 0 disables this.", ) group.add_argument( - "--shrink-lr-no-best-bleu-eval", + "--shrink-lr-no-best-tune-loss", default=5, type=int, metavar="N", help="Decay learning rate after N evals have been run without " - "achieving a better BLEU score than before. This is to achieve " + "achieving a lower tune loss than before. This is to achieve " "decay lr within an epoch, independent of lr_scheduler. " "Note that this is affected by --save-interval-updates in " "how frequently we run BLEU eval in the first place. " diff --git a/pytorch_translate/train.py b/pytorch_translate/train.py index c3474d65..d8a644c6 100644 --- a/pytorch_translate/train.py +++ b/pytorch_translate/train.py @@ -109,12 +109,6 @@ def default_extra_state(args) -> Dict[str, Any]: "lowest_loss": None, "num_since_best": 0, }, - "tune_bleu": { - "current": None, - "best": None, - "best_epoch": None, - "num_since_best": 0, - }, # The list of checkpoint files is actually managed by the # CheckpointManager, which overwrites this placeholder when it saves # checkpoints. @@ -136,8 +130,8 @@ def update_output( num_updates, { "train_ppl": train_ppl, + "tune_loss": extra_state["tune_eval"]["loss"], "tune_ppl": extra_state["tune_eval"]["perplexity"], - "tune_bleu": extra_state["tune_bleu"]["current"], "wps": wps, # translation_samples isn't currently used by the queue reader, # so just pass None for now until we start needing it. @@ -159,7 +153,6 @@ def clear_per_step_extra_state(extra_state: Dict[str, Any]) -> Dict[str, Any]: """ extra_state["tune_eval"]["loss"] = None extra_state["tune_eval"]["perplexity"] = None - extra_state["tune_bleu"]["current"] = None return extra_state @@ -566,11 +559,7 @@ def train( # any case where extra_case does not get populated correctly. extra_state = clear_per_step_extra_state(extra_state) extra_state["batch_offset"] = i + 1 - ( - extra_state, - stop_training_mid_epoch, - translation_samples, - ) = evals.save_and_eval( + extra_state, stop_training_mid_epoch = evals.save_and_eval( args=args, trainer=trainer, task=task, @@ -607,9 +596,9 @@ def train( hasattr(args, "lr_shrink") and args.save_interval_updates > 0 and extra_state["num_iterations"] % args.save_interval_updates == 0 - and args.shrink_lr_no_best_bleu_eval > 0 - and extra_state["tune_bleu"]["num_since_best"] - > args.shrink_lr_no_best_bleu_eval + and args.shrink_lr_no_best_tune_loss > 0 + and extra_state["tune_eval"]["num_since_best"] + > args.shrink_lr_no_best_tune_loss ): current_lr = trainer.optimizer.get_lr() trainer.optimizer.set_lr(current_lr * args.lr_shrink) @@ -626,11 +615,7 @@ def train( # batch_offset being None denotes the end of an epoch. extra_state["batch_offset"] = None - ( - extra_state, - stop_training_end_of_epoch, - translation_samples, - ) = evals.save_and_eval( + extra_state, stop_training_end_of_epoch = evals.save_and_eval( args=args, trainer=trainer, task=task, @@ -662,11 +647,6 @@ def train( if checkpoint_manager: checkpoint_manager.remove_all_checkpoints() - print( - f"| Best BLEU score of {extra_state['tune_bleu']['best']} was from " - f"epoch {extra_state['tune_bleu']['best_epoch']}" - ) - def setup_epoch(args, epoch_itr, trainer): """Sets up data and progress meters for one epoch."""