diff --git a/.changeset/thirty-coats-tie.md b/.changeset/thirty-coats-tie.md new file mode 100644 index 000000000..f0c6a9e67 --- /dev/null +++ b/.changeset/thirty-coats-tie.md @@ -0,0 +1,7 @@ +--- +"livekit-plugins-google": minor +"livekit-plugins-openai": patch +"livekit-agents": patch +--- + +make multimodal class generic and support gemini live api diff --git a/examples/multimodal_agent/gemini_agent.py b/examples/multimodal_agent/gemini_agent.py new file mode 100644 index 000000000..81a474609 --- /dev/null +++ b/examples/multimodal_agent/gemini_agent.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +import logging +from typing import Annotated + +import aiohttp +from dotenv import load_dotenv +from livekit.agents import ( + AutoSubscribe, + JobContext, + WorkerOptions, + WorkerType, + cli, + llm, + multimodal, +) +from livekit.plugins import google + +load_dotenv() + +logger = logging.getLogger("my-worker") +logger.setLevel(logging.INFO) + + +async def entrypoint(ctx: JobContext): + logger.info("starting entrypoint") + + fnc_ctx = llm.FunctionContext() + + @fnc_ctx.ai_callable() + async def get_weather( + location: Annotated[ + str, llm.TypeInfo(description="The location to get the weather for") + ], + ): + """Called when the user asks about the weather. This function will return the weather for the given location.""" + logger.info(f"getting weather for {location}") + url = f"https://wttr.in/{location}?format=%C+%t" + async with aiohttp.ClientSession() as session: + async with session.get(url) as response: + if response.status == 200: + weather_data = await response.text() + # # response from the function call is returned to the LLM + return f"The weather in {location} is {weather_data}." + else: + raise Exception( + f"Failed to get weather data, status code: {response.status}" + ) + + await ctx.connect(auto_subscribe=AutoSubscribe.AUDIO_ONLY) + participant = await ctx.wait_for_participant() + + chat_ctx = llm.ChatContext() + + agent = multimodal.MultimodalAgent( + model=google.beta.realtime.RealtimeModel( + voice="Charon", + temperature=0.8, + instructions="You are a helpful assistant", + ), + fnc_ctx=fnc_ctx, + chat_ctx=chat_ctx, + ) + agent.start(ctx.room, participant) + + +if __name__ == "__main__": + cli.run_app(WorkerOptions(entrypoint_fnc=entrypoint, worker_type=WorkerType.ROOM)) diff --git a/examples/multimodal_agent.py b/examples/multimodal_agent/openai_agent.py similarity index 100% rename from examples/multimodal_agent.py rename to examples/multimodal_agent/openai_agent.py diff --git a/livekit-agents/livekit/agents/cli/log.py b/livekit-agents/livekit/agents/cli/log.py index dc16bfdfa..c4b5e5e52 100644 --- a/livekit-agents/livekit/agents/cli/log.py +++ b/livekit-agents/livekit/agents/cli/log.py @@ -18,6 +18,7 @@ "openai", "watchfiles", "anthropic", + "websockets.client", ] diff --git a/livekit-agents/livekit/agents/multimodal/__init__.py b/livekit-agents/livekit/agents/multimodal/__init__.py index d165c082a..f741e168a 100644 --- a/livekit-agents/livekit/agents/multimodal/__init__.py +++ b/livekit-agents/livekit/agents/multimodal/__init__.py @@ -1,3 +1,13 @@ -from .multimodal_agent import AgentTranscriptionOptions, MultimodalAgent +from .multimodal_agent import ( + AgentTranscriptionOptions, + MultimodalAgent, + _RealtimeAPI, + _RealtimeAPISession, +) -__all__ = ["MultimodalAgent", "AgentTranscriptionOptions"] +__all__ = [ + "MultimodalAgent", + "AgentTranscriptionOptions", + "_RealtimeAPI", + "_RealtimeAPISession", +] diff --git a/livekit-agents/livekit/agents/multimodal/multimodal_agent.py b/livekit-agents/livekit/agents/multimodal/multimodal_agent.py index ee3a2d992..f02bb2e64 100644 --- a/livekit-agents/livekit/agents/multimodal/multimodal_agent.py +++ b/livekit-agents/livekit/agents/multimodal/multimodal_agent.py @@ -2,7 +2,17 @@ import asyncio from dataclasses import dataclass -from typing import Callable, Literal, Protocol +from typing import ( + Any, + AsyncIterable, + Callable, + Literal, + Optional, + Protocol, + TypeVar, + Union, + overload, +) import aiohttp from livekit import rtc @@ -28,6 +38,76 @@ ] +class _InputTranscriptionProto(Protocol): + item_id: str + """id of the item""" + transcript: str + """transcript of the input audio""" + + +class _ContentProto(Protocol): + response_id: str + item_id: str + output_index: int + content_index: int + text: str + audio: list[rtc.AudioFrame] + text_stream: AsyncIterable[str] + audio_stream: AsyncIterable[rtc.AudioFrame] + content_type: Literal["text", "audio"] + + +class _CapabilitiesProto(Protocol): + supports_truncate: bool + + +class _RealtimeAPI(Protocol): + """Realtime API protocol""" + + @property + def capabilities(self) -> _CapabilitiesProto: ... + def session( + self, + *, + chat_ctx: llm.ChatContext | None = None, + fnc_ctx: llm.FunctionContext | None = None, + ) -> _RealtimeAPISession: + """ + Create a new realtime session with the given chat and function contexts. + """ + pass + + +T = TypeVar("T", bound=Callable[..., Any]) + + +class _RealtimeAPISession(Protocol): + async def set_chat_ctx(self, ctx: llm.ChatContext) -> None: ... + @overload + def on(self, event: str, callback: None = None) -> Callable[[T], T]: ... + @overload + def on(self, event: str, callback: T) -> T: ... + def on( + self, event: str, callback: Optional[T] = None + ) -> Union[T, Callable[[T], T]]: ... + + def _push_audio(self, frame: rtc.AudioFrame) -> None: ... + @property + def fnc_ctx(self) -> llm.FunctionContext | None: ... + @fnc_ctx.setter + def fnc_ctx(self, value: llm.FunctionContext | None) -> None: ... + def chat_ctx_copy(self) -> llm.ChatContext: ... + def _recover_from_text_response(self, item_id: str) -> None: ... + def _update_conversation_item_content( + self, + item_id: str, + content: llm.ChatContent | list[llm.ChatContent] | None = None, + ) -> None: ... + def _truncate_conversation_item( + self, item_id: str, content_index: int, audio_end_ms: int + ) -> None: ... + + @dataclass(frozen=True) class AgentTranscriptionOptions: user_transcription: bool = True @@ -50,9 +130,6 @@ class AgentTranscriptionOptions: representing the hyphenated parts of the word.""" -class S2SModel(Protocol): ... - - @dataclass(frozen=True) class _ImplOptions: transcription: AgentTranscriptionOptions @@ -62,7 +139,7 @@ class MultimodalAgent(utils.EventEmitter[EventTypes]): def __init__( self, *, - model: S2SModel, + model: _RealtimeAPI, vad: vad.VAD | None = None, chat_ctx: llm.ChatContext | None = None, fnc_ctx: llm.FunctionContext | None = None, @@ -73,7 +150,7 @@ def __init__( """Create a new MultimodalAgent. Args: - model: S2SModel instance. + model: RealtimeAPI instance. vad: Voice Activity Detection (VAD) instance. chat_ctx: Chat context for the assistant. fnc_ctx: Function context for the assistant. @@ -89,10 +166,6 @@ def __init__( super().__init__() self._loop = loop or asyncio.get_event_loop() - from livekit.plugins.openai import realtime - - assert isinstance(model, realtime.RealtimeModel) - self._model = model self._vad = vad self._chat_ctx = chat_ctx @@ -177,13 +250,8 @@ async def _init_and_start(): # Schedule the initialization and start task asyncio.create_task(_init_and_start()) - from livekit.plugins.openai import realtime - @self._session.on("response_content_added") - def _on_content_added(message: realtime.RealtimeContent): - if message.content_type == "text": - return - + def _on_content_added(message: _ContentProto): tr_fwd = transcription.TTSSegmentsForwarder( room=self._room, participant=self._room.local_participant, @@ -202,7 +270,7 @@ def _on_content_added(message: realtime.RealtimeContent): ) @self._session.on("response_content_done") - def _response_content_done(message: realtime.RealtimeContent): + def _response_content_done(message: _ContentProto): if message.content_type == "text": if self._text_response_retries >= self._max_text_response_retries: raise RuntimeError( @@ -236,9 +304,7 @@ def _input_speech_committed(): ) @self._session.on("input_speech_transcription_completed") - def _input_speech_transcription_completed( - ev: realtime.InputTranscriptionCompleted, - ): + def _input_speech_transcription_completed(ev: _InputTranscriptionProto): self._stt_forwarder.update( stt.SpeechEvent( type=stt.SpeechEventType.FINAL_TRANSCRIPT, @@ -248,6 +314,7 @@ def _input_speech_transcription_completed( user_msg = ChatMessage.create( text=ev.transcript, role="user", id=ev.item_id ) + self._session._update_conversation_item_content( ev.item_id, user_msg.content ) @@ -265,11 +332,14 @@ def _input_speech_started(): if self._playing_handle is not None and not self._playing_handle.done(): self._playing_handle.interrupt() - self._session.conversation.item.truncate( - item_id=self._playing_handle.item_id, - content_index=self._playing_handle.content_index, - audio_end_ms=int(self._playing_handle.audio_samples / 24000 * 1000), - ) + if self._model.capabilities.supports_truncate: + self._session._truncate_conversation_item( + item_id=self._playing_handle.item_id, + content_index=self._playing_handle.content_index, + audio_end_ms=int( + self._playing_handle.audio_samples / 24000 * 1000 + ), + ) @self._session.on("input_speech_stopped") def _input_speech_stopped(): @@ -330,9 +400,10 @@ def _on_playout_stopped(interrupted: bool) -> None: role="assistant", id=self._playing_handle.item_id, ) - self._session._update_conversation_item_content( - self._playing_handle.item_id, msg.content - ) + if self._model.capabilities.supports_truncate: + self._session._update_conversation_item_content( + self._playing_handle.item_id, msg.content + ) if interrupted: self.emit("agent_speech_interrupted", msg) @@ -366,7 +437,7 @@ def _on_playout_stopped(interrupted: bool) -> None: ) async for frame in self._input_audio_ch: for f in bstream.write(frame.data.tobytes()): - self._session.input_audio_buffer.append(f) + self._session._push_audio(f) def _on_participant_connected(self, participant: rtc.RemoteParticipant): if self._linked_participant is None: diff --git a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/__init__.py b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/__init__.py index ca754bd30..88e163634 100644 --- a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/__init__.py +++ b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/__init__.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from . import beta from .stt import STT, SpeechStream from .tts import TTS from .version import __version__ -__all__ = ["STT", "TTS", "SpeechStream", "__version__"] - +__all__ = ["STT", "TTS", "SpeechStream", "__version__", "beta"] from livekit.agents import Plugin from .log import logger diff --git a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/__init__.py b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/__init__.py new file mode 100644 index 000000000..89cb122c8 --- /dev/null +++ b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/__init__.py @@ -0,0 +1,3 @@ +from . import realtime + +__all__ = ["realtime"] diff --git a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/__init__.py b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/__init__.py new file mode 100644 index 000000000..e95a86917 --- /dev/null +++ b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/__init__.py @@ -0,0 +1,15 @@ +from .api_proto import ( + ClientEvents, + LiveAPIModels, + ResponseModality, + Voice, +) +from .realtime_api import RealtimeModel + +__all__ = [ + "RealtimeModel", + "ClientEvents", + "LiveAPIModels", + "ResponseModality", + "Voice", +] diff --git a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/api_proto.py b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/api_proto.py new file mode 100644 index 000000000..c02fb3859 --- /dev/null +++ b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/api_proto.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +import inspect +from typing import Any, Dict, List, Literal, Sequence, Union + +from google.genai import types # type: ignore + +LiveAPIModels = Literal["gemini-2.0-flash-exp"] + +Voice = Literal["Puck", "Charon", "Kore", "Fenrir", "Aoede"] +ResponseModality = Literal["AUDIO", "TEXT"] + + +ClientEvents = Union[ + types.ContentListUnion, + types.ContentListUnionDict, + types.LiveClientContentOrDict, + types.LiveClientRealtimeInput, + types.LiveClientRealtimeInputOrDict, + types.LiveClientToolResponseOrDict, + types.FunctionResponseOrDict, + Sequence[types.FunctionResponseOrDict], +] + + +JSON_SCHEMA_TYPE_MAP = { + str: "string", + int: "integer", + float: "number", + bool: "boolean", + dict: "object", + list: "array", +} + + +def _build_parameters(arguments: Dict[str, Any]) -> types.SchemaDict: + properties: Dict[str, types.SchemaDict] = {} + required: List[str] = [] + + for arg_name, arg_info in arguments.items(): + py_type = arg_info.type + if py_type not in JSON_SCHEMA_TYPE_MAP: + raise ValueError(f"Unsupported type: {py_type}") + + prop: types.SchemaDict = { + "type": JSON_SCHEMA_TYPE_MAP[py_type], + "description": arg_info.description, + } + + if arg_info.choices: + prop["enum"] = arg_info.choices + + properties[arg_name] = prop + + if arg_info.default is inspect.Parameter.empty: + required.append(arg_name) + + parameters: types.SchemaDict = {"type": "object", "properties": properties} + + if required: + parameters["required"] = required + + return parameters + + +def _build_tools(fnc_ctx: Any) -> List[types.FunctionDeclarationDict]: + function_declarations: List[types.FunctionDeclarationDict] = [] + for fnc_info in fnc_ctx.ai_functions.values(): + parameters = _build_parameters(fnc_info.arguments) + + func_decl: types.FunctionDeclarationDict = { + "name": fnc_info.name, + "description": fnc_info.description, + "parameters": parameters, + } + + function_declarations.append(func_decl) + + return function_declarations diff --git a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/realtime_api.py b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/realtime_api.py new file mode 100644 index 000000000..40bb0d7a1 --- /dev/null +++ b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/beta/realtime/realtime_api.py @@ -0,0 +1,424 @@ +from __future__ import annotations + +import asyncio +import base64 +import json +import os +from dataclasses import dataclass +from typing import AsyncIterable, Literal + +from livekit import rtc +from livekit.agents import llm, utils +from livekit.agents.llm.function_context import _create_ai_function_info + +from google import genai # type: ignore +from google.genai.types import ( # type: ignore + FunctionResponse, + GenerationConfigDict, + LiveClientToolResponse, + LiveConnectConfigDict, + PrebuiltVoiceConfig, + SpeechConfig, + VoiceConfig, +) + +from ...log import logger +from .api_proto import ( + ClientEvents, + LiveAPIModels, + ResponseModality, + Voice, + _build_tools, +) + +EventTypes = Literal[ + "start_session", + "input_speech_started", + "response_content_added", + "response_content_done", + "function_calls_collected", + "function_calls_finished", + "function_calls_cancelled", +] + + +@dataclass +class GeminiContent: + response_id: str + item_id: str + output_index: int + content_index: int + text: str + audio: list[rtc.AudioFrame] + text_stream: AsyncIterable[str] + audio_stream: AsyncIterable[rtc.AudioFrame] + content_type: Literal["text", "audio"] + + +@dataclass +class Capabilities: + supports_truncate: bool + + +@dataclass +class ModelOptions: + model: LiveAPIModels | str + api_key: str | None + voice: Voice | str + response_modalities: ResponseModality + vertexai: bool + project: str | None + location: str | None + candidate_count: int + temperature: float | None + max_output_tokens: int | None + top_p: float | None + top_k: int | None + presence_penalty: float | None + frequency_penalty: float | None + instructions: str + + +class RealtimeModel: + def __init__( + self, + *, + instructions: str = "", + model: LiveAPIModels | str = "gemini-2.0-flash-exp", + api_key: str | None = None, + voice: Voice | str = "Puck", + modalities: ResponseModality = "AUDIO", + vertexai: bool = False, + project: str | None = None, + location: str | None = None, + candidate_count: int = 1, + temperature: float | None = None, + max_output_tokens: int | None = None, + top_p: float | None = None, + top_k: int | None = None, + presence_penalty: float | None = None, + frequency_penalty: float | None = None, + loop: asyncio.AbstractEventLoop | None = None, + ): + """ + Initializes a RealtimeModel instance for interacting with Google's Realtime API. + + Args: + instructions (str, optional): Initial system instructions for the model. Defaults to "". + api_key (str or None, optional): OpenAI API key. If None, will attempt to read from the environment variable OPENAI_API_KEY + modalities (ResponseModality): Modalities to use, such as ["TEXT", "AUDIO"]. Defaults to ["AUDIO"]. + model (str or None, optional): The name of the model to use. Defaults to "gemini-2.0-flash-exp". + voice (api_proto.Voice, optional): Voice setting for audio outputs. Defaults to "Puck". + temperature (float, optional): Sampling temperature for response generation. Defaults to 0.8. + vertexai (bool, optional): Whether to use VertexAI for the API. Defaults to False. + project (str or None, optional): The project to use for the API. Defaults to None. (for vertexai) + location (str or None, optional): The location to use for the API. Defaults to None. (for vertexai) + candidate_count (int, optional): The number of candidate responses to generate. Defaults to 1. + top_p (float, optional): The top-p value for response generation + top_k (int, optional): The top-k value for response generation + presence_penalty (float, optional): The presence penalty for response generation + frequency_penalty (float, optional): The frequency penalty for response generation + loop (asyncio.AbstractEventLoop or None, optional): Event loop to use for async operations. If None, the current event loop is used. + + Raises: + ValueError: If the API key is not provided and cannot be found in environment variables. + """ + super().__init__() + self._capabilities = Capabilities( + supports_truncate=False, + ) + self._model = model + self._loop = loop or asyncio.get_event_loop() + self._api_key = api_key or os.environ.get("GOOGLE_API_KEY") + self._vertexai = vertexai + self._project_id = project or os.environ.get("GOOGLE_PROJECT") + self._location = location or os.environ.get("GOOGLE_LOCATION") + if self._api_key is None and not self._vertexai: + raise ValueError("GOOGLE_API_KEY is not set") + + self._rt_sessions: list[GeminiRealtimeSession] = [] + self._opts = ModelOptions( + model=model, + api_key=api_key, + voice=voice, + response_modalities=modalities, + vertexai=vertexai, + project=project, + location=location, + candidate_count=candidate_count, + temperature=temperature, + max_output_tokens=max_output_tokens, + top_p=top_p, + top_k=top_k, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + instructions=instructions, + ) + + @property + def sessions(self) -> list[GeminiRealtimeSession]: + return self._rt_sessions + + @property + def capabilities(self) -> Capabilities: + return self._capabilities + + def session( + self, + *, + chat_ctx: llm.ChatContext | None = None, + fnc_ctx: llm.FunctionContext | None = None, + ) -> GeminiRealtimeSession: + session = GeminiRealtimeSession( + opts=self._opts, + chat_ctx=chat_ctx or llm.ChatContext(), + fnc_ctx=fnc_ctx, + loop=self._loop, + ) + self._rt_sessions.append(session) + + return session + + async def aclose(self) -> None: + for session in self._rt_sessions: + await session.aclose() + + +class GeminiRealtimeSession(utils.EventEmitter[EventTypes]): + def __init__( + self, + *, + opts: ModelOptions, + chat_ctx: llm.ChatContext, + fnc_ctx: llm.FunctionContext | None, + loop: asyncio.AbstractEventLoop, + ): + """ + Initializes a GeminiRealtimeSession instance for interacting with Google's Realtime API. + + Args: + opts (ModelOptions): The model options for the session. + chat_ctx (llm.ChatContext): The chat context for the session. + fnc_ctx (llm.FunctionContext or None): The function context for the session. + loop (asyncio.AbstractEventLoop): The event loop for the session. + """ + super().__init__() + self._loop = loop + self._opts = opts + self._chat_ctx = chat_ctx + self._fnc_ctx = fnc_ctx + self._fnc_tasks = utils.aio.TaskSet() + + tools = [] + if self._fnc_ctx is not None: + functions = _build_tools(self._fnc_ctx) + tools.append({"function_declarations": functions}) + + self._config = LiveConnectConfigDict( + model=self._opts.model, + response_modalities=self._opts.response_modalities, + generation_config=GenerationConfigDict( + candidate_count=self._opts.candidate_count, + temperature=self._opts.temperature, + max_output_tokens=self._opts.max_output_tokens, + top_p=self._opts.top_p, + top_k=self._opts.top_k, + presence_penalty=self._opts.presence_penalty, + frequency_penalty=self._opts.frequency_penalty, + ), + system_instruction=self._opts.instructions, + speech_config=SpeechConfig( + voice_config=VoiceConfig( + prebuilt_voice_config=PrebuiltVoiceConfig( + voice_name=self._opts.voice + ) + ) + ), + tools=tools, + ) + self._client = genai.Client( + http_options={"api_version": "v1alpha"}, + api_key=self._opts.api_key, + vertexai=self._opts.vertexai, + project=self._opts.project, + location=self._opts.location, + ) + self._main_atask = asyncio.create_task( + self._main_task(), name="gemini-realtime-session" + ) + # dummy task to wait for the session to be initialized # TODO: sync chat ctx + self._init_sync_task = asyncio.create_task( + asyncio.sleep(0), name="gemini-realtime-session-init" + ) + self._send_ch = utils.aio.Chan[ClientEvents]() + self._active_response_id = None + + async def aclose(self) -> None: + if self._send_ch.closed: + return + + self._send_ch.close() + await self._main_atask + + @property + def fnc_ctx(self) -> llm.FunctionContext | None: + return self._fnc_ctx + + @fnc_ctx.setter + def fnc_ctx(self, value: llm.FunctionContext | None) -> None: + self._fnc_ctx = value + + def _push_audio(self, frame: rtc.AudioFrame) -> None: + data = base64.b64encode(frame.data).decode("utf-8") + self._queue_msg({"mime_type": "audio/pcm", "data": data}) + + def _queue_msg(self, msg: dict) -> None: + self._send_ch.send_nowait(msg) + + def chat_ctx_copy(self) -> llm.ChatContext: + return self._chat_ctx.copy() + + async def set_chat_ctx(self, ctx: llm.ChatContext) -> None: + self._chat_ctx = ctx.copy() + + @utils.log_exceptions(logger=logger) + async def _main_task(self): + @utils.log_exceptions(logger=logger) + async def _send_task(): + async for msg in self._send_ch: + await self._session.send(msg) + + await self._session.send(".", end_of_turn=True) + + @utils.log_exceptions(logger=logger) + async def _recv_task(): + while True: + async for response in self._session.receive(): + if self._active_response_id is None: + self._active_response_id = utils.shortuuid() + text_stream = utils.aio.Chan[str]() + audio_stream = utils.aio.Chan[rtc.AudioFrame]() + content = GeminiContent( + response_id=self._active_response_id, + item_id=self._active_response_id, + output_index=0, + content_index=0, + text="", + audio=[], + text_stream=text_stream, + audio_stream=audio_stream, + content_type=self._opts.response_modalities, + ) + self.emit("response_content_added", content) + + server_content = response.server_content + if server_content: + model_turn = server_content.model_turn + if model_turn: + for part in model_turn.parts: + if part.text: + content.text_stream.send_nowait(part.text) + if part.inline_data: + frame = rtc.AudioFrame( + data=part.inline_data.data, + sample_rate=24000, + num_channels=1, + samples_per_channel=len(part.inline_data.data) + // 2, + ) + content.audio_stream.send_nowait(frame) + + if server_content.interrupted or server_content.turn_complete: + for stream in (content.text_stream, content.audio_stream): + if isinstance(stream, utils.aio.Chan): + stream.close() + + if server_content.interrupted: + self.emit("input_speech_started") + elif server_content.turn_complete: + self.emit("response_content_done", content) + + self._active_response_id = None + + if response.tool_call: + if self._fnc_ctx is None: + raise ValueError("Function context is not set") + fnc_calls = [] + for fnc_call in response.tool_call.function_calls: + fnc_call_info = _create_ai_function_info( + self._fnc_ctx, + fnc_call.id, + fnc_call.name, + json.dumps(fnc_call.args), + ) + fnc_calls.append(fnc_call_info) + + self.emit("function_calls_collected", fnc_calls) + + for fnc_call_info in fnc_calls: + self._fnc_tasks.create_task( + self._run_fnc_task(fnc_call_info, content.item_id) + ) + + # Handle function call cancellations + if response.tool_call_cancellation: + logger.warning( + "function call cancelled", + extra={ + "function_call_ids": response.tool_call_cancellation.function_call_ids, + }, + ) + self.emit( + "function_calls_cancelled", + response.tool_call_cancellation.function_call_ids, + ) + + async with self._client.aio.live.connect( + model=self._opts.model, config=self._config + ) as session: + self._session = session + tasks = [ + asyncio.create_task(_send_task(), name="gemini-realtime-send"), + asyncio.create_task(_recv_task(), name="gemini-realtime-recv"), + ] + + try: + await asyncio.gather(*tasks) + finally: + await utils.aio.gracefully_cancel(*tasks) + await self._session.close() + + @utils.log_exceptions(logger=logger) + async def _run_fnc_task(self, fnc_call_info: llm.FunctionCallInfo, item_id: str): + logger.debug( + "executing ai function", + extra={ + "function": fnc_call_info.function_info.name, + }, + ) + + called_fnc = fnc_call_info.execute() + try: + await called_fnc.task + except Exception as e: + logger.exception( + "error executing ai function", + extra={ + "function": fnc_call_info.function_info.name, + }, + exc_info=e, + ) + tool_call = llm.ChatMessage.create_tool_from_called_function(called_fnc) + if tool_call.content is not None: + tool_response = LiveClientToolResponse( + function_responses=[ + FunctionResponse( + name=tool_call.name, + id=tool_call.tool_call_id, + response={"result": tool_call.content}, + ) + ] + ) + await self._session.send(tool_response) + + self.emit("function_calls_finished", [called_fnc]) diff --git a/livekit-plugins/livekit-plugins-google/setup.py b/livekit-plugins/livekit-plugins-google/setup.py index 87646895f..0db8addce 100644 --- a/livekit-plugins/livekit-plugins-google/setup.py +++ b/livekit-plugins/livekit-plugins-google/setup.py @@ -51,6 +51,7 @@ "google-auth >= 2, < 3", "google-cloud-speech >= 2, < 3", "google-cloud-texttospeech >= 2, < 3", + "google-genai >= 0.3.0", "livekit-agents>=0.12.3", ], package_data={"livekit.plugins.google": ["py.typed"]}, diff --git a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/__init__.py b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/__init__.py index 471deef37..fbb453609 100644 --- a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/__init__.py +++ b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/__init__.py @@ -2,8 +2,6 @@ from .realtime_model import ( DEFAULT_INPUT_AUDIO_TRANSCRIPTION, DEFAULT_SERVER_VAD_OPTIONS, - InputTranscriptionCompleted, - InputTranscriptionFailed, InputTranscriptionOptions, RealtimeContent, RealtimeError, @@ -17,8 +15,6 @@ ) __all__ = [ - "InputTranscriptionCompleted", - "InputTranscriptionFailed", "RealtimeContent", "RealtimeOutput", "RealtimeResponse", diff --git a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/realtime_model.py b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/realtime_model.py index 26bc2649b..10d7abc1f 100644 --- a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/realtime_model.py +++ b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/realtime/realtime_model.py @@ -4,6 +4,7 @@ import base64 import os import time +import weakref from copy import deepcopy from dataclasses import dataclass from typing import AsyncIterable, Literal, Optional, Union, cast, overload @@ -105,8 +106,11 @@ class RealtimeToolCall: """id of the tool call""" -# TODO(theomonnom): add the content type directly inside RealtimeContent? -# text/audio/transcript? +@dataclass +class Capabilities: + supports_truncate: bool + + @dataclass class RealtimeContent: response_id: str @@ -284,6 +288,9 @@ def __init__( ValueError: If the API key is not provided and cannot be found in environment variables. """ super().__init__() + self._capabilities = Capabilities( + supports_truncate=True, + ) self._base_url = base_url is_azure = ( @@ -322,7 +329,7 @@ def __init__( ) self._loop = loop or asyncio.get_event_loop() - self._rt_sessions: list[RealtimeSession] = [] + self._rt_sessions = weakref.WeakSet[RealtimeSession]() self._http_session = http_session @classmethod @@ -427,9 +434,13 @@ def _ensure_session(self) -> aiohttp.ClientSession: return self._http_session @property - def sessions(self) -> list[RealtimeSession]: + def sessions(self) -> weakref.WeakSet[RealtimeSession]: return self._rt_sessions + @property + def capabilities(self) -> Capabilities: + return self._capabilities + def session( self, *, @@ -475,7 +486,7 @@ def session( http_session=self._ensure_session(), loop=self._loop, ) - self._rt_sessions.append(new_session) + self._rt_sessions.add(new_session) return new_session async def aclose(self) -> None: @@ -854,6 +865,9 @@ def conversation(self) -> Conversation: def input_audio_buffer(self) -> InputAudioBuffer: return RealtimeSession.InputAudioBuffer(self) + def _push_audio(self, frame: rtc.AudioFrame) -> None: + self.input_audio_buffer.append(frame) + @property def response(self) -> Response: return RealtimeSession.Response(self) @@ -1023,6 +1037,15 @@ def _recover_from_text_response(self, item_id: str | None = None) -> None: self.conversation.item.create(self._create_empty_user_audio_message(1.0)) self.response.create(on_duplicate="keep_both") + def _truncate_conversation_item( + self, item_id: str, content_index: int, audio_end_ms: int + ) -> None: + self.conversation.item.truncate( + item_id=item_id, + content_index=content_index, + audio_end_ms=audio_end_ms, + ) + def _update_conversation_item_content( self, item_id: str, content: llm.ChatContent | list[llm.ChatContent] | None ) -> None: @@ -1662,7 +1685,7 @@ async def _run_fnc_task(self, fnc_call_info: llm.FunctionCallInfo, item_id: str) "function": fnc_call_info.function_info.name, }, ) - if called_fnc.result is not None: + if tool_call.content is not None: create_fut = self.conversation.item.create( tool_call, previous_item_id=item_id,