Skip to content

Commit

Permalink
Support Gemini Live API (#1240)
Browse files Browse the repository at this point in the history
  • Loading branch information
jayeshp19 authored Dec 30, 2024
1 parent bd36bc9 commit b7f2895
Show file tree
Hide file tree
Showing 14 changed files with 741 additions and 43 deletions.
7 changes: 7 additions & 0 deletions .changeset/thirty-coats-tie.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
"livekit-plugins-google": minor
"livekit-plugins-openai": patch
"livekit-agents": patch
---

make multimodal class generic and support gemini live api
68 changes: 68 additions & 0 deletions examples/multimodal_agent/gemini_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from __future__ import annotations

import logging
from typing import Annotated

import aiohttp
from dotenv import load_dotenv
from livekit.agents import (
AutoSubscribe,
JobContext,
WorkerOptions,
WorkerType,
cli,
llm,
multimodal,
)
from livekit.plugins import google

load_dotenv()

logger = logging.getLogger("my-worker")
logger.setLevel(logging.INFO)


async def entrypoint(ctx: JobContext):
logger.info("starting entrypoint")

fnc_ctx = llm.FunctionContext()

@fnc_ctx.ai_callable()
async def get_weather(
location: Annotated[
str, llm.TypeInfo(description="The location to get the weather for")
],
):
"""Called when the user asks about the weather. This function will return the weather for the given location."""
logger.info(f"getting weather for {location}")
url = f"https://wttr.in/{location}?format=%C+%t"
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
if response.status == 200:
weather_data = await response.text()
# # response from the function call is returned to the LLM
return f"The weather in {location} is {weather_data}."
else:
raise Exception(
f"Failed to get weather data, status code: {response.status}"
)

await ctx.connect(auto_subscribe=AutoSubscribe.AUDIO_ONLY)
participant = await ctx.wait_for_participant()

chat_ctx = llm.ChatContext()

agent = multimodal.MultimodalAgent(
model=google.beta.realtime.RealtimeModel(
voice="Charon",
temperature=0.8,
instructions="You are a helpful assistant",
),
fnc_ctx=fnc_ctx,
chat_ctx=chat_ctx,
)
agent.start(ctx.room, participant)


if __name__ == "__main__":
cli.run_app(WorkerOptions(entrypoint_fnc=entrypoint, worker_type=WorkerType.ROOM))
File renamed without changes.
1 change: 1 addition & 0 deletions livekit-agents/livekit/agents/cli/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"openai",
"watchfiles",
"anthropic",
"websockets.client",
]


Expand Down
14 changes: 12 additions & 2 deletions livekit-agents/livekit/agents/multimodal/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
from .multimodal_agent import AgentTranscriptionOptions, MultimodalAgent
from .multimodal_agent import (
AgentTranscriptionOptions,
MultimodalAgent,
_RealtimeAPI,
_RealtimeAPISession,
)

__all__ = ["MultimodalAgent", "AgentTranscriptionOptions"]
__all__ = [
"MultimodalAgent",
"AgentTranscriptionOptions",
"_RealtimeAPI",
"_RealtimeAPISession",
]
129 changes: 100 additions & 29 deletions livekit-agents/livekit/agents/multimodal/multimodal_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,17 @@

import asyncio
from dataclasses import dataclass
from typing import Callable, Literal, Protocol
from typing import (
Any,
AsyncIterable,
Callable,
Literal,
Optional,
Protocol,
TypeVar,
Union,
overload,
)

import aiohttp
from livekit import rtc
Expand All @@ -28,6 +38,76 @@
]


class _InputTranscriptionProto(Protocol):
item_id: str
"""id of the item"""
transcript: str
"""transcript of the input audio"""


class _ContentProto(Protocol):
response_id: str
item_id: str
output_index: int
content_index: int
text: str
audio: list[rtc.AudioFrame]
text_stream: AsyncIterable[str]
audio_stream: AsyncIterable[rtc.AudioFrame]
content_type: Literal["text", "audio"]


class _CapabilitiesProto(Protocol):
supports_truncate: bool


class _RealtimeAPI(Protocol):
"""Realtime API protocol"""

@property
def capabilities(self) -> _CapabilitiesProto: ...
def session(
self,
*,
chat_ctx: llm.ChatContext | None = None,
fnc_ctx: llm.FunctionContext | None = None,
) -> _RealtimeAPISession:
"""
Create a new realtime session with the given chat and function contexts.
"""
pass


T = TypeVar("T", bound=Callable[..., Any])


class _RealtimeAPISession(Protocol):
async def set_chat_ctx(self, ctx: llm.ChatContext) -> None: ...
@overload
def on(self, event: str, callback: None = None) -> Callable[[T], T]: ...
@overload
def on(self, event: str, callback: T) -> T: ...
def on(
self, event: str, callback: Optional[T] = None
) -> Union[T, Callable[[T], T]]: ...

def _push_audio(self, frame: rtc.AudioFrame) -> None: ...
@property
def fnc_ctx(self) -> llm.FunctionContext | None: ...
@fnc_ctx.setter
def fnc_ctx(self, value: llm.FunctionContext | None) -> None: ...
def chat_ctx_copy(self) -> llm.ChatContext: ...
def _recover_from_text_response(self, item_id: str) -> None: ...
def _update_conversation_item_content(
self,
item_id: str,
content: llm.ChatContent | list[llm.ChatContent] | None = None,
) -> None: ...
def _truncate_conversation_item(
self, item_id: str, content_index: int, audio_end_ms: int
) -> None: ...


@dataclass(frozen=True)
class AgentTranscriptionOptions:
user_transcription: bool = True
Expand All @@ -50,9 +130,6 @@ class AgentTranscriptionOptions:
representing the hyphenated parts of the word."""


class S2SModel(Protocol): ...


@dataclass(frozen=True)
class _ImplOptions:
transcription: AgentTranscriptionOptions
Expand All @@ -62,7 +139,7 @@ class MultimodalAgent(utils.EventEmitter[EventTypes]):
def __init__(
self,
*,
model: S2SModel,
model: _RealtimeAPI,
vad: vad.VAD | None = None,
chat_ctx: llm.ChatContext | None = None,
fnc_ctx: llm.FunctionContext | None = None,
Expand All @@ -73,7 +150,7 @@ def __init__(
"""Create a new MultimodalAgent.
Args:
model: S2SModel instance.
model: RealtimeAPI instance.
vad: Voice Activity Detection (VAD) instance.
chat_ctx: Chat context for the assistant.
fnc_ctx: Function context for the assistant.
Expand All @@ -89,10 +166,6 @@ def __init__(
super().__init__()
self._loop = loop or asyncio.get_event_loop()

from livekit.plugins.openai import realtime

assert isinstance(model, realtime.RealtimeModel)

self._model = model
self._vad = vad
self._chat_ctx = chat_ctx
Expand Down Expand Up @@ -177,13 +250,8 @@ async def _init_and_start():
# Schedule the initialization and start task
asyncio.create_task(_init_and_start())

from livekit.plugins.openai import realtime

@self._session.on("response_content_added")
def _on_content_added(message: realtime.RealtimeContent):
if message.content_type == "text":
return

def _on_content_added(message: _ContentProto):
tr_fwd = transcription.TTSSegmentsForwarder(
room=self._room,
participant=self._room.local_participant,
Expand All @@ -202,7 +270,7 @@ def _on_content_added(message: realtime.RealtimeContent):
)

@self._session.on("response_content_done")
def _response_content_done(message: realtime.RealtimeContent):
def _response_content_done(message: _ContentProto):
if message.content_type == "text":
if self._text_response_retries >= self._max_text_response_retries:
raise RuntimeError(
Expand Down Expand Up @@ -236,9 +304,7 @@ def _input_speech_committed():
)

@self._session.on("input_speech_transcription_completed")
def _input_speech_transcription_completed(
ev: realtime.InputTranscriptionCompleted,
):
def _input_speech_transcription_completed(ev: _InputTranscriptionProto):
self._stt_forwarder.update(
stt.SpeechEvent(
type=stt.SpeechEventType.FINAL_TRANSCRIPT,
Expand All @@ -248,6 +314,7 @@ def _input_speech_transcription_completed(
user_msg = ChatMessage.create(
text=ev.transcript, role="user", id=ev.item_id
)

self._session._update_conversation_item_content(
ev.item_id, user_msg.content
)
Expand All @@ -265,11 +332,14 @@ def _input_speech_started():
if self._playing_handle is not None and not self._playing_handle.done():
self._playing_handle.interrupt()

self._session.conversation.item.truncate(
item_id=self._playing_handle.item_id,
content_index=self._playing_handle.content_index,
audio_end_ms=int(self._playing_handle.audio_samples / 24000 * 1000),
)
if self._model.capabilities.supports_truncate:
self._session._truncate_conversation_item(
item_id=self._playing_handle.item_id,
content_index=self._playing_handle.content_index,
audio_end_ms=int(
self._playing_handle.audio_samples / 24000 * 1000
),
)

@self._session.on("input_speech_stopped")
def _input_speech_stopped():
Expand Down Expand Up @@ -330,9 +400,10 @@ def _on_playout_stopped(interrupted: bool) -> None:
role="assistant",
id=self._playing_handle.item_id,
)
self._session._update_conversation_item_content(
self._playing_handle.item_id, msg.content
)
if self._model.capabilities.supports_truncate:
self._session._update_conversation_item_content(
self._playing_handle.item_id, msg.content
)

if interrupted:
self.emit("agent_speech_interrupted", msg)
Expand Down Expand Up @@ -366,7 +437,7 @@ def _on_playout_stopped(interrupted: bool) -> None:
)
async for frame in self._input_audio_ch:
for f in bstream.write(frame.data.tobytes()):
self._session.input_audio_buffer.append(f)
self._session._push_audio(f)

def _on_participant_connected(self, participant: rtc.RemoteParticipant):
if self._linked_participant is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from . import beta
from .stt import STT, SpeechStream
from .tts import TTS
from .version import __version__

__all__ = ["STT", "TTS", "SpeechStream", "__version__"]

__all__ = ["STT", "TTS", "SpeechStream", "__version__", "beta"]
from livekit.agents import Plugin

from .log import logger
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from . import realtime

__all__ = ["realtime"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from .api_proto import (
ClientEvents,
LiveAPIModels,
ResponseModality,
Voice,
)
from .realtime_api import RealtimeModel

__all__ = [
"RealtimeModel",
"ClientEvents",
"LiveAPIModels",
"ResponseModality",
"Voice",
]
Loading

0 comments on commit b7f2895

Please sign in to comment.