Skip to content

Commit 28d6f90

Browse files
committed
replace use_fast_tokenizer with tokenizer_args
1 parent 55756ad commit 28d6f90

File tree

1 file changed

+2
-6
lines changed

1 file changed

+2
-6
lines changed

sentence_transformers/cross_encoder/CrossEncoder.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616

1717
class CrossEncoder():
18-
def __init__(self, model_name:str, num_labels:int = None, max_length:int = None, device:str = None, use_fast_tokenizer:bool = None):
18+
def __init__(self, model_name:str, num_labels:int = None, max_length:int = None, device:str = None, tokenizer_args:Dict = {}):
1919
"""
2020
A CrossEncoder takes exactly two sentences / texts as input and either predicts
2121
a score or label for this sentence pair. It can for example predict the similarity of the sentence pair
@@ -27,7 +27,7 @@ def __init__(self, model_name:str, num_labels:int = None, max_length:int = None,
2727
:param num_labels: Number of labels of the classifier. If 1, the CrossEncoder is a regression model that outputs a continous score 0...1. If > 1, it output several scores that can be soft-maxed to get probability scores for the different classes.
2828
:param max_length: Max length for input sequences. Longer sequences will be truncated. If None, max length of the model will be used
2929
:param device: Device that should be used for the model. If None, it will use CUDA if available.
30-
:param use_fast_tokenizer: Use fast tokenizer from hugging face.
30+
:param tokenizer_args: Arguments passed to AutoTokenizer
3131
"""
3232

3333
self.config = AutoConfig.from_pretrained(model_name)
@@ -42,10 +42,6 @@ def __init__(self, model_name:str, num_labels:int = None, max_length:int = None,
4242
self.config.num_labels = num_labels
4343

4444
self.model = AutoModelForSequenceClassification.from_pretrained(model_name, config=self.config)
45-
tokenizer_args = {}
46-
if use_fast_tokenizer is not None:
47-
tokenizer_args['use_fast'] = use_fast_tokenizer
48-
4945
self.tokenizer = AutoTokenizer.from_pretrained(model_name, **tokenizer_args)
5046

5147
self.max_length = max_length

0 commit comments

Comments
 (0)