Skip to content

Commit

Permalink
Merge pull request #393 from ag2ai/realtime-webrtc-ag2-repo
Browse files Browse the repository at this point in the history
RealtimeAgent WebRTC
  • Loading branch information
davorrunje authored Jan 8, 2025
2 parents a1e9dc3 + 12b2e50 commit 984897e
Show file tree
Hide file tree
Showing 6 changed files with 639 additions and 8 deletions.
184 changes: 182 additions & 2 deletions autogen/agentchat/realtime_agent/oai_realtime_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,21 @@
#
# SPDX-License-Identifier: Apache-2.0

import asyncio
import json
from contextlib import asynccontextmanager
from logging import Logger, getLogger
from typing import TYPE_CHECKING, Any, AsyncGenerator, Optional

from asyncer import TaskGroup, create_task_group
import httpx
from openai import DEFAULT_MAX_RETRIES, NOT_GIVEN, AsyncOpenAI
from openai.resources.beta.realtime.realtime import AsyncRealtimeConnection

from .realtime_client import Role

if TYPE_CHECKING:
from fastapi.websockets import WebSocket

from .realtime_client import RealtimeClientProtocol

__all__ = ["OpenAIRealtimeClient", "Role"]
Expand Down Expand Up @@ -168,8 +172,184 @@ async def read_events(self) -> AsyncGenerator[dict[str, Any], None]:
self._connection = None


# needed for mypy to check if OpenAIRealtimeClient implements RealtimeClientProtocol
class OpenAIRealtimeWebRTCClient:
"""(Experimental) Client for OpenAI Realtime API that uses WebRTC protocol."""

def __init__(
self,
*,
llm_config: dict[str, Any],
voice: str,
system_message: str,
websocket: "WebSocket",
logger: Optional[Logger] = None,
) -> None:
"""(Experimental) Client for OpenAI Realtime API.
Args:
llm_config (dict[str, Any]): The config for the client.
"""
self._llm_config = llm_config
self._voice = voice
self._system_message = system_message
self._logger = logger
self._websocket = websocket

config = llm_config["config_list"][0]
self._model: str = config["model"]
self._temperature: float = llm_config.get("temperature", 0.8) # type: ignore[union-attr]
self._config = config

@property
def logger(self) -> Logger:
"""Get the logger for the OpenAI Realtime API."""
return self._logger or global_logger

async def send_function_result(self, call_id: str, result: str) -> None:
"""Send the result of a function call to the OpenAI Realtime API.
Args:
call_id (str): The ID of the function call.
result (str): The result of the function call.
"""
await self._websocket.send_json(
{
"type": "conversation.item.create",
"item": {
"type": "function_call_output",
"call_id": call_id,
"output": result,
},
}
)
await self._websocket.send_json({"type": "response.create"})

async def send_text(self, *, role: Role, text: str) -> None:
"""Send a text message to the OpenAI Realtime API.
Args:
role (str): The role of the message.
text (str): The text of the message.
"""
# await self.connection.response.cancel() #why is this here?
await self._websocket.send_json(
{
"type": "connection.conversation.item.create",
"item": {"type": "message", "role": role, "content": [{"type": "input_text", "text": text}]},
}
)
# await self.connection.response.create()

async def send_audio(self, audio: str) -> None:
"""Send audio to the OpenAI Realtime API.
Args:
audio (str): The audio to send.
"""
await self._websocket.send_json({"type": "input_audio_buffer.append", "audio": audio})

async def truncate_audio(self, audio_end_ms: int, content_index: int, item_id: str) -> None:
"""Truncate audio in the OpenAI Realtime API.
Args:
audio_end_ms (int): The end of the audio to truncate.
content_index (int): The index of the content to truncate.
item_id (str): The ID of the item to truncate.
"""
await self._websocket.send_json(
{
"type": "conversation.item.truncate",
"content_index": content_index,
"item_id": item_id,
"audio_end_ms": audio_end_ms,
}
)

async def session_update(self, session_options: dict[str, Any]) -> None:
"""Send a session update to the OpenAI Realtime API.
In the case of WebRTC we can not send it directly, but we can send it
to the javascript over the websocket, and rely on it to send session
update to OpenAI
Args:
session_options (dict[str, Any]): The session options to update.
"""
logger = self.logger
logger.info(f"Sending session update: {session_options}")
# await self.connection.session.update(session=session_options) # type: ignore[arg-type]
await self._websocket.send_json({"type": "session.update", "session": session_options})
logger.info("Sending session update finished")

async def _initialize_session(self) -> None:
"""Control initial session with OpenAI."""
session_update = {
"turn_detection": {"type": "server_vad"},
"voice": self._voice,
"instructions": self._system_message,
"modalities": ["audio", "text"],
"temperature": self._temperature,
}
await self.session_update(session_options=session_update)

@asynccontextmanager
async def connect(self) -> AsyncGenerator[None, None]:
"""Connect to the OpenAI Realtime API.
In the case of WebRTC, we pass connection information over the
websocket, so that javascript on the other end of websocket open
actual connection to OpenAI
"""
try:
url = "https://api.openai.com/v1/realtime/sessions"
api_key = self._config.get("api_key", None)
headers = {
"Authorization": f"Bearer {api_key}", # Use os.getenv to get from environment
"Content-Type": "application/json",
}
data = {
# "model": "gpt-4o-realtime-preview-2024-12-17",
"model": self._model,
"voice": self._voice,
}
async with httpx.AsyncClient() as client:
response = await client.post(url, headers=headers, json=data)
response.raise_for_status()
json_data = response.json()
json_data["model"] = self._model
if self._websocket is not None:
await self._websocket.send_json({"type": "ag2.init", "config": json_data})
await asyncio.sleep(10)
await self._initialize_session()
yield
finally:
pass

async def read_events(self) -> AsyncGenerator[dict[str, Any], None]:
"""Read messages from the OpenAI Realtime API.
Again, in case of WebRTC, we do not read OpenAI messages directly since we
do not hold connection to OpenAI. Instead we read messages from the websocket, and javascript
client on the other side of the websocket that is connected to OpenAI is relaying events to us.
"""
logger = self.logger
while True:
try:
messageJSON = await self._websocket.receive_text()
message = json.loads(messageJSON)
if "function" in message["type"]:
logger.info("Received function message", message)
yield message
except Exception:
break


# needed for mypy to check if OpenAIRealtimeWebRTCClient implements RealtimeClientProtocol
if TYPE_CHECKING:
_client: RealtimeClientProtocol = OpenAIRealtimeClient(
llm_config={}, voice="alloy", system_message="You are a helpful AI voice assistant."
)

def _rtc_client(websocket: "WebSocket") -> RealtimeClientProtocol:
return OpenAIRealtimeWebRTCClient(
llm_config={}, voice="alloy", system_message="You are a helpful AI voice assistant.", websocket=websocket
)
25 changes: 19 additions & 6 deletions autogen/agentchat/realtime_agent/realtime_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,16 @@

import anyio
from asyncer import create_task_group, syncify
from fastapi import WebSocket

from autogen.agentchat.realtime_agent.realtime_client import RealtimeClientProtocol

from ... import SwarmAgent
from ...tools import Tool, get_function_schema
from ..agent import Agent
from ..conversable_agent import ConversableAgent
from .function_observer import FunctionObserver
from .oai_realtime_client import OpenAIRealtimeClient, Role
from .oai_realtime_client import OpenAIRealtimeClient, OpenAIRealtimeWebRTCClient, Role
from .realtime_observer import RealtimeObserver

F = TypeVar("F", bound=Callable[..., Any])
Expand Down Expand Up @@ -42,20 +45,22 @@ def __init__(
self,
*,
name: str,
audio_adapter: RealtimeObserver,
audio_adapter: Optional[RealtimeObserver] = None,
system_message: str = "You are a helpful AI Assistant.",
llm_config: dict[str, Any],
voice: str = "alloy",
logger: Optional[Logger] = None,
websocket: Optional[WebSocket] = None,
):
"""(Experimental) Agent for interacting with the Realtime Clients.
Args:
name (str): The name of the agent.
audio_adapter (RealtimeObserver): The audio adapter for the agent.
audio_adapter (Optional[RealtimeObserver] = None): The audio adapter for the agent.
system_message (str): The system message for the agent.
llm_config (dict[str, Any], bool): The config for the agent.
voice (str): The voice for the agent.
websocket (Optional[WebSocket] = None): WebSocket from WebRTC javascript client
"""
super().__init__(
name=name,
Expand All @@ -75,12 +80,20 @@ def __init__(
self._logger = logger
self._function_observer = FunctionObserver(logger=logger)
self._audio_adapter = audio_adapter
self._realtime_client = OpenAIRealtimeClient(
self._realtime_client: RealtimeClientProtocol = OpenAIRealtimeClient(
llm_config=llm_config, voice=voice, system_message=system_message, logger=logger
)
if websocket is not None:
self._realtime_client = OpenAIRealtimeWebRTCClient(
llm_config=llm_config, voice=voice, system_message=system_message, websocket=websocket, logger=logger
)

self._voice = voice

self._observers: list[RealtimeObserver] = [self._function_observer, self._audio_adapter]
self._observers: list[RealtimeObserver] = [self._function_observer]
if self._audio_adapter:
# audio adapter is not needed for WebRTC
self._observers.append(self._audio_adapter)

self._registred_realtime_tools: dict[str, Tool] = {}

Expand All @@ -102,7 +115,7 @@ def logger(self) -> Logger:
return self._logger or global_logger

@property
def realtime_client(self) -> OpenAIRealtimeClient:
def realtime_client(self) -> RealtimeClientProtocol:
"""Get the OpenAI Realtime Client."""
return self._realtime_client

Expand Down
Loading

0 comments on commit 984897e

Please sign in to comment.