Skip to content

Commit

Permalink
Fixes to Anthropic Function Calling (#708)
Browse files Browse the repository at this point in the history
  • Loading branch information
keepingitneil authored Sep 9, 2024
1 parent 0007461 commit b8e88c8
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 35 deletions.
5 changes: 5 additions & 0 deletions .changeset/smooth-monkeys-perform.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"livekit-plugins-anthropic": patch
---

Fixes to Anthropic Function Calling
5 changes: 5 additions & 0 deletions livekit-agents/livekit/agents/llm/chat_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class ChatMessage:
content: str | list[str | ChatImage] | None = None
tool_calls: list[function_context.FunctionCallInfo] | None = None
tool_call_id: str | None = None
tool_exception: Exception | None = None
_metadata: dict[str, Any] = field(default_factory=dict, repr=False, init=False)

@staticmethod
Expand All @@ -50,16 +51,20 @@ def create_tool_from_called_function(
if not called_function.task.done():
raise ValueError("cannot create a tool result from a running ai function")

tool_exception: Exception | None = None
try:
content = called_function.task.result()
except BaseException as e:
if isinstance(e, Exception):
tool_exception = e
content = f"Error: {e}"

return ChatMessage(
role="tool",
name=called_function.call_info.function_info.name,
content=content,
tool_call_id=called_function.call_info.tool_call_id,
tool_exception=tool_exception,
)

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class LLM(llm.LLM):
def __init__(
self,
*,
model: str | ChatModels = "claude-3-opus-20240229",
model: str | ChatModels = "claude-3-haiku-20240307",
api_key: str | None = None,
base_url: str | None = None,
user: str | None = None,
Expand Down Expand Up @@ -144,6 +144,9 @@ async def __anext__(self):
if not self._anthropic_stream:
self._anthropic_stream = await self._awaitable_anthropic_stream

fn_calling_enabled = self._fnc_ctx is not None
ignore = False

async for event in self._anthropic_stream:
if event.type == "message_start":
pass
Expand All @@ -159,18 +162,34 @@ async def __anext__(self):
elif event.type == "content_block_delta":
delta = event.delta
if delta.type == "text_delta":
text = delta.text

# Anthropic seems to add a prompt when tool calling is enabled
# where responses always start with a "<thinking>" block containing
# the LLM's chain of thought. It's very verbose and not useful for voice
# applications.
if fn_calling_enabled:
if text.startswith("<thinking>"):
ignore = True

if "</thinking>" in text:
text = text.split("</thinking>")[-1]
ignore = False

if ignore:
continue

return llm.ChatChunk(
choices=[
llm.Choice(
delta=llm.ChoiceDelta(
content=delta.text, role="assistant"
)
delta=llm.ChoiceDelta(content=text, role="assistant")
)
]
)
elif delta.type == "input_json_delta":
assert self._fnc_raw_arguments is not None
self._fnc_raw_arguments += delta.partial_json

elif event.type == "content_block_stop":
if self._tool_call_id is not None and self._fnc_ctx:
assert self._fnc_name is not None
Expand Down Expand Up @@ -249,13 +268,15 @@ def _build_anthropic_context(
) -> List[anthropic.types.MessageParam]:
result: List[anthropic.types.MessageParam] = []
for msg in chat_ctx:
a_msg = _build_anthropic_message(msg, cache_key)
a_msg = _build_anthropic_message(msg, cache_key, chat_ctx)
if a_msg:
result.append(a_msg)
return result


def _build_anthropic_message(msg: llm.ChatMessage, cache_key: Any):
def _build_anthropic_message(
msg: llm.ChatMessage, cache_key: Any, chat_ctx: List[llm.ChatMessage]
) -> anthropic.types.MessageParam | None:
if msg.role == "user" or msg.role == "assistant":
a_msg: anthropic.types.MessageParam = {
"role": msg.role,
Expand All @@ -282,38 +303,35 @@ def _build_anthropic_message(msg: llm.ChatMessage, cache_key: Any):
a_content.append(content)
elif isinstance(cnt, llm.ChatImage):
a_content.append(_build_anthropic_image_content(cnt, cache_key))
return a_msg
elif msg.role == "tool":
ant_msg: anthropic.types.MessageParam = {
"role": "assistant",
"content": [],
}
assert isinstance(ant_msg["content"], list)
# make sure to provide when function has been called inside the context
# (+ raw_arguments)

if msg.tool_calls is not None:
for fnc in msg.tool_calls:
ant_msg["content"].append(
{
"id": fnc.tool_call_id,
"type": "tool_use",
"input": fnc.arguments,
"name": fnc.function_info.name,
}
tool_use = anthropic.types.ToolUseBlockParam(
id=fnc.tool_call_id,
type="tool_use",
name=fnc.function_info.name,
input=fnc.arguments,
)
if isinstance(msg.content, str):
ant_msg["content"].append(
{
"tool_use_id": fnc.tool_call_id,
"type": "tool_result",
"content": msg.content,
}
)
else:
logger.warning(
"tool result content is not a string, this is not supported by anthropic"
)
return ant_msg
a_content.append(tool_use)

return a_msg
elif msg.role == "tool":
if not isinstance(msg.content, str):
logger.warning("tool message content is not a string")
return None
if not msg.tool_call_id:
return None

u_content = anthropic.types.ToolResultBlockParam(
tool_use_id=msg.tool_call_id,
type="tool_result",
content=msg.content,
is_error=msg.tool_exception is not None,
)
return {
"role": "user",
"content": [u_content],
}

return None

Expand Down

0 comments on commit b8e88c8

Please sign in to comment.