Skip to content

Commit

Permalink
added out-of-band functionality and an additional VAD option
Browse files Browse the repository at this point in the history
  • Loading branch information
tinalenguyen committed Dec 31, 2024
1 parent b718710 commit 9fa5746
Showing 1 changed file with 73 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import base64
import os
import time
import weakref
from copy import deepcopy
from dataclasses import dataclass
from typing import AsyncIterable, Literal, Optional, Union, cast, overload
Expand Down Expand Up @@ -76,6 +75,8 @@ class RealtimeResponse:
"""timestamp when the response was created"""
_first_token_timestamp: float | None = None
"""timestamp when the first token was received"""
metadata: dict[str, str] | None = None
"""developer-provided string key-value pairs"""


@dataclass
Expand Down Expand Up @@ -106,11 +107,8 @@ class RealtimeToolCall:
"""id of the tool call"""


@dataclass
class Capabilities:
supports_truncate: bool


# TODO(theomonnom): add the content type directly inside RealtimeContent?
# text/audio/transcript?
@dataclass
class RealtimeContent:
response_id: str
Expand Down Expand Up @@ -140,6 +138,7 @@ class ServerVadOptions:
threshold: float
prefix_padding_ms: int
silence_duration_ms: int
create_response: bool


@dataclass
Expand Down Expand Up @@ -191,6 +190,7 @@ class _ContentPtr(TypedDict):
threshold=0.5,
prefix_padding_ms=300,
silence_duration_ms=500,
create_response=True,
)

DEFAULT_INPUT_AUDIO_TRANSCRIPTION = InputTranscriptionOptions(model="whisper-1")
Expand Down Expand Up @@ -288,9 +288,6 @@ def __init__(
ValueError: If the API key is not provided and cannot be found in environment variables.
"""
super().__init__()
self._capabilities = Capabilities(
supports_truncate=True,
)
self._base_url = base_url

is_azure = (
Expand Down Expand Up @@ -329,7 +326,7 @@ def __init__(
)

self._loop = loop or asyncio.get_event_loop()
self._rt_sessions = weakref.WeakSet[RealtimeSession]()
self._rt_sessions: list[RealtimeSession] = []
self._http_session = http_session

@classmethod
Expand Down Expand Up @@ -434,13 +431,9 @@ def _ensure_session(self) -> aiohttp.ClientSession:
return self._http_session

@property
def sessions(self) -> weakref.WeakSet[RealtimeSession]:
def sessions(self) -> list[RealtimeSession]:
return self._rt_sessions

@property
def capabilities(self) -> Capabilities:
return self._capabilities

def session(
self,
*,
Expand Down Expand Up @@ -486,7 +479,7 @@ def session(
http_session=self._ensure_session(),
loop=self._loop,
)
self._rt_sessions.add(new_session)
self._rt_sessions.append(new_session)
return new_session

async def aclose(self) -> None:
Expand Down Expand Up @@ -715,6 +708,10 @@ def create(
on_duplicate: Literal[
"cancel_existing", "cancel_new", "keep_both"
] = "keep_both",
instructions: str | None = None,
modalities: list[api_proto.Modality] = ["text", "audio"],
conversation: Literal["auto", "none"] = "auto",
metadata: map | None = None,
) -> asyncio.Future[bool]:
"""Creates a new response.
Expand All @@ -723,6 +720,12 @@ def create(
- "cancel_existing": Cancel the existing response before creating new one
- "cancel_new": Skip creating new response if one is in progress
- "keep_both": Wait for the existing response to be done and then create a new one
instructions: explicit prompt used for out-of-band events
modalities: set of modalities that the model can respond in
conversation: specifies whether respones is out-of-band
- "auto": Contents of the response will be added to the default conversation
- "none": Creates an out-of-band response which will not add items to default conversation
metadata: set of key-value pairs that can be used for storing additional information
Returns:
Future that resolves when the response create request is queued
Expand Down Expand Up @@ -756,7 +759,29 @@ def create(
or self._sess._pending_responses[active_resp_id].done_fut.done()
):
# no active response in progress, create a new one
self._sess._queue_msg({"type": "response.create"})
if instructions is not None:
self._sess._queue_msg(
{
"type": "response.create",
"response": {
"conversation": conversation,
"metadata": metadata,
"instructions": instructions,
"modalities": modalities,
},
}
)
else:
self._sess._queue_msg(
{
"type": "response.create",
"response": {
"conversation": conversation,
"metadata": metadata,
"modalities": modalities,
},
}
)
_fut = asyncio.Future[bool]()
_fut.set_result(True)
return _fut
Expand Down Expand Up @@ -793,7 +818,29 @@ async def wait_and_create() -> bool:
)
new_create_fut = asyncio.Future[None]()
self._sess._response_create_fut = new_create_fut
self._sess._queue_msg({"type": "response.create"})
if instructions is not None:
self._sess._queue_msg(
{
"type": "response.create",
"response": {
"conversation": conversation,
"metadata": metadata,
"instructions": instructions,
"modalities": modalities,
},
}
)
else:
self._sess._queue_msg(
{
"type": "response.create",
"response": {
"conversation": conversation,
"metadata": metadata,
"modalities": modalities,
},
}
)
return True

return asyncio.create_task(wait_and_create())
Expand Down Expand Up @@ -865,9 +912,6 @@ def conversation(self) -> Conversation:
def input_audio_buffer(self) -> InputAudioBuffer:
return RealtimeSession.InputAudioBuffer(self)

def _push_audio(self, frame: rtc.AudioFrame) -> None:
self.input_audio_buffer.append(frame)

@property
def response(self) -> Response:
return RealtimeSession.Response(self)
Expand Down Expand Up @@ -924,6 +968,7 @@ def session_update(
"threshold": self._opts.turn_detection.threshold,
"prefix_padding_ms": self._opts.turn_detection.prefix_padding_ms,
"silence_duration_ms": self._opts.turn_detection.silence_duration_ms,
"create_response": self._opts.turn_detection.create_response,
}
input_audio_transcription_opts: api_proto.InputAudioTranscription | None = None
if self._opts.input_audio_transcription is not None:
Expand Down Expand Up @@ -1037,15 +1082,6 @@ def _recover_from_text_response(self, item_id: str | None = None) -> None:
self.conversation.item.create(self._create_empty_user_audio_message(1.0))
self.response.create(on_duplicate="keep_both")

def _truncate_conversation_item(
self, item_id: str, content_index: int, audio_end_ms: int
) -> None:
self.conversation.item.truncate(
item_id=item_id,
content_index=content_index,
audio_end_ms=audio_end_ms,
)

def _update_conversation_item_content(
self, item_id: str, content: llm.ChatContent | list[llm.ChatContent] | None
) -> None:
Expand Down Expand Up @@ -1220,6 +1256,7 @@ def _handle_session_updated(
threshold=session["turn_detection"]["threshold"],
prefix_padding_ms=session["turn_detection"]["prefix_padding_ms"],
silence_duration_ms=session["turn_detection"]["silence_duration_ms"],
create_response=True,
)
if session["input_audio_transcription"] is None:
input_audio_transcription = None
Expand Down Expand Up @@ -1399,11 +1436,14 @@ def _handle_response_created(
response = response_created["response"]
done_fut = self._loop.create_future()
status_details = response.get("status_details")
metadata = response.get("metadata")

new_response = RealtimeResponse(
id=response["id"],
status=response["status"],
status_details=status_details,
output=[],
metadata=metadata,
usage=response.get("usage"),
done_fut=done_fut,
_created_timestamp=time.time(),
Expand Down Expand Up @@ -1578,6 +1618,8 @@ def _handle_response_done(self, response_done: api_proto.ServerEvent.ResponseDon

response.status = response_data["status"]
response.status_details = response_data.get("status_details")
response.metadata = response_data.get("metadata")
response.output = response_data.get("output")
response.usage = response_data.get("usage")

metrics_error = None
Expand Down Expand Up @@ -1685,7 +1727,7 @@ async def _run_fnc_task(self, fnc_call_info: llm.FunctionCallInfo, item_id: str)
"function": fnc_call_info.function_info.name,
},
)
if tool_call.content is not None:
if called_fnc.result is not None:
create_fut = self.conversation.item.create(
tool_call,
previous_item_id=item_id,
Expand Down

0 comments on commit 9fa5746

Please sign in to comment.