From 966e52b063d09e2431b6953b13e5cde5b8e42ea3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Monnom?= Date: Tue, 22 Oct 2024 12:51:00 -0700 Subject: [PATCH] voicepipeline: support recursive/chained function calls (#970) --- .changeset/twenty-poems-whisper.md | 5 + .../livekit/agents/pipeline/pipeline_agent.py | 97 +++++++++++-------- 2 files changed, 63 insertions(+), 39 deletions(-) create mode 100644 .changeset/twenty-poems-whisper.md diff --git a/.changeset/twenty-poems-whisper.md b/.changeset/twenty-poems-whisper.md new file mode 100644 index 000000000..d0076bffe --- /dev/null +++ b/.changeset/twenty-poems-whisper.md @@ -0,0 +1,5 @@ +--- +"livekit-agents": patch +--- + +voicepipeline: support recursive/chained function calls diff --git a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py index ef233356b..20b34150d 100644 --- a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py +++ b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py @@ -99,6 +99,7 @@ class _ImplOptions: int_speech_duration: float int_min_words: int min_endpointing_delay: float + max_recursive_fnc_calls: int preemptive_synthesis: bool before_llm_cb: BeforeLLMCallback before_tts_cb: BeforeTTSCallback @@ -149,6 +150,7 @@ def __init__( interrupt_speech_duration: float = 0.5, interrupt_min_words: int = 0, min_endpointing_delay: float = 0.5, + max_recursive_fnc_calls: int = 1, preemptive_synthesis: bool = False, transcription: AgentTranscriptionOptions = AgentTranscriptionOptions(), before_llm_cb: BeforeLLMCallback = _default_before_llm_cb, @@ -203,6 +205,7 @@ def __init__( int_speech_duration=interrupt_speech_duration, int_min_words=interrupt_min_words, min_endpointing_delay=min_endpointing_delay, + max_recursive_fnc_calls=max_recursive_fnc_calls, preemptive_synthesis=preemptive_synthesis, transcription=transcription, before_llm_cb=before_llm_cb, @@ -683,62 +686,70 @@ def _commit_user_question_if_needed() -> None: not user_question or speech_handle.user_commited ), "user speech should have been committed before using tools" + llm_stream = speech_handle.source + # execute functions - call_ctx = AgentCallContext(self, speech_handle.source) + call_ctx = AgentCallContext(self, llm_stream) tk = _CallContextVar.set(call_ctx) - self.emit("function_calls_collected", speech_handle.source.function_calls) - called_fncs_info = speech_handle.source.function_calls - - called_fncs = [] - for fnc in called_fncs_info: - called_fnc = fnc.execute() - called_fncs.append(called_fnc) - logger.debug( - "executing ai function", - extra={ - "function": fnc.function_info.name, - "speech_id": speech_handle.id, - }, - ) - try: - await called_fnc.task - except Exception as e: - logger.exception( - "error executing ai function", + + new_function_calls = llm_stream.function_calls + + for i in range(self._opts.max_recursive_fnc_calls): + self.emit("function_calls_collected", new_function_calls) + + called_fncs = [] + for fnc in new_function_calls: + called_fnc = fnc.execute() + called_fncs.append(called_fnc) + logger.debug( + "executing ai function", extra={ "function": fnc.function_info.name, "speech_id": speech_handle.id, }, - exc_info=e, + ) + try: + await called_fnc.task + except Exception as e: + logger.exception( + "error executing ai function", + extra={ + "function": fnc.function_info.name, + "speech_id": speech_handle.id, + }, + exc_info=e, + ) + + tool_calls_info = [] + tool_calls_results = [] + + for called_fnc in called_fncs: + # ignore the function calls that returns None + if called_fnc.result is None: + continue + + tool_calls_info.append(called_fnc.call_info) + tool_calls_results.append( + ChatMessage.create_tool_from_called_function(called_fnc) ) - self.emit("function_calls_finished", called_fncs) - _CallContextVar.reset(tk) - - tool_calls = [] - tool_calls_results_msg = [] - - for called_fnc in called_fncs: - # ignore the function calls that returns None - if called_fnc.result is None: - continue - - tool_calls.append(called_fnc.call_info) - tool_calls_results_msg.append( - ChatMessage.create_tool_from_called_function(called_fnc) - ) + if not tool_calls_info: + break - if tool_calls: + # generate an answer from the tool calls extra_tools_messages.append( - ChatMessage.create_tool_calls(tool_calls, text=collected_text) + ChatMessage.create_tool_calls(tool_calls_info, text=collected_text) ) - extra_tools_messages.extend(tool_calls_results_msg) + extra_tools_messages.extend(tool_calls_results) chat_ctx = speech_handle.source.chat_ctx.copy() chat_ctx.messages.extend(extra_tools_messages) answer_llm_stream = self._llm.chat( chat_ctx=chat_ctx, + fnc_ctx=self.fnc_ctx + if i < self._opts.max_recursive_fnc_calls - 1 + else None, ) answer_synthesis = self._synthesize_agent_speech( speech_handle.id, answer_llm_stream @@ -750,6 +761,14 @@ def _commit_user_question_if_needed() -> None: collected_text = answer_synthesis.tts_forwarder.played_text interrupted = answer_synthesis.interrupted + new_function_calls = answer_llm_stream.function_calls + + self.emit("function_calls_finished", called_fncs) + + if not new_function_calls: + break + + _CallContextVar.reset(tk) if speech_handle.add_to_chat_ctx and ( not user_question or speech_handle.user_commited