From 09d699120764e42f50832caeffdc634c7b4f5628 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Monnom?= Date: Sat, 12 Oct 2024 17:26:26 -0700 Subject: [PATCH] be more resilient with incomplete function tools outputs (#909) --- livekit-agents/livekit/agents/llm/chat_context.py | 2 +- .../livekit-plugins-openai/livekit/plugins/openai/llm.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/livekit-agents/livekit/agents/llm/chat_context.py b/livekit-agents/livekit/agents/llm/chat_context.py index d90bae67f..a65479288 100644 --- a/livekit-agents/livekit/agents/llm/chat_context.py +++ b/livekit-agents/livekit/agents/llm/chat_context.py @@ -80,7 +80,7 @@ def create_tool_from_called_function( def create_tool_calls( called_functions: list[function_context.FunctionCallInfo], ) -> "ChatMessage": - return ChatMessage(role="assistant", tool_calls=called_functions) + return ChatMessage(role="assistant", tool_calls=called_functions, content="") @staticmethod def create( diff --git a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py index 9e1951290..c784759d3 100644 --- a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py +++ b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py @@ -467,6 +467,7 @@ def chat( temperature = self._opts.temperature messages = _build_oai_context(chat_ctx, id(self)) + cmp = self._client.chat.completions.create( messages=messages, model=self._opts.model, @@ -543,7 +544,7 @@ def _parse_choice(self, choice: Choice) -> llm.ChatChunk | None: if call_chunk is not None: return call_chunk - if choice.finish_reason == "tool_calls": + if choice.finish_reason in ("tool_calls", "stop") and self._tool_call_id: # we're done with the tool calls, run the last one return self._try_run_function(choice) @@ -576,6 +577,7 @@ def _try_run_function(self, choice: Choice) -> llm.ChatChunk | None: fnc_info = llm._oai_api.create_ai_function_info( self._fnc_ctx, self._tool_call_id, self._fnc_name, self._fnc_raw_arguments ) + self._tool_call_id = self._fnc_name = self._fnc_raw_arguments = None self._function_calls_info.append(fnc_info)