@@ -28,7 +28,7 @@ class Model:
2828 ```python
2929 from pyllamacpp.model import Model
3030
31- model = Model(ggml_model='./models /ggml- model-f16-q4_0.bin ')
31+ model = Model(ggml_model='path/to /ggml/ model')
3232 for token in model.generate("Tell me a joke ?"):
3333 print(token, end='', flush=True)
3434 ```
@@ -113,6 +113,7 @@ def reset(self):
113113 def generate (self ,
114114 prompt : str ,
115115 n_predict : Union [None , int ] = None ,
116+ antiprompt : str = None ,
116117 infinite_generation : bool = False ,
117118 n_threads : int = 4 ,
118119 repeat_last_n : int = 64 ,
@@ -126,6 +127,8 @@ def generate(self,
126127 :param prompt: The prompt :)
127128 :param n_predict: if n_predict is not None, the inference will stop if it reaches `n_predict` tokens, otherwise
128129 it will continue until `EOS`
130+ :param antiprompt: aka the stop word, the generation will stop if this word is predicted,
131+ keep it None to handle it in your own way
129132 :param infinite_generation: set it to `True` to make the generation go infinitely
130133 :param n_threads: The number of CPU threads
131134 :param repeat_last_n: last n tokens to penalize
@@ -156,6 +159,9 @@ def generate(self,
156159 self ._last_n_tokens .append (tok )
157160
158161 n_remain = 0
162+ if antiprompt is not None :
163+ sequence_queue = []
164+ stop_word = antiprompt .strip ()
159165
160166 while infinite_generation or predicted_token != pp .llama_token_eos ():
161167 if len (predicted_tokens ) > 0 :
@@ -178,6 +184,23 @@ def generate(self,
178184
179185 predicted_tokens .append (predicted_token )
180186 token_str = pp .llama_token_to_str (self ._ctx , predicted_token )
187+ if antiprompt is not None :
188+ if token_str == '\n ' :
189+ sequence_queue .append (token_str )
190+ continue
191+ if len (sequence_queue ) != 0 :
192+ if stop_word .startswith ('' .join (sequence_queue ).strip ()):
193+ sequence_queue .append (token_str )
194+ if '' .join (sequence_queue ).strip () == stop_word :
195+ break
196+ else :
197+ continue
198+ else :
199+ # consume sequence queue tokens
200+ while len (sequence_queue ) != 0 :
201+ yield sequence_queue .pop (0 )
202+ sequence_queue = []
203+
181204 self ._last_n_tokens .pop (0 )
182205 self ._last_n_tokens .append (predicted_token )
183206 yield token_str
0 commit comments