Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix Google STT handling of session timeouts #1337

Merged
merged 6 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changeset/funny-buttons-taste.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
"livekit-plugins-azure": patch
"livekit-agents": patch
---

reduces initial delay before model retries
5 changes: 5 additions & 0 deletions .changeset/wild-planes-cheer.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"livekit-plugins-google": patch
---

fix Google STT handling of session timeouts
5 changes: 3 additions & 2 deletions livekit-agents/livekit/agents/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ async def _main_task(self) -> None:
try:
return await self._run()
except APIError as e:
retry_interval = self._conn_options._interval_for_retry(i)
if self._conn_options.max_retry == 0 or not e.retryable:
raise
elif i == self._conn_options.max_retry:
Expand All @@ -156,15 +157,15 @@ async def _main_task(self) -> None:
) from e
else:
logger.warning(
f"failed to generate LLM completion, retrying in {self._conn_options.retry_interval}s",
f"failed to generate LLM completion, retrying in {retry_interval}s",
exc_info=e,
extra={
"llm": self._llm._label,
"attempt": i + 1,
},
)

await asyncio.sleep(self._conn_options.retry_interval)
await asyncio.sleep(retry_interval)

@utils.log_exceptions(logger=logger)
async def _metrics_monitor_task(
Expand Down
24 changes: 15 additions & 9 deletions livekit-agents/livekit/agents/stt/stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ async def recognize(
return event

except APIError as e:
retry_interval = conn_options._interval_for_retry(i)
if conn_options.max_retry == 0:
raise
elif i == conn_options.max_retry:
Expand All @@ -127,7 +128,7 @@ async def recognize(
) from e
else:
logger.warning(
f"failed to recognize speech, retrying in {conn_options.retry_interval}s",
f"failed to recognize speech, retrying in {retry_interval}s",
exc_info=e,
extra={
"tts": self._label,
Expand All @@ -136,7 +137,7 @@ async def recognize(
},
)

await asyncio.sleep(conn_options.retry_interval)
await asyncio.sleep(retry_interval)

raise RuntimeError("unreachable")

Expand Down Expand Up @@ -210,28 +211,33 @@ def __init__(
async def _run(self) -> None: ...

async def _main_task(self) -> None:
for i in range(self._conn_options.max_retry + 1):
max_retries = self._conn_options.max_retry
num_retries = 0

while num_retries <= max_retries:
try:
return await self._run()
except APIError as e:
if self._conn_options.max_retry == 0:
if max_retries == 0:
raise
elif i == self._conn_options.max_retry:
elif num_retries == max_retries:
raise APIConnectionError(
f"failed to recognize speech after {self._conn_options.max_retry + 1} attempts",
f"failed to recognize speech after {num_retries} attempts",
) from e
else:
retry_interval = self._conn_options._interval_for_retry(num_retries)
logger.warning(
f"failed to recognize speech, retrying in {self._conn_options.retry_interval}s",
f"failed to recognize speech, retrying in {retry_interval}s",
exc_info=e,
extra={
"tts": self._stt._label,
"attempt": i + 1,
"attempt": num_retries,
davidzhao marked this conversation as resolved.
Show resolved Hide resolved
"streamed": True,
},
)
davidzhao marked this conversation as resolved.
Show resolved Hide resolved
await asyncio.sleep(retry_interval)

await asyncio.sleep(self._conn_options.retry_interval)
num_retries += 1

async def _metrics_monitor_task(
self, event_aiter: AsyncIterable[SpeechEvent]
Expand Down
10 changes: 6 additions & 4 deletions livekit-agents/livekit/agents/tts/tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ async def _main_task(self) -> None:
try:
return await self._run()
except APIError as e:
retry_interval = self._conn_options._interval_for_retry(i)
if self._conn_options.max_retry == 0:
raise
elif i == self._conn_options.max_retry:
Expand All @@ -186,7 +187,7 @@ async def _main_task(self) -> None:
) from e
else:
logger.warning(
f"failed to synthesize speech, retrying in {self._conn_options.retry_interval}s",
f"failed to synthesize speech, retrying in {retry_interval}s",
exc_info=e,
extra={
"tts": self._tts._label,
Expand All @@ -195,7 +196,7 @@ async def _main_task(self) -> None:
},
)

await asyncio.sleep(self._conn_options.retry_interval)
await asyncio.sleep(retry_interval)

async def aclose(self) -> None:
"""Close is automatically called if the stream is completely collected"""
Expand Down Expand Up @@ -258,6 +259,7 @@ async def _main_task(self) -> None:
try:
return await self._run()
except APIError as e:
retry_interval = self._conn_options._interval_for_retry(i)
if self._conn_options.max_retry == 0:
raise
elif i == self._conn_options.max_retry:
Expand All @@ -266,7 +268,7 @@ async def _main_task(self) -> None:
) from e
else:
logger.warning(
f"failed to synthesize speech, retrying in {self._conn_options.retry_interval}s",
f"failed to synthesize speech, retrying in {retry_interval}s",
exc_info=e,
extra={
"tts": self._tts._label,
Expand All @@ -275,7 +277,7 @@ async def _main_task(self) -> None:
},
)

await asyncio.sleep(self._conn_options.retry_interval)
await asyncio.sleep(retry_interval)

async def _metrics_monitor_task(
self, event_aiter: AsyncIterable[SynthesizedAudio]
Expand Down
10 changes: 10 additions & 0 deletions livekit-agents/livekit/agents/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,5 +57,15 @@ def __post_init__(self):
if self.timeout < 0:
raise ValueError("timeout must be greater than or equal to 0")

def _interval_for_retry(self, num_retries: int) -> float:
"""
Return the interval for the given number of retries.

The first retry is immediate, and then uses specified retry_interval
"""
if num_retries == 0:
return 0.1
return self.retry_interval


DEFAULT_API_CONNECT_OPTIONS = APIConnectOptions()
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,9 @@ async def process_input():
for task in done:
if task != wait_reconnect_task:
task.result()
if wait_reconnect_task not in done:
break
self._reconnect_event.clear()
finally:
await utils.aio.gracefully_cancel(
process_input_task, wait_reconnect_task
Expand All @@ -220,9 +223,6 @@ def _cleanup():
del self._recognizer

await asyncio.to_thread(_cleanup)
if not self._reconnect_event.is_set():
break
self._reconnect_event.clear()

def _on_recognized(self, evt: speechsdk.SpeechRecognitionEventArgs):
detected_lg = speechsdk.AutoDetectSourceLanguageResult(evt.result).language
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import asyncio
import dataclasses
import time
import weakref
from dataclasses import dataclass
from typing import List, Union
Expand Down Expand Up @@ -44,6 +45,10 @@
LgType = Union[SpeechLanguages, str]
LanguageCode = Union[LgType, List[LgType]]

# Google STT has a timeout of 5 mins, we'll attempt to restart the session
# before that timeout is reached
_max_session_duration = 4
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like 4 seconds to me



# This class is only be used internally to encapsulate the options
@dataclass
Expand Down Expand Up @@ -229,8 +234,6 @@ async def _recognize_impl(
raise APIStatusError(
e.message,
status_code=e.code or -1,
request_id=None,
body=None,
)
except Exception as e:
raise APIConnectionError() from e
Expand Down Expand Up @@ -312,6 +315,7 @@ def __init__(
self._recognizer = recognizer
self._config = config
self._reconnect_event = asyncio.Event()
self._session_connected_at: float = 0

def update_options(
self,
Expand Down Expand Up @@ -347,7 +351,7 @@ def update_options(
async def _run(self) -> None:
# google requires a async generator when calling streaming_recognize
# this function basically convert the queue into a async generator
async def input_generator():
async def input_generator(should_stop: asyncio.Event):
try:
# first request should contain the config
yield cloud_speech.StreamingRecognizeRequest(
Expand All @@ -356,6 +360,12 @@ async def input_generator():
)

async for frame in self._input_ch:
# when the stream is aborted due to reconnect, this input_generator
# needs to stop consuming frames
# when the generator stops, the previous gRPC stream will close
if should_stop.is_set():
return

if isinstance(frame, rtc.AudioFrame):
yield cloud_speech.StreamingRecognizeRequest(
audio=frame.data.tobytes()
Expand All @@ -367,6 +377,7 @@ async def input_generator():
)

async def process_stream(stream):
has_started = False
async for resp in stream:
if (
resp.speech_event_type
Expand All @@ -375,6 +386,7 @@ async def process_stream(stream):
self._event_ch.send_nowait(
stt.SpeechEvent(type=stt.SpeechEventType.START_OF_SPEECH)
)
has_started = True

if (
resp.speech_event_type
Expand All @@ -399,6 +411,20 @@ async def process_stream(stream):
alternatives=[speech_data],
)
)
if (
davidzhao marked this conversation as resolved.
Show resolved Hide resolved
time.time() - self._session_connected_at
> _max_session_duration
):
logger.debug("restarting session due to timeout")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unnecessary log? Should be handled seamlessly IMO

if has_started:
self._event_ch.send_nowait(
stt.SpeechEvent(
type=stt.SpeechEventType.END_OF_SPEECH
)
)
has_started = False
self._reconnect_event.set()
return

if (
resp.speech_event_type
Expand All @@ -407,6 +433,7 @@ async def process_stream(stream):
self._event_ch.send_nowait(
stt.SpeechEvent(type=stt.SpeechEventType.END_OF_SPEECH)
)
has_started = False

while True:
try:
Expand All @@ -431,12 +458,15 @@ async def process_stream(stream):
),
)

should_stop = asyncio.Event()
stream = await self._client.streaming_recognize(
requests=input_generator(),
requests=input_generator(should_stop),
)
self._session_connected_at = time.time()

process_stream_task = asyncio.create_task(process_stream(stream))
wait_reconnect_task = asyncio.create_task(self._reconnect_event.wait())

try:
done, _ = await asyncio.wait(
[process_stream_task, wait_reconnect_task],
Expand All @@ -445,14 +475,23 @@ async def process_stream(stream):
for task in done:
if task != wait_reconnect_task:
task.result()
if wait_reconnect_task not in done:
break
self._reconnect_event.clear()
finally:
await utils.aio.gracefully_cancel(
process_stream_task, wait_reconnect_task
)
finally:
if not self._reconnect_event.is_set():
break
davidzhao marked this conversation as resolved.
Show resolved Hide resolved
self._reconnect_event.clear()
should_stop.set()
except DeadlineExceeded:
raise APITimeoutError()
except GoogleAPICallError as e:
raise APIStatusError(
e.message,
status_code=e.code or -1,
)
except Exception as e:
raise APIConnectionError() from e


def _recognize_response_to_speech_event(
Expand Down
Loading