Skip to content

Commit dbedd16

Browse files
committed
mechanism added
1 parent 9d88632 commit dbedd16

File tree

4 files changed

+42
-52
lines changed

4 files changed

+42
-52
lines changed

README.md

Lines changed: 15 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ from pyllamacpp.model import Model
110110

111111
model = Model(ggml_model='./models/gpt4all-model.bin')
112112
for token in model.generate("Tell me a joke ?"):
113-
print(token, end='')
113+
print(token, end='', flush=True)
114114
```
115115

116116
### Interactive Dialogue
@@ -126,8 +126,8 @@ while True:
126126
if prompt == '':
127127
continue
128128
print(f"AI:", end='')
129-
for tok in model.generate(prompt):
130-
print(f"{tok}", end='', flush=True)
129+
for token in model.generate(prompt):
130+
print(f"{token}", end='', flush=True)
131131
print()
132132
except KeyboardInterrupt:
133133
break
@@ -149,43 +149,24 @@ Bob: Welcome! I'm here to assist you with anything you need. What can I do for y
149149
prompt_prefix = "\nUser:"
150150
prompt_suffix = "\nBob:"
151151

152-
model = Model(model_path='/path/to/ggml/model',
153-
prompt_context=prompt_context,
152+
model = Model(model_path='/path/to/ggml/model',
153+
prompt_context=prompt_context,
154154
prompt_prefix=prompt_prefix,
155155
prompt_suffix=prompt_suffix)
156156

157-
sequence = ''
158-
stop_word = prompt_prefix.strip()
159-
160157
while True:
161-
try:
162-
prompt = input("You: ")
163-
if prompt == '':
164-
continue
165-
print(f"AI: ", end='')
166-
for token in model.generate(prompt):
167-
if token == '\n':
168-
sequence += token
169-
continue
170-
if len(sequence) != 0:
171-
if stop_word.startswith(sequence.strip()):
172-
sequence += token
173-
if sequence.strip() == stop_word:
174-
sequence = ''
175-
break
176-
else:
177-
continue
178-
else:
179-
print(f"{sequence}", end='', flush=True)
180-
sequence = ''
181-
print(f"{token}", end='', flush=True)
182-
183-
print()
184-
except KeyboardInterrupt:
185-
break
158+
try:
159+
prompt = input("User: ")
160+
if prompt == '':
161+
continue
162+
print(f"Bob: ", end='')
163+
for token in model.generate(prompt, antiprompt='User:'):
164+
print(f"{token}", end='', flush=True)
165+
print()
166+
except KeyboardInterrupt:
167+
break
186168
```
187169

188-
189170
# API reference
190171
You can check the [API reference documentation](https://abdeladim-s.github.io/pyllamacpp/) for more details.
191172

pyllamacpp/cli.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -201,28 +201,14 @@ def run(args):
201201
print("...")
202202
print("[+] Press Ctrl+C to Stop ... ")
203203
print("...")
204-
sequence = ''
204+
205205
while True:
206206
try:
207207
prompt = input("You: ")
208208
if prompt == '':
209209
continue
210210
print(f"{bcolors.OKBLUE}AI: {bcolors.ENDC}", end='', flush=True)
211-
for token in model.generate(prompt, **gpt_params):
212-
if token == '\n':
213-
sequence += token
214-
continue
215-
if len(sequence) != 0:
216-
if PROMPT_PREFIX.strip().startswith(sequence.strip()):
217-
sequence += token
218-
if sequence.strip() == PROMPT_PREFIX.strip():
219-
sequence = ''
220-
break
221-
else:
222-
continue
223-
else:
224-
print(f"{sequence}", end='', flush=True)
225-
sequence = ''
211+
for token in model.generate(prompt, antiprompt=PROMPT_PREFIX.strip(), **gpt_params):
226212
print(f"{bcolors.OKCYAN}{token}{bcolors.ENDC}", end='', flush=True)
227213
print()
228214
except KeyboardInterrupt:

pyllamacpp/model.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class Model:
2828
```python
2929
from pyllamacpp.model import Model
3030
31-
model = Model(ggml_model='./models/ggml-model-f16-q4_0.bin')
31+
model = Model(ggml_model='path/to/ggml/model')
3232
for token in model.generate("Tell me a joke ?"):
3333
print(token, end='', flush=True)
3434
```
@@ -113,6 +113,7 @@ def reset(self):
113113
def generate(self,
114114
prompt: str,
115115
n_predict: Union[None, int] = None,
116+
antiprompt: str = None,
116117
infinite_generation: bool = False,
117118
n_threads: int = 4,
118119
repeat_last_n: int = 64,
@@ -126,6 +127,8 @@ def generate(self,
126127
:param prompt: The prompt :)
127128
:param n_predict: if n_predict is not None, the inference will stop if it reaches `n_predict` tokens, otherwise
128129
it will continue until `EOS`
130+
:param antiprompt: aka the stop word, the generation will stop if this word is predicted,
131+
keep it None to handle it in your own way
129132
:param infinite_generation: set it to `True` to make the generation go infinitely
130133
:param n_threads: The number of CPU threads
131134
:param repeat_last_n: last n tokens to penalize
@@ -156,6 +159,9 @@ def generate(self,
156159
self._last_n_tokens.append(tok)
157160

158161
n_remain = 0
162+
if antiprompt is not None:
163+
sequence_queue = []
164+
stop_word = antiprompt.strip()
159165

160166
while infinite_generation or predicted_token != pp.llama_token_eos():
161167
if len(predicted_tokens) > 0:
@@ -178,6 +184,23 @@ def generate(self,
178184

179185
predicted_tokens.append(predicted_token)
180186
token_str = pp.llama_token_to_str(self._ctx, predicted_token)
187+
if antiprompt is not None:
188+
if token_str == '\n':
189+
sequence_queue.append(token_str)
190+
continue
191+
if len(sequence_queue) != 0:
192+
if stop_word.startswith(''.join(sequence_queue).strip()):
193+
sequence_queue.append(token_str)
194+
if ''.join(sequence_queue).strip() == stop_word:
195+
break
196+
else:
197+
continue
198+
else:
199+
# consume sequence queue tokens
200+
while len(sequence_queue) != 0:
201+
yield sequence_queue.pop(0)
202+
sequence_queue = []
203+
181204
self._last_n_tokens.pop(0)
182205
self._last_n_tokens.append(predicted_token)
183206
yield token_str

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def build_extension(self, ext: CMakeExtension) -> None:
131131
# logic and declaration, and simpler if you include description/version in a file.
132132
setup(
133133
name="pyllamacpp",
134-
version="2.1.0",
134+
version="2.1.1",
135135
author="Abdeladim Sadiki",
136136
description="Python bindings for llama.cpp",
137137
long_description=long_description,

0 commit comments

Comments
 (0)