Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 57 additions & 1 deletion guidance/models/transformers/_transformers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os
import re
import uuid
import jinja2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need jinja2?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we do -- the chat templates are written in jinja2 as a standard


try:
import torch
Expand Down Expand Up @@ -277,4 +279,58 @@ def __init__(


class TransformersChat(Transformers, Chat):
pass

def __init__(self, *args, chat_template=None, **kwargs):
super().__init__(*args, **kwargs)

self._fake_content = str(uuid.uuid4())


def get_role_start(self, role_name, **kwargs):
"""The starting grammar for a role.

By default we follow the GPT role tag start conventions.

Parameters
----------
role_name : str
The name of the role, like "user", or "assistant"
kwargs : dict
This kwargs are added to the role start as arguments.
"""
if self.engine.tokenizer._orig_tokenizer.chat_template is not None or self.engine.tokenizer._orig_tokenizer.defaut_chat_template is not None:
messages = [
{"role": role_name, "content": self._fake_content}
]
sereialized_messages = self.engine.tokenizer._orig_tokenizer.apply_chat_template(messages, tokenize=False)
start = sereialized_messages.find(self._fake_content)
return sereialized_messages[:start]
else:
return (
"<|im_start|>"
+ role_name
+ "".join([f' {k}="{v}"' for k, v in kwargs.items()])
+ "\n"
)

def get_role_end(self, role_name=None):
"""The ending bytes for a role.

Note that we cannot use a grammar in closers because they need to remain constant
so we can append them whenever we need a representation before the final closing of the context.
By default we follow the GPT role tag end conventions.

Parameters
----------
role_name : str
The name of the role, like "user", or "assistant"
"""
if self.engine.tokenizer._orig_tokenizer.chat_template is not None or self.engine.tokenizer._orig_tokenizer.defaut_chat_template is not None:
messages = [
{"role": role_name, "content": self._fake_content}
]
sereialized_messages = sereialized_messages = self.engine.tokenizer._orig_tokenizer.apply_chat_template(messages, tokenize=False)
end = sereialized_messages.find(self._fake_content) + len(self._fake_content)
return sereialized_messages[end:]
else:
return "<|im_end|>"