11# Copyright (c) OpenMMLab. All rights reserved.
22import json
33import os .path as osp
4- from typing import Sequence , Union
4+ from typing import Optional , Sequence , Union
55
66import torch
77
@@ -16,6 +16,7 @@ class SentencePieceTokenizer:
1616 def __init__ (self , model_file : str ):
1717 from sentencepiece import SentencePieceProcessor
1818 self .model = SentencePieceProcessor (model_file = model_file )
19+ self ._no_prefix_space_tokens = None
1920
2021 @property
2122 def vocab_size (self ):
@@ -32,6 +33,24 @@ def eos_token_id(self):
3233 """end of the sentence token id."""
3334 return self .model .eos_id ()
3435
36+ @property
37+ def no_prefix_space_tokens (self ):
38+ """tokens without prefix space."""
39+ if self ._no_prefix_space_tokens is None :
40+ vocab = self .model .IdToPiece (list (range (self .vocab_size )))
41+ self ._no_prefix_space_tokens = {
42+ i
43+ for i , tok in enumerate (vocab ) if not tok .startswith ('▁' )
44+ }
45+ return self ._no_prefix_space_tokens
46+
47+ def _maybe_add_prefix_space (self , tokens , decoded ):
48+ """maybe add prefix space for incremental decoding."""
49+ if len (tokens ) and tokens [0 ] not in self .no_prefix_space_tokens :
50+ return ' ' + decoded
51+ else :
52+ return decoded
53+
3554 def encode (self , s : str ):
3655 """Tokenize a prompt.
3756
@@ -50,17 +69,23 @@ def encode(self, s: str):
5069 add_eos = True
5170 return self .model .Encode (s , add_bos = add_bos , add_eos = add_eos )
5271
53- def decode (self , t : Sequence [int ]):
72+ def decode (self , t : Sequence [int ], offset : Optional [ int ] = None ):
5473 """De-tokenize.
5574
5675 Args:
5776 t (List[int]): a list of token ids
77+ offset (int): for incrementally decoding. Default to None, which
78+ means not applied.
5879 Returns:
5980 str: text of decoding tokens
6081 """
6182 if isinstance (t , torch .Tensor ):
6283 t = t .tolist ()
63- return self .model .Decode (t )
84+ t = t [offset :]
85+ out_string = self .model .Decode (t )
86+ if offset :
87+ out_string = self ._maybe_add_prefix_space (t , out_string )
88+ return out_string
6489
6590 def __call__ (self , s : Union [str , Sequence [str ]]):
6691 """Tokenize prompts.
@@ -86,7 +111,7 @@ class HuggingFaceTokenizer:
86111 """
87112
88113 def __init__ (self , model_dir : str ):
89- from transformers import AutoTokenizer
114+ from transformers import AutoTokenizer , LlamaTokenizerFast
90115 model_file = osp .join (model_dir , 'tokenizer.model' )
91116 backend_tokenizer_file = osp .join (model_dir , 'tokenizer.json' )
92117 model_file_exists = osp .exists (model_file )
@@ -95,6 +120,8 @@ def __init__(self, model_dir: str):
95120 'It may take long time to initialize the tokenizer.' )
96121 self .model = AutoTokenizer .from_pretrained (model_dir ,
97122 trust_remote_code = True )
123+ self .need_padding = isinstance (self .model , LlamaTokenizerFast )
124+ self ._no_prefix_space_tokens = None
98125 # save tokenizer.json to reuse
99126 if not osp .exists (backend_tokenizer_file ) and model_file_exists :
100127 if hasattr (self .model , 'backend_tokenizer' ):
@@ -122,6 +149,26 @@ def eos_token_id(self):
122149 """end of the sentence token id."""
123150 return self .model .eos_token_id
124151
152+ @property
153+ def no_prefix_space_tokens (self ):
154+ """tokens without prefix space."""
155+ if self ._no_prefix_space_tokens is None :
156+ vocab = self .model .convert_ids_to_tokens (
157+ list (range (self .vocab_size )))
158+ self ._no_prefix_space_tokens = {
159+ i
160+ for i , tok in enumerate (vocab ) if not tok .startswith ('▁' )
161+ }
162+ return self ._no_prefix_space_tokens
163+
164+ def _maybe_add_prefix_space (self , tokens , decoded ):
165+ """maybe add prefix space for incremental decoding."""
166+ if self .need_padding and len (
167+ tokens ) and tokens [0 ] not in self .no_prefix_space_tokens :
168+ return ' ' + decoded
169+ else :
170+ return decoded
171+
125172 def encode (self , s : str ):
126173 """Tokenize a prompt.
127174
@@ -139,16 +186,23 @@ def encode(self, s: str):
139186 add_special_tokens = True
140187 return self .model .encode (s , add_special_tokens = add_special_tokens )
141188
142- def decode (self , t : Sequence [int ]):
189+ def decode (self , t : Sequence [int ], offset : Optional [ int ] = None ):
143190 """De-tokenize.
144191
145192 Args:
146193 t (List[int]): a list of token ids
194+ offset (int): for incrementally decoding. Default to None, which
195+ means not applied.
147196 Returns:
148197 str: text of decoding tokens
149198 """
150199 skip_special_tokens = True
151- return self .model .decode (t , skip_special_tokens = skip_special_tokens )
200+ t = t [offset :]
201+ out_string = self .model .decode (t ,
202+ skip_special_tokens = skip_special_tokens )
203+ if offset :
204+ out_string = self ._maybe_add_prefix_space (t , out_string )
205+ return out_string
152206
153207 def __call__ (self , s : Union [str , Sequence [str ]]):
154208 """Tokenize prompts.
@@ -211,15 +265,17 @@ def encode(self, s: str):
211265 """
212266 return self .model .encode (s )
213267
214- def decode (self , t : Sequence [int ]):
268+ def decode (self , t : Sequence [int ], offset : Optional [ int ] = None ):
215269 """De-tokenize.
216270
217271 Args:
218272 t (List[int]): a list of token ids
273+ offset (int): for incrementally decoding. Default to None, which
274+ means not applied.
219275 Returns:
220276 str: text of decoding tokens
221277 """
222- return self .model .decode (t )
278+ return self .model .decode (t , offset )
223279
224280 def __call__ (self , s : Union [str , Sequence [str ]]):
225281 """Tokenize prompts.
0 commit comments