Skip to content

Commit 5a5140d

Browse files
authored
Merge pull request #62 from joaonadkarni/fix-transformer-generate-version-check
[bugfix] fix comparison of transformers version in lm generate
2 parents c2b1a47 + e48e73f commit 5a5140d

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

src/ecco/lm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from operator import attrgetter
1818
import re
1919
from ecco.util import is_partial_token, strip_tokenizer_prefix
20+
from packaging import version
2021

2122

2223
class LM(object):
@@ -199,7 +200,7 @@ def generate(self, input_str: str,
199200
# Get decoder input ids
200201
if self.model_type == 'enc-dec': # FIXME: only done because causal LMs like GPT-2 have the _prepare_decoder_input_ids_for_generation method but do not use it
201202
assert len(input_ids.size()) == 2 # will break otherwise
202-
if transformers.__version__ >= '4.13': # ALSO FIXME: awful hack. But seems to work?
203+
if version.parse(transformers.__version__) >= version.parse('4.13'): # ALSO FIXME: awful hack. But seems to work?
203204
decoder_input_ids = self.model._prepare_decoder_input_ids_for_generation(input_ids.shape[0], None, None)
204205
else:
205206
decoder_input_ids = self.model._prepare_decoder_input_ids_for_generation(input_ids, None, None)

0 commit comments

Comments
 (0)