-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix: support xml tags inside chat messages (#30)
* support xml tags inside chat messages * fix tests * linting
- Loading branch information
Showing
5 changed files
with
36 additions
and
89 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,51 +1,15 @@ | ||
# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi <[email protected]> | ||
# | ||
# SPDX-License-Identifier: MIT | ||
from html.parser import HTMLParser | ||
import re | ||
|
||
from jinja2 import TemplateSyntaxError, nodes | ||
from jinja2.ext import Extension | ||
|
||
from banks.types import ChatMessage, ChatMessageContent, ContentBlock, ContentBlockType | ||
from banks.types import ChatMessage, ContentBlock, ContentBlockType | ||
|
||
SUPPORTED_TYPES = ("system", "user") | ||
|
||
|
||
class _ContentBlockParser(HTMLParser): | ||
"""A parser used to extract text surrounded by `<content_block_txt>` and `</content_block_txt>` tags.""" | ||
|
||
def __init__(self, *args, **kwargs) -> None: | ||
super().__init__(*args, **kwargs) | ||
self._parse_block_content = False | ||
self._content_blocks: list[ContentBlock] = [] | ||
|
||
@property | ||
def content(self) -> ChatMessageContent: | ||
"""Returns ChatMessageContent data that can be directly assigned to ChatMessage.content. | ||
If only one block is present, this block is of type text and has no cache control set, we just | ||
return it as plain text for simplicity. | ||
""" | ||
if len(self._content_blocks) == 1: | ||
block = self._content_blocks[0] | ||
if block.type == "text" and block.cache_control is None: | ||
return block.text or "" | ||
|
||
return self._content_blocks | ||
|
||
def handle_starttag(self, tag, attrs): # noqa | ||
if tag == "content_block": | ||
self._parse_block_content = True | ||
|
||
def handle_endtag(self, tag): | ||
if tag == "content_block": | ||
self._parse_block_content = False | ||
|
||
def handle_data(self, data): | ||
if self._parse_block_content: | ||
self._content_blocks.append(ContentBlock.model_validate_json(data)) | ||
else: | ||
self._content_blocks.append(ContentBlock(type=ContentBlockType.text, text=data)) | ||
CONTENT_BLOCK_REGEX = re.compile(r"<content_block>((?s:.)*)<\/content_block>") | ||
|
||
|
||
class ChatExtension(Extension): | ||
|
@@ -105,7 +69,19 @@ def _store_chat_messages(self, role, caller): | |
""" | ||
Helper callback. | ||
""" | ||
parser = _ContentBlockParser() | ||
parser.feed(caller()) | ||
cm = ChatMessage(role=role, content=parser.content) | ||
content_blocks: list[ContentBlock] = [] | ||
result = CONTENT_BLOCK_REGEX.match(caller()) | ||
if result is not None: | ||
for g in result.groups(): | ||
content_blocks.append(ContentBlock.model_validate_json(g)) | ||
else: | ||
content_blocks.append(ContentBlock(type=ContentBlockType.text, text=caller())) | ||
|
||
content = content_blocks | ||
if len(content_blocks) == 1: | ||
block = content_blocks[0] | ||
if block.type == "text" and block.cache_control is None: | ||
content = block.text or "" | ||
|
||
cm = ChatMessage(role=role, content=content) | ||
return cm.model_dump_json(exclude_none=True) + "\n" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters