Skip to content

Commit

Permalink
voicepipeline: support recursive/chained function calls (#970)
Browse files Browse the repository at this point in the history
  • Loading branch information
theomonnom authored Oct 22, 2024
1 parent 7cffc9b commit 966e52b
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 39 deletions.
5 changes: 5 additions & 0 deletions .changeset/twenty-poems-whisper.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"livekit-agents": patch
---

voicepipeline: support recursive/chained function calls
97 changes: 58 additions & 39 deletions livekit-agents/livekit/agents/pipeline/pipeline_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ class _ImplOptions:
int_speech_duration: float
int_min_words: int
min_endpointing_delay: float
max_recursive_fnc_calls: int
preemptive_synthesis: bool
before_llm_cb: BeforeLLMCallback
before_tts_cb: BeforeTTSCallback
Expand Down Expand Up @@ -149,6 +150,7 @@ def __init__(
interrupt_speech_duration: float = 0.5,
interrupt_min_words: int = 0,
min_endpointing_delay: float = 0.5,
max_recursive_fnc_calls: int = 1,
preemptive_synthesis: bool = False,
transcription: AgentTranscriptionOptions = AgentTranscriptionOptions(),
before_llm_cb: BeforeLLMCallback = _default_before_llm_cb,
Expand Down Expand Up @@ -203,6 +205,7 @@ def __init__(
int_speech_duration=interrupt_speech_duration,
int_min_words=interrupt_min_words,
min_endpointing_delay=min_endpointing_delay,
max_recursive_fnc_calls=max_recursive_fnc_calls,
preemptive_synthesis=preemptive_synthesis,
transcription=transcription,
before_llm_cb=before_llm_cb,
Expand Down Expand Up @@ -683,62 +686,70 @@ def _commit_user_question_if_needed() -> None:
not user_question or speech_handle.user_commited
), "user speech should have been committed before using tools"

llm_stream = speech_handle.source

# execute functions
call_ctx = AgentCallContext(self, speech_handle.source)
call_ctx = AgentCallContext(self, llm_stream)
tk = _CallContextVar.set(call_ctx)
self.emit("function_calls_collected", speech_handle.source.function_calls)
called_fncs_info = speech_handle.source.function_calls

called_fncs = []
for fnc in called_fncs_info:
called_fnc = fnc.execute()
called_fncs.append(called_fnc)
logger.debug(
"executing ai function",
extra={
"function": fnc.function_info.name,
"speech_id": speech_handle.id,
},
)
try:
await called_fnc.task
except Exception as e:
logger.exception(
"error executing ai function",

new_function_calls = llm_stream.function_calls

for i in range(self._opts.max_recursive_fnc_calls):
self.emit("function_calls_collected", new_function_calls)

called_fncs = []
for fnc in new_function_calls:
called_fnc = fnc.execute()
called_fncs.append(called_fnc)
logger.debug(
"executing ai function",
extra={
"function": fnc.function_info.name,
"speech_id": speech_handle.id,
},
exc_info=e,
)
try:
await called_fnc.task
except Exception as e:
logger.exception(
"error executing ai function",
extra={
"function": fnc.function_info.name,
"speech_id": speech_handle.id,
},
exc_info=e,
)

tool_calls_info = []
tool_calls_results = []

for called_fnc in called_fncs:
# ignore the function calls that returns None
if called_fnc.result is None:
continue

tool_calls_info.append(called_fnc.call_info)
tool_calls_results.append(
ChatMessage.create_tool_from_called_function(called_fnc)
)

self.emit("function_calls_finished", called_fncs)
_CallContextVar.reset(tk)

tool_calls = []
tool_calls_results_msg = []

for called_fnc in called_fncs:
# ignore the function calls that returns None
if called_fnc.result is None:
continue

tool_calls.append(called_fnc.call_info)
tool_calls_results_msg.append(
ChatMessage.create_tool_from_called_function(called_fnc)
)
if not tool_calls_info:
break

if tool_calls:
# generate an answer from the tool calls
extra_tools_messages.append(
ChatMessage.create_tool_calls(tool_calls, text=collected_text)
ChatMessage.create_tool_calls(tool_calls_info, text=collected_text)
)
extra_tools_messages.extend(tool_calls_results_msg)
extra_tools_messages.extend(tool_calls_results)

chat_ctx = speech_handle.source.chat_ctx.copy()
chat_ctx.messages.extend(extra_tools_messages)

answer_llm_stream = self._llm.chat(
chat_ctx=chat_ctx,
fnc_ctx=self.fnc_ctx
if i < self._opts.max_recursive_fnc_calls - 1
else None,
)
answer_synthesis = self._synthesize_agent_speech(
speech_handle.id, answer_llm_stream
Expand All @@ -750,6 +761,14 @@ def _commit_user_question_if_needed() -> None:

collected_text = answer_synthesis.tts_forwarder.played_text
interrupted = answer_synthesis.interrupted
new_function_calls = answer_llm_stream.function_calls

self.emit("function_calls_finished", called_fncs)

if not new_function_calls:
break

_CallContextVar.reset(tk)

if speech_handle.add_to_chat_ctx and (
not user_question or speech_handle.user_commited
Expand Down

0 comments on commit 966e52b

Please sign in to comment.