Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 53 additions & 48 deletions fastseq/optimizer/fairseq/beam_search_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from fairseq.modules.multihead_attention import MultiheadAttention
from fairseq.search import BeamSearch
from fairseq.sequence_generator import SequenceGenerator
from fastseq.ops.ngram_repeat_block import NGramRepeatBlock
from fastseq.utils.api_decorator import replace

@replace(BeamSearch)
Expand Down Expand Up @@ -432,6 +431,49 @@ class SequenceGeneratorV2(SequenceGenerator):
Sequence Generator is optimized by reducing the cached memory usage
during the encoding period for beam search.
"""
@torch.no_grad()
def apply_no_repeat_ngram_cpu(self, tokens,lprobs, bsz,step,
beam_size, no_repeat_ngram_size):
""" Fairseq implementation of blocking
repeated ngrams
"""
banned_list = [[] for bbsz_idx in range(bsz * beam_size)]
cpu_tokens = tokens.cpu()[:, :step + 1].numpy()
check_start_pos = step + 2 - no_repeat_ngram_size
for bbsz_idx in range(bsz * beam_size):
for i in range(check_start_pos):
is_banned = True
for k in range(no_repeat_ngram_size - 1):
if cpu_tokens[bbsz_idx, i + k] != cpu_tokens[
bbsz_idx, check_start_pos + k]:
is_banned = False
break
if is_banned:
banned_list[bbsz_idx].append(
cpu_tokens[bbsz_idx,
i + no_repeat_ngram_size - 1])

def calculate_banned_tokens(bbsz_idx):
"""before decoding the next token, prevent decoding
of ngrams that have already appeared
"""
banned_tokens_per_sample = [
(bbsz_idx, t) for t in banned_list[bbsz_idx]
]
return banned_tokens_per_sample

banned_tokens = []
if step + 2 - no_repeat_ngram_size >= 0:
for bbsz_idx in range(bsz * beam_size):
banned_tokens.extend(calculate_banned_tokens(bbsz_idx))

if banned_tokens:
banned_tokens = torch.LongTensor(banned_tokens)
lprobs.index_put_(
tuple(banned_tokens.t()),
lprobs.new_tensor([-math.inf] * len(banned_tokens)))

return lprobs

@torch.no_grad()
def _generate(self,
Expand Down Expand Up @@ -459,7 +501,13 @@ def _generate(self,
bsz = input_size[0]
src_len = input_size[1]
beam_size = self.beam_size
self.no_repeat_ngram_op = NGramRepeatBlock()
cuda_ngram_op_import = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

initialize no_repeat_ngram_op as None. In case of import exception, just pass. When no_repeat_ngram_op is None, use cpu code. Otherwise, use gpu code. So we don't need to create new var cuda_ngram_op_import.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we do this kind of checking (which kind of ops (e.g., cpu v.s. gpu) to use) inside the ops implementation? So that we do not need to do the similar check twice for fairseq and transformers. It will be easier for us to maintain and change the code in the future.

try:
#pylint: disable=import-outside-toplevel
from fastseq.ops.ngram_repeat_block import NGramRepeatBlock
self.no_repeat_ngram_op = NGramRepeatBlock()
except:
cuda_ngram_op_import = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

log warning msg


if self.match_source_len:
max_len = src_lengths.max().item()
Expand Down Expand Up @@ -524,49 +572,6 @@ def is_finished(sent, step, unfin_idx):
return True
return False

def apply_no_repeat_ngram_cpu(self, tokens,lprobs, bsz,step,
beam_size, no_repeat_ngram_size):
""" Fairseq implementation of blocking
repeated ngrams
"""
banned_list = [[] for bbsz_idx in range(bsz * beam_size)]
cpu_tokens = tokens.cpu()[:, :step + 1].numpy()
check_start_pos = step + 2 - no_repeat_ngram_size
for bbsz_idx in range(bsz * beam_size):
for i in range(check_start_pos):
is_banned = True
for k in range(no_repeat_ngram_size - 1):
if cpu_tokens[bbsz_idx, i + k] != cpu_tokens[
bbsz_idx, check_start_pos + k]:
is_banned = False
break
if is_banned:
banned_list[bbsz_idx].append(
cpu_tokens[bbsz_idx,
i + no_repeat_ngram_size - 1])

def calculate_banned_tokens(bbsz_idx):
"""before decoding the next token, prevent decoding
of ngrams that have already appeared
"""
banned_tokens_per_sample = [
(bbsz_idx, t) for t in banned_list[bbsz_idx]
]
return banned_tokens_per_sample

banned_tokens = []
if step + 2 - no_repeat_ngram_size >= 0:
for bbsz_idx in range(bsz * beam_size):
banned_tokens.extend(calculate_banned_tokens(bbsz_idx))

if banned_tokens:
banned_tokens = torch.LongTensor(banned_tokens)
lprobs.index_put_(
tuple(banned_tokens.t()),
lprobs.new_tensor([-math.inf] * len(banned_tokens)))

return lprobs

def finalize_hypos(step, bbsz_idx, eos_scores):
"""
Finalize the given hypotheses at this step, while keeping the total
Expand Down Expand Up @@ -731,12 +736,12 @@ def replicate_first_beam(tensor, mask):

if self.no_repeat_ngram_size > 0:
#Applying Cuda Op for NGram repeat Blocking
if (tokens.is_cuda and lprobs.is_cuda):
if (tokens.is_cuda and lprobs.is_cuda and cuda_ngram_op_import):
lprobs = self.no_repeat_ngram_op(tokens,lprobs, bsz, step,
beam_size, self.no_repeat_ngram_size)
else:
lprobs = apply_no_repeat_ngram_cpu(tokens, lprobs, bsz,
step, beam_size, self.ngram_repeat_block_size)
lprobs = self.apply_no_repeat_ngram_cpu(tokens, lprobs, bsz,
step, beam_size, self.no_repeat_ngram_size)

cand_scores, cand_indices, cand_beams = self.search.step(
step,
Expand Down
12 changes: 9 additions & 3 deletions fastseq/optimizer/transformers/beam_search_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

from transformers.modeling_bart import BartForConditionalGeneration
from transformers.modeling_t5 import T5ForConditionalGeneration
from fastseq.ops.ngram_repeat_block import NGramRepeatBlock

from fastseq.logging import get_logger
from fastseq.utils.api_decorator import replace
Expand Down Expand Up @@ -650,7 +649,8 @@ def _update_scores(banned_tokens):
cpu_input_ids = input_ids.cpu()
if no_repeat_ngram_size > 0:
#custom op for Ngram repeat blocking
if (input_ids.is_cuda and scores.is_cuda):
if (input_ids.is_cuda and scores.is_cuda and
self.cuda_ngram_op_import):
scores = self.no_repeat_ngram_op(input_ids,scores.float(),
batch_size, cur_len-1, num_beams, no_repeat_ngram_size)
else:
Expand Down Expand Up @@ -725,7 +725,13 @@ def _generate_beam_search(
done = [False for _ in range(batch_size)]

#NGram Repeat block Op
self.no_repeat_ngram_op = NGramRepeatBlock()#.to('cuda', torch.float32)
self.cuda_ngram_op_import = True
try:
#pylint: disable=import-outside-toplevel
from fastseq.ops.ngram_repeat_block import NGramRepeatBlock
self.no_repeat_ngram_op = NGramRepeatBlock()
except:
self.cuda_ngram_op_import = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

log warning msg


while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(
Expand Down