Skip to content

Commit 18925a8

Browse files
committed
case of empty stop_word_ids
Signed-off-by: jiant <[email protected]>
1 parent cd3dd45 commit 18925a8

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

tensorrt_llm/sampling_params.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -403,13 +403,16 @@ def _encode(tokenizer, text, add_special_tokens):
403403
and isinstance(generation_config.eos_token_id, List)
404404
and all(isinstance(i, int) for i in generation_config.eos_token_id)
405405
):
406-
all_stop_tokens_id = set(i for sublist in self._stop_word_ids for i in sublist)
407-
from_generation_stop_tokens = [
408-
i for i in generation_config.eos_token_id if i not in all_stop_tokens_id
409-
]
410-
411-
if from_generation_stop_tokens:
412-
self._stop_word_ids.append(from_generation_stop_tokens)
406+
if self._stop_word_ids:
407+
all_stop_tokens_id = set(i for sublist in self._stop_word_ids for i in sublist)
408+
from_generation_stop_tokens = [
409+
i for i in generation_config.eos_token_id if i not in all_stop_tokens_id
410+
]
411+
412+
if from_generation_stop_tokens:
413+
self._stop_word_ids.append(from_generation_stop_tokens)
414+
else:
415+
self._stop_word_ids = [generation_config.eos_token_id]
413416

414417
return self
415418

0 commit comments

Comments
 (0)