diff --git a/pytorch_translate/generate.py b/pytorch_translate/generate.py index c232f269..057f35dd 100644 --- a/pytorch_translate/generate.py +++ b/pytorch_translate/generate.py @@ -210,7 +210,10 @@ def _generate_score(models, args, task, dataset): for trans_info in _iter_translations( args, task, dataset, translations, align_dict, rescorer ): - scorer.add(trans_info.target_tokens, trans_info.hypo_tokens) + if hasattr(scorer, "add_string"): + scorer.add_string(trans_info.target_str, trans_info.hypo_str) + else: + scorer.add(trans_info.target_tokens, trans_info.hypo_tokens) if oracle_scorer is not None: oracle_scorer.add(trans_info.target_tokens, trans_info.best_hypo_tokens) if rescoring_bleu_scorer is not None: