Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TransformersChat code to figure out correct role start and end tokens #791

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 57 additions & 1 deletion guidance/models/transformers/_transformers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os
import re
import uuid
import jinja2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need jinja2?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we do -- the chat templates are written in jinja2 as a standard


try:
import torch
Expand Down Expand Up @@ -277,4 +279,58 @@ def __init__(


class TransformersChat(Transformers, Chat):
pass

def __init__(self, *args, chat_template=None, **kwargs):
super().__init__(*args, **kwargs)

self._fake_content = str(uuid.uuid4())


def get_role_start(self, role_name, **kwargs):
"""The starting grammar for a role.

By default we follow the GPT role tag start conventions.

Parameters
----------
role_name : str
The name of the role, like "user", or "assistant"
kwargs : dict
This kwargs are added to the role start as arguments.
"""
if self.engine.tokenizer._orig_tokenizer.chat_template is not None or self.engine.tokenizer._orig_tokenizer.defaut_chat_template is not None:
messages = [
{"role": role_name, "content": self._fake_content}
]
sereialized_messages = self.engine.tokenizer._orig_tokenizer.apply_chat_template(messages, tokenize=False)
start = sereialized_messages.find(self._fake_content)
return sereialized_messages[:start]
else:
return (
"<|im_start|>"
+ role_name
+ "".join([f' {k}="{v}"' for k, v in kwargs.items()])
+ "\n"
)

def get_role_end(self, role_name=None):
"""The ending bytes for a role.

Note that we cannot use a grammar in closers because they need to remain constant
so we can append them whenever we need a representation before the final closing of the context.
By default we follow the GPT role tag end conventions.

Parameters
----------
role_name : str
The name of the role, like "user", or "assistant"
"""
if self.engine.tokenizer._orig_tokenizer.chat_template is not None or self.engine.tokenizer._orig_tokenizer.defaut_chat_template is not None:
messages = [
{"role": role_name, "content": self._fake_content}
]
sereialized_messages = sereialized_messages = self.engine.tokenizer._orig_tokenizer.apply_chat_template(messages, tokenize=False)
end = sereialized_messages.find(self._fake_content) + len(self._fake_content)
return sereialized_messages[end:]
else:
return "<|im_end|>"
Loading