Skip to content

Commit a75896a

Browse files
Harsha-Noririedgar-mspaulbkoch
authored
[WIP] Remove Instruct/Chat versions of models & introduce a new ChatTemplate API, fix Anthropic API (#820)
This PR should significantly reduce the number of user-facing classes we have in Guidance, and reduce subtle bugs introduced by using a mis-specified Chat Template (models currently silently default to the ChatML syntax, which many of the latest models don't adhere to). It should also make it easier for users to add new models to guidance, either via PR or in their own codebases. Before: ```python from guidance.models.transformers import Transformers, TransformersChat class Llama(Transformers): pass # Users have to do this for most chat models in guidance class LlamaChat(TransformersChat, Llama): def get_role_start(self, role_name, **kwargs): if role_name == "system": return self._system_prefex + "<<SYS>>\n" elif role_name == "user": if str(self).endswith("\n<</SYS>>\n\n"): return "" else: return "[INST] " else: return " " def get_role_end(self, role_name=None): if role_name == "system": return "\n<</SYS>>\n\n" elif role_name == "user": return " [/INST]" else: return " " lm = LlamaChat(path_to_llama) ``` After: ```python from guidance.models import Transformers lm = Transformers(path_to_llama) # automagically works ``` If you're using a rare model and the auto import doesn't automatically work... After pt2: ```python # users can copy paste from https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/main/tokenizer_config.json#L12 llama2_template = "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}" lm = Transformers(path_to_llama, chat_template=llama2_template) ``` or, in the worst case for maximal robustness and customizability: ```python from guidance._chat import ChatTemplate, UnsupportedRoleException class Llama2ChatTemplate(ChatTemplate): template_str = llama2_template def get_role_start(self, role_name): if role_name == "system": return "[INST] <<SYS>>\n" elif role_name == "user": return "<s>[INST]" elif role_name == "assistant": return " " else: raise UnsupportedRoleException(role_name, self) def get_role_end(self, role_name=None): if role_name == "system": return "\n<</SYS>" elif role_name == "user": return " [/INST]" elif role_name == "assistant": return "</s>" else: raise UnsupportedRoleException(role_name, self) lm = Transformers(path_to_llama, chat_template=Llama2ChatTemplate) ``` The first big change is the removal of the `Chat` and `Instruct` mixins, and an introduction of a new `guidance._chat.ChatTemplate` class, which handles the same responsibilities as those mixins used to. Users can construct a subclass of `ChatTemplate` and pass it to models with a new `chat_template` argument (defaulted to None). The way this works for local models is to leverage the `chat_template` property in `huggingface transformers` and `llamacpp`'s GGUF files. When a user tries to load a model, guidance now follows the following order of operations: 1. See if the user passed in a `ChatTemplate` -- if so, use that directly. 2. If the user passed string "chat_template", set that as a "template_str". If the user did not pass in anything, set the "template_str" based on metadata from the huggingface.AutoTokenizer or gguf metadata fields. 3. Check the `template_str` against a local cache in guidance which maintains template converters for the most popular models on huggingface/in llama.cpp. We index this cache based on the actual chat_template string, so any model that uses one of these chat templates -- even if it isn't explicitly listed in the documentation -- will automatically load the right guidance class. 4. If we don't have anything in the cache, try to automatically convert the jinja2 template into the new guidance._chat.ChatTemplate syntax. Warn the user if we attempt this. [NOTE: This is not yet implemented and may come in a future PR.] 5. Default to the `ChatML` syntax, with a warning to the user. Currently this PR updates the following user facing `guidance.Models` classes: - Transformers (removing TransformersChat and the Llama/Mistral subclasses) - LlamaCpp (removing LlamaCppChat and Mistral subclasses) - Anthropic For now, `Anthropic` should be representative of how grammarless classes will work. I wanted to start with OpenAI, but many other guidance.models classes inherit from OpenAI, so I'm saving that for later. Also while I was at it I upgraded the `Anthropic` class to use their latest SDK, so `guidance.models.Anthropic` should now work with the latest Claude3 models. # TODO A decent amount left to do here. In no particular order... 1. Unrelated to this change, but guidance cannot properly handle `llama3` or `phi-3`'s true tokenizers/chat templates. We need to fix that independently. 2. Add better warnings to users when we fall back to using the ChatML template. 3. Extend support through the rest of the guidance.models classes (mostly just remote ones left, starting with OpenAI). 4. Write out more templates for popular models in the ChatTemplateCache, and also add an alias system so that we can look up models in the cache by common names (e.g. "llama3"). 5. Add a deprecation warning to people trying to use very old models on `Anthropic`. 6. Much more testing and documentation. We should, for example, add documentation on how to import/initialize a new ChatTemplate and use it for your own models. 7. Write the auto-converter from huggingface `jinja2` to guidance ChatTemplate. A battery of unit tests here that compare against the original `transformers.apply_chat_template` method would make this more robust. Can be in a future PR as this is complex logic. A start to this was attempted in #791 by @ilmarinen, and we could eventually pull this in and expand its coverage. 8. Probably get rid of the folders in `guidance.models.llama_cpp` and `guidance.models.transformers` because we don't need to maintain a bunch of subclasses for them anymore. Would appreciate any and all feedback, particularly on the logical flow and new user facing (simpler) API. @marcotcr @paulbkoch @slundberg @riedgar-ms @hudson-ai --------- Co-authored-by: Richard Edgar (Microsoft) <[email protected]> Co-authored-by: Paul Koch <[email protected]>
1 parent 13270bf commit a75896a

30 files changed

+712
-631
lines changed

.github/workflows/workflow-pr-gate.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ jobs:
134134
needs:
135135
- unit-tests-linux-python-other
136136
- unit-tests-gpu-python-latest
137+
- server-tests
137138
name: End Stage 2
138139
runs-on: ubuntu-latest
139140
steps:

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,6 @@ guidance/_rust/Cargo.lock
2222

2323
notebooks/**/*.papermill_out.ipynb
2424

25-
.mypy_cache/*
25+
.mypy_cache/*
26+
27+
**/scratch.*

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@ else:
5555
```python
5656
from guidance import user, assistant
5757

58-
# load a chat model
59-
chat_lm = models.LlamaCppChat(path)
58+
# load a model
59+
chat_lm = models.LlamaCpp(path)
6060

6161
# wrap with chat block contexts
6262
with user():

guidance/_chat.py

Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
import warnings
2+
import uuid
3+
import inspect
4+
5+
class ChatTemplate:
6+
"""Contains template for all chat and instruct tuned models."""
7+
8+
def get_role_start(self, role_name, **kwargs):
9+
raise NotImplementedError(
10+
"You need to use a ChatTemplate subclass that overrides the get_role_start method"
11+
)
12+
13+
def get_role_end(self, role_name=None):
14+
raise NotImplementedError(
15+
"You need to use a ChatTemplate subclass that overrides the get_role_start method"
16+
)
17+
18+
class ChatTemplateCache:
19+
def __init__(self):
20+
self._cache = {}
21+
22+
def __getitem__(self, key):
23+
key_compact = key.replace(" ", "")
24+
return self._cache[key_compact]
25+
26+
27+
def __setitem__(self, key, value):
28+
key_compact = key.replace(" ", "")
29+
self._cache[key_compact] = value
30+
31+
def __contains__(self, key):
32+
key_compact = key.replace(" ", "")
33+
return key_compact in self._cache
34+
35+
# Feels weird having to instantiate this, but it's a singleton for all purposes
36+
# TODO [HN]: Add an alias system so we can instantiate with other simple keys (e.g. "llama2" instead of the full template string)
37+
CHAT_TEMPLATE_CACHE = ChatTemplateCache()
38+
39+
class UnsupportedRoleException(Exception):
40+
def __init__(self, role_name, instance):
41+
self.role_name = role_name
42+
self.instance = instance
43+
super().__init__(self._format_message())
44+
45+
def _format_message(self):
46+
return (f"Role {self.role_name} is not supported by the {self.instance.__class__.__name__} chat template. ")
47+
48+
def load_template_class(chat_template=None):
49+
"""Utility method to find the best chat template.
50+
51+
Order of precedence:
52+
- If it's a chat template class, use it directly
53+
- If it's a string, check the cache of popular model templates
54+
- If it's a string and not in the cache, try to create a class dynamically
55+
- [TODO] If it's a string and can't be created, default to ChatML and raise a warning
56+
- If it's None, default to ChatML and raise a warning
57+
"""
58+
if inspect.isclass(chat_template) and issubclass(chat_template, ChatTemplate):
59+
if chat_template is ChatTemplate:
60+
raise Exception("You can't use the base ChatTemplate class directly. Create or use a subclass instead.")
61+
return chat_template
62+
63+
elif isinstance(chat_template, str):
64+
# First check the cache of popular model types
65+
# TODO: Expand keys of cache to include aliases for popular model types (e.g. "llama2, phi3")
66+
# Can possibly accomplish this with an "aliases" dictionary that maps all aliases to the canonical key in cache
67+
if chat_template in CHAT_TEMPLATE_CACHE:
68+
return CHAT_TEMPLATE_CACHE[chat_template]
69+
# TODO: Add logic here to try to auto-create class dynamically via _template_class_from_string method
70+
71+
# Only warn when a user provided a chat template that we couldn't load
72+
if chat_template is not None:
73+
warnings.warn(f"""Chat template {chat_template} was unable to be loaded directly into guidance.
74+
Defaulting to the ChatML format which may not be optimal for the selected model.
75+
For best results, create and pass in a `guidance.ChatTemplate` subclass for your model.""")
76+
77+
# By default, use the ChatML Template. Warnings to user will happen downstream only if they use chat roles.
78+
return ChatMLTemplate
79+
80+
81+
def _template_class_from_string(template_str):
82+
"""Utility method to try to create a chat template class from a string."""
83+
# TODO: Try to build this, perhaps based on passing unit tests we create?
84+
pass
85+
86+
87+
# CACHE IMPLEMENTATIONS:
88+
89+
# --------------------------------------------------
90+
# @@@@ ChatML @@@@
91+
# --------------------------------------------------
92+
# Note that all grammarless models will default to this syntax, since we typically send chat formatted messages.
93+
chatml_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}"
94+
class ChatMLTemplate(ChatTemplate):
95+
template_str = chatml_template
96+
97+
def get_role_start(self, role_name):
98+
return f"<|im_start|>{role_name}\n"
99+
100+
def get_role_end(self, role_name=None):
101+
return "<|im_end|>\n"
102+
103+
CHAT_TEMPLATE_CACHE[chatml_template] = ChatMLTemplate
104+
105+
106+
# --------------------------------------------------
107+
# @@@@ Llama-2 @@@@
108+
# --------------------------------------------------
109+
# [05/08/24] https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/main/tokenizer_config.json#L12
110+
llama2_template = "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}"
111+
class Llama2ChatTemplate(ChatTemplate):
112+
# available_roles = ["system", "user", "assistant"]
113+
template_str = llama2_template
114+
115+
def get_role_start(self, role_name):
116+
if role_name == "system":
117+
return "[INST] <<SYS>>\n"
118+
elif role_name == "user":
119+
return "<s>[INST]"
120+
elif role_name == "assistant":
121+
return " "
122+
else:
123+
raise UnsupportedRoleException(role_name, self)
124+
125+
def get_role_end(self, role_name=None):
126+
if role_name == "system":
127+
return "\n<</SYS>"
128+
elif role_name == "user":
129+
return " [/INST]"
130+
elif role_name == "assistant":
131+
return "</s>"
132+
else:
133+
raise UnsupportedRoleException(role_name, self)
134+
135+
CHAT_TEMPLATE_CACHE[llama2_template] = Llama2ChatTemplate
136+
137+
138+
# --------------------------------------------------
139+
# @@@@ Llama-3 @@@@
140+
# --------------------------------------------------
141+
# [05/08/24] https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct/blob/main/tokenizer_config.json#L2053
142+
llama3_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}"
143+
class Llama3ChatTemplate(ChatTemplate):
144+
# available_roles = ["system", "user", "assistant"]
145+
template_str = llama3_template
146+
147+
def get_role_start(self, role_name):
148+
if role_name == "system":
149+
return "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n"
150+
elif role_name == "user":
151+
return "<|start_header_id|>user<|end_header_id>\n\n"
152+
elif role_name == "assistant":
153+
return "<|start_header_id|>assistant<|end_header_id>\n\n"
154+
else:
155+
raise UnsupportedRoleException(role_name, self)
156+
157+
def get_role_end(self, role_name=None):
158+
return "<|eot_id|>"
159+
160+
CHAT_TEMPLATE_CACHE[llama3_template] = Llama3ChatTemplate
161+
162+
# --------------------------------------------------
163+
# @@@@ Phi-3 @@@@
164+
# --------------------------------------------------
165+
# [05/08/24] https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/tokenizer_config.json#L119
166+
phi3_template = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}"
167+
class Phi3ChatTemplate(ChatTemplate):
168+
# available_roles = ["user", "assistant"]
169+
template_str = phi3_template
170+
171+
def get_role_start(self, role_name):
172+
if role_name == "user":
173+
return "<|user|>"
174+
elif role_name == "assistant":
175+
return "<|assistant|>"
176+
else:
177+
raise UnsupportedRoleException(role_name, self)
178+
179+
def get_role_end(self, role_name=None):
180+
return "<|end|>"
181+
182+
CHAT_TEMPLATE_CACHE[phi3_template] = Phi3ChatTemplate
183+
184+
185+
# --------------------------------------------------
186+
# @@@@ Mistral-7B-Instruct-v0.2 @@@@
187+
# --------------------------------------------------
188+
# [05/08/24] https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2/blob/main/tokenizer_config.json#L42
189+
mistral_7b_instruct_template = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"
190+
class Mistral7BInstructChatTemplate(ChatTemplate):
191+
# available_roles = ["user", "assistant"]
192+
template_str = mistral_7b_instruct_template
193+
194+
def get_role_start(self, role_name):
195+
if role_name == "user":
196+
return "[INST] "
197+
elif role_name == "assistant":
198+
return ""
199+
else:
200+
raise UnsupportedRoleException(role_name, self)
201+
202+
def get_role_end(self, role_name=None):
203+
if role_name == "user":
204+
return " [/INST]"
205+
elif role_name == "assistant":
206+
return "</s>"
207+
else:
208+
raise UnsupportedRoleException(role_name, self)
209+
210+
CHAT_TEMPLATE_CACHE[mistral_7b_instruct_template] = Mistral7BInstructChatTemplate

guidance/library/_role.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,10 @@
77
span_start = "<||_html:<span style='background-color: rgba(255, 180, 0, 0.3); border-radius: 3px;'>_||>"
88
span_end = "<||_html:</span>_||>"
99

10-
1110
@guidance
1211
def role_opener(lm, role_name, **kwargs):
1312
indent = getattr(lm, "indent_roles", True)
14-
if not hasattr(lm, "get_role_start"):
15-
raise Exception(
16-
f"You need to use a chat model in order the use role blocks like `with {role_name}():`! Perhaps you meant to use the {type(lm).__name__}Chat class?"
17-
)
13+
1814

1915
# Block start container (centers elements)
2016
if indent:
@@ -25,8 +21,17 @@ def role_opener(lm, role_name, **kwargs):
2521
lm += nodisp_start
2622
else:
2723
lm += span_start
28-
29-
lm += lm.get_role_start(role_name, **kwargs)
24+
25+
# TODO [HN]: Temporary change while I instrument chat_template in transformers only.
26+
# Eventually have all models use chat_template.
27+
if hasattr(lm, "get_role_start"):
28+
lm += lm.get_role_start(role_name, **kwargs)
29+
elif hasattr(lm, "chat_template"):
30+
lm += lm.chat_template.get_role_start(role_name)
31+
else:
32+
raise Exception(
33+
f"You need to use a chat model in order the use role blocks like `with {role_name}():`! Perhaps you meant to use the {type(lm).__name__}Chat class?"
34+
)
3035

3136
# End of either debug or HTML no disp block
3237
if indent:
@@ -46,7 +51,12 @@ def role_closer(lm, role_name, **kwargs):
4651
else:
4752
lm += span_start
4853

49-
lm += lm.get_role_end(role_name)
54+
# TODO [HN]: Temporary change while I instrument chat_template in transformers only.
55+
# Eventually have all models use chat_template.
56+
if hasattr(lm, "get_role_end"):
57+
lm += lm.get_role_end(role_name)
58+
elif hasattr(lm, "chat_template"):
59+
lm += lm.chat_template.get_role_end(role_name)
5060

5161
# End of either debug or HTML no disp block
5262
if indent:
@@ -60,7 +70,7 @@ def role_closer(lm, role_name, **kwargs):
6070

6171
return lm
6272

63-
73+
# TODO HN: Add a docstring to better describe arbitrary role functions
6474
def role(role_name, text=None, **kwargs):
6575
if text is None:
6676
return block(

guidance/models/__init__.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from ._model import Model, Instruct, Chat
22

33
# local models
4-
from .transformers._transformers import Transformers, TransformersChat
5-
from .llama_cpp import LlamaCpp, LlamaCppChat, MistralInstruct, MistralChat
4+
from .transformers._transformers import Transformers
5+
from .llama_cpp import LlamaCpp
66
from ._mock import Mock, MockChat
77

88
# grammarless models (we can't do constrained decoding for them)
@@ -15,15 +15,12 @@
1515
)
1616
from ._azure_openai import (
1717
AzureOpenAI,
18-
AzureOpenAIChat,
19-
AzureOpenAICompletion,
20-
AzureOpenAIInstruct,
2118
)
2219
from ._azureai_studio import AzureAIStudioChat
23-
from ._openai import OpenAI, OpenAIChat, OpenAIInstruct, OpenAICompletion
20+
from ._openai import OpenAI
2421
from ._lite_llm import LiteLLM, LiteLLMChat, LiteLLMInstruct, LiteLLMCompletion
2522
from ._cohere import Cohere, CohereCompletion, CohereInstruct
26-
from ._anthropic import Anthropic, AnthropicChat
23+
from ._anthropic import Anthropic
2724
from ._googleai import GoogleAI, GoogleAIChat
2825
from ._togetherai import (
2926
TogetherAI,

0 commit comments

Comments
 (0)