Skip to content

Commit

Permalink
fix: support xml tags inside chat messages (#30)
Browse files Browse the repository at this point in the history
* support xml tags inside chat messages

* fix tests

* linting
  • Loading branch information
masci authored Dec 31, 2024
1 parent 42a4633 commit 0737d7d
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 89 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ ban-relative-imports = "parents"

[tool.ruff.lint.per-file-ignores]
# Tests can use magic values, assertions, and relative imports
"tests/**/*" = ["PLR2004", "S101", "TID252"]
"tests/**/*" = ["PLR2004", "S101", "TID252", "E501"]


[tool.coverage.run]
Expand Down
60 changes: 18 additions & 42 deletions src/banks/extensions/chat.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,15 @@
# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi <[email protected]>
#
# SPDX-License-Identifier: MIT
from html.parser import HTMLParser
import re

from jinja2 import TemplateSyntaxError, nodes
from jinja2.ext import Extension

from banks.types import ChatMessage, ChatMessageContent, ContentBlock, ContentBlockType
from banks.types import ChatMessage, ContentBlock, ContentBlockType

SUPPORTED_TYPES = ("system", "user")


class _ContentBlockParser(HTMLParser):
"""A parser used to extract text surrounded by `<content_block_txt>` and `</content_block_txt>` tags."""

def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._parse_block_content = False
self._content_blocks: list[ContentBlock] = []

@property
def content(self) -> ChatMessageContent:
"""Returns ChatMessageContent data that can be directly assigned to ChatMessage.content.
If only one block is present, this block is of type text and has no cache control set, we just
return it as plain text for simplicity.
"""
if len(self._content_blocks) == 1:
block = self._content_blocks[0]
if block.type == "text" and block.cache_control is None:
return block.text or ""

return self._content_blocks

def handle_starttag(self, tag, attrs): # noqa
if tag == "content_block":
self._parse_block_content = True

def handle_endtag(self, tag):
if tag == "content_block":
self._parse_block_content = False

def handle_data(self, data):
if self._parse_block_content:
self._content_blocks.append(ContentBlock.model_validate_json(data))
else:
self._content_blocks.append(ContentBlock(type=ContentBlockType.text, text=data))
CONTENT_BLOCK_REGEX = re.compile(r"<content_block>((?s:.)*)<\/content_block>")


class ChatExtension(Extension):
Expand Down Expand Up @@ -105,7 +69,19 @@ def _store_chat_messages(self, role, caller):
"""
Helper callback.
"""
parser = _ContentBlockParser()
parser.feed(caller())
cm = ChatMessage(role=role, content=parser.content)
content_blocks: list[ContentBlock] = []
result = CONTENT_BLOCK_REGEX.match(caller())
if result is not None:
for g in result.groups():
content_blocks.append(ContentBlock.model_validate_json(g))
else:
content_blocks.append(ContentBlock(type=ContentBlockType.text, text=caller()))

content = content_blocks
if len(content_blocks) == 1:
block = content_blocks[0]
if block.type == "text" and block.cache_control is None:
content = block.text or ""

cm = ChatMessage(role=role, content=content)
return cm.model_dump_json(exclude_none=True) + "\n"
4 changes: 2 additions & 2 deletions tests/templates/chat.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ You are a helpful assistant.
{% endchat %}

{% chat role="user" %}
Hello, how are you?
{{ "Hello, <bold>how are you?</bold>" | cache_control("ephemeral") }}
{% endchat %}

{% chat role="system" %}
Expand All @@ -14,4 +14,4 @@ I'm doing well, thank you! How can I assist you today?
Can you explain quantum computing?
{% endchat %}

Some random text.
Some random text.
42 changes: 0 additions & 42 deletions tests/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
from jinja2 import TemplateSyntaxError

from banks import Prompt
from banks.extensions.chat import _ContentBlockParser
from banks.types import CacheControl, ContentBlock, ContentBlockType


def test_wrong_tag():
Expand All @@ -19,43 +17,3 @@ def test_wrong_tag_params():
def test_wrong_role_type():
with pytest.raises(TemplateSyntaxError):
Prompt('{% chat role="does not exist" %}{% endchat %}')


def test_content_block_parser_init():
p = _ContentBlockParser()
assert p._parse_block_content is False
assert p._content_blocks == []


def test_content_block_parser_single_with_cache_control():
p = _ContentBlockParser()
p.feed(
'<content_block>{"type":"text","cache_control":{"type":"ephemeral"},"text":"foo","source":null}</content_block>'
)
assert p.content == [
ContentBlock(type=ContentBlockType.text, cache_control=CacheControl(type="ephemeral"), text="foo", source=None)
]


def test_content_block_parser_single_no_cache_control():
p = _ContentBlockParser()
p.feed('<content_block>{"type":"text","cache_control":null,"text":"foo","source":null}</content_block>')
assert p.content == "foo"


def test_content_block_parser_multiple():
p = _ContentBlockParser()
p.feed(
'<content_block>{"type":"text","cache_control":null,"text":"foo","source":null}</content_block>'
'<content_block>{"type":"text","cache_control":null,"text":"bar","source":null}</content_block>'
)
assert p.content == [
ContentBlock(type=ContentBlockType.text, cache_control=None, text="foo", source=None),
ContentBlock(type=ContentBlockType.text, cache_control=None, text="bar", source=None),
]


def test_content_block_parser_other_tags():
p = _ContentBlockParser()
p.feed("<some_tag>FOO</some_tag>")
assert p.content == "FOO"
17 changes: 15 additions & 2 deletions tests/test_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from banks import AsyncPrompt, ChatMessage, Prompt
from banks.cache import DefaultCache
from banks.errors import AsyncError
from banks.types import CacheControl, ContentBlock, ContentBlockType


def test_canary_word_generation():
Expand Down Expand Up @@ -90,7 +91,7 @@ def test_chat_messages():
== """
{"role":"system","content":"You are a helpful assistant.\\n"}
{"role":"user","content":"Hello, how are you?\\n"}
{"role":"user","content":[{"type":"text","cache_control":{"type":"ephemeral"},"text":"Hello, <bold>how are you?</bold>"}]}
{"role":"system","content":"I'm doing well, thank you! How can I assist you today?\\n"}
Expand All @@ -102,7 +103,19 @@ def test_chat_messages():

assert p.chat_messages() == [
ChatMessage(role="system", content="You are a helpful assistant.\n"),
ChatMessage(role="user", content="Hello, how are you?\n"),
ChatMessage(
role="user",
content=[
ContentBlock(
type=ContentBlockType.text,
cache_control=CacheControl(type="ephemeral"),
text="Hello, <bold>how are you?</bold>",
image_url=None,
)
],
tool_call_id=None,
name=None,
),
ChatMessage(role="system", content="I'm doing well, thank you! How can I assist you today?\n"),
ChatMessage(role="user", content="Can you explain quantum computing?\n"),
]
Expand Down

0 comments on commit 0737d7d

Please sign in to comment.