Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 43 additions & 3 deletions fastseq_cli/transformers_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import argparse
import json
from pathlib import Path
from typing import List

import torch
from tqdm import tqdm
Expand All @@ -17,6 +18,37 @@ def chunks(lst, n):
for i in range(0, len(lst), n):
yield lst[i:i + n]

def sort_sentences(sents: List[str], reverse: bool=False):
"""Sort the input sentences by their length.

Args:
sents (List[str): input sentences.
reverse (bool): indicate the order is ascending(False) or descending.

Returns:
tuple(List[str, List[int]): the sorted sentences and
the indices in the original input list.
"""
is_ascending = -1 if reverse else 1
sorted_idx = sorted(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any perf for large data?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yuyan2do Do you mean the benchmarking result on a larger dataset than the data in our benchmark script?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also have similar worry, as it loaded all data into memory, and then sort. Please check perf on a larger dataset.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is the test results on a larger dataset

  • 5,242,880 of examples (~ 16GB): sort took ~4.4 seconds; unsort took: 1.4 seconds;
  • 15,728,640 of examples (~ 48GB): sort took ~14.0 seconds; unsort took: 4.6 seconds;

range(len(sents)), key=lambda i: len(sents[i])*is_ascending)
sorted_sents = [sents[i] for i in sorted_idx]
return sorted_sents, sorted_idx

def unsort_sentences(sents: List[str], sorted_idx: List[int]):
"""Unsort the sents to be the order specified by sorted_idx.

Args:
sents (List[str]): a list of input strings.
sorted_idx (List[int]): the order that will be restored.

Returns:
List[str]: the unsorted list of strings.
"""
result = [''] * len(sents)
for cur_idx, org_idx in enumerate(sorted_idx):
result[org_idx] = sents[cur_idx]
return result

def generate_summaries_or_translations(
examples: list,
Expand All @@ -34,6 +66,8 @@ def generate_summaries_or_translations(
"""Run generation"""
if fastseq_opt:
import fastseq #pylint: disable=import-outside-toplevel
examples, sorted_idx = sort_sentences(examples, reverse=True)

fout = Path(out_file).open("w", encoding="utf-8")
model_name = str(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
Expand All @@ -47,6 +81,7 @@ def generate_summaries_or_translations(
# update config with summarization specific params
use_task_specific_params(model, task)

hypothesis = []
for batch in tqdm(list(chunks(examples, batch_size))):
if "t5" in model_name:
batch = [model.config.prefix + text for text in batch]
Expand All @@ -66,9 +101,14 @@ def generate_summaries_or_translations(
dec = tokenizer.batch_decode(summaries,
skip_special_tokens=True,
clean_up_tokenization_spaces=False)
for hypothesis in dec:
fout.write(hypothesis + "\n")
fout.flush()
hypothesis.extend(dec)

if fastseq_opt:
hypothesis = unsort_sentences(hypothesis, sorted_idx)

for hypo in hypothesis:
fout.write(hypo + "\n")
fout.flush()


def run_generate():
Expand Down
Loading