From 9b220d84772d752007501d2ad8db1790093f29c4 Mon Sep 17 00:00:00 2001 From: David Zhao Date: Mon, 6 Jan 2025 01:16:03 -0800 Subject: [PATCH 1/6] fix Google STT handling of session timeouts also reduces initial retry delay with STT connection errors --- livekit-agents/livekit/agents/stt/stt.py | 19 +++++---- .../livekit/plugins/google/stt.py | 40 ++++++++++++++++--- 2 files changed, 46 insertions(+), 13 deletions(-) diff --git a/livekit-agents/livekit/agents/stt/stt.py b/livekit-agents/livekit/agents/stt/stt.py index e2f79f93c..89c7c6ee3 100644 --- a/livekit-agents/livekit/agents/stt/stt.py +++ b/livekit-agents/livekit/agents/stt/stt.py @@ -210,15 +210,19 @@ def __init__( async def _run(self) -> None: ... async def _main_task(self) -> None: - for i in range(self._conn_options.max_retry + 1): + max_retries = self._conn_options.max_retry + num_retries = 0 + + while num_retries <= max_retries: try: - return await self._run() + await self._run() + num_retries = 0 except APIError as e: - if self._conn_options.max_retry == 0: + if max_retries == 0: raise - elif i == self._conn_options.max_retry: + elif num_retries == max_retries: raise APIConnectionError( - f"failed to recognize speech after {self._conn_options.max_retry + 1} attempts", + f"failed to recognize speech after {num_retries} attempts", ) from e else: logger.warning( @@ -226,12 +230,13 @@ async def _main_task(self) -> None: exc_info=e, extra={ "tts": self._stt._label, - "attempt": i + 1, + "attempt": num_retries, "streamed": True, }, ) - await asyncio.sleep(self._conn_options.retry_interval) + if num_retries > 0: + await asyncio.sleep(self._conn_options.retry_interval) async def _metrics_monitor_task( self, event_aiter: AsyncIterable[SpeechEvent] diff --git a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/stt.py b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/stt.py index 7fe2a527d..4675a2ee2 100644 --- a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/stt.py +++ b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/stt.py @@ -16,6 +16,7 @@ import asyncio import dataclasses +import time import weakref from dataclasses import dataclass from typing import List, Union @@ -44,6 +45,10 @@ LgType = Union[SpeechLanguages, str] LanguageCode = Union[LgType, List[LgType]] +# Google STT has a timeout of 5 mins, we'll attempt to restart the session +# before that timeout is reached +_max_session_duration = 4 + # This class is only be used internally to encapsulate the options @dataclass @@ -229,8 +234,6 @@ async def _recognize_impl( raise APIStatusError( e.message, status_code=e.code or -1, - request_id=None, - body=None, ) except Exception as e: raise APIConnectionError() from e @@ -312,6 +315,7 @@ def __init__( self._recognizer = recognizer self._config = config self._reconnect_event = asyncio.Event() + self._session_connected_at: float = 0 def update_options( self, @@ -347,7 +351,7 @@ def update_options( async def _run(self) -> None: # google requires a async generator when calling streaming_recognize # this function basically convert the queue into a async generator - async def input_generator(): + async def input_generator(should_stop: asyncio.Event): try: # first request should contain the config yield cloud_speech.StreamingRecognizeRequest( @@ -356,6 +360,12 @@ async def input_generator(): ) async for frame in self._input_ch: + # when the stream is aborted due to reconnect, this input_generator + # needs to stop consuming frames + # when the generator stops, the previous gRPC stream will close + if should_stop.is_set(): + return + if isinstance(frame, rtc.AudioFrame): yield cloud_speech.StreamingRecognizeRequest( audio=frame.data.tobytes() @@ -399,6 +409,13 @@ async def process_stream(stream): alternatives=[speech_data], ) ) + if ( + time.time() - self._session_connected_at + > _max_session_duration + ): + logger.debug("restarting session due to timeout") + self._reconnect_event.set() + return if ( resp.speech_event_type @@ -431,12 +448,15 @@ async def process_stream(stream): ), ) + should_stop = asyncio.Event() stream = await self._client.streaming_recognize( - requests=input_generator(), + requests=input_generator(should_stop), ) + self._session_connected_at = time.time() process_stream_task = asyncio.create_task(process_stream(stream)) wait_reconnect_task = asyncio.create_task(self._reconnect_event.wait()) + try: done, _ = await asyncio.wait( [process_stream_task, wait_reconnect_task], @@ -449,9 +469,17 @@ async def process_stream(stream): await utils.aio.gracefully_cancel( process_stream_task, wait_reconnect_task ) + should_stop.set() + except DeadlineExceeded: + raise APITimeoutError() + except GoogleAPICallError as e: + raise APIStatusError( + e.message, + status_code=e.code or -1, + ) + except Exception as e: + raise APIConnectionError() from e finally: - if not self._reconnect_event.is_set(): - break self._reconnect_event.clear() From 0718006b182205b9a6fa1d5eff225d3592bce2af Mon Sep 17 00:00:00 2001 From: jayesh Date: Mon, 6 Jan 2025 16:26:37 +0530 Subject: [PATCH 2/6] move `break` inside try block --- .../livekit-plugins-azure/livekit/plugins/azure/stt.py | 6 +++--- .../livekit-plugins-google/livekit/plugins/google/stt.py | 5 +++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/livekit-plugins/livekit-plugins-azure/livekit/plugins/azure/stt.py b/livekit-plugins/livekit-plugins-azure/livekit/plugins/azure/stt.py index 2bda776fd..59363114d 100644 --- a/livekit-plugins/livekit-plugins-azure/livekit/plugins/azure/stt.py +++ b/livekit-plugins/livekit-plugins-azure/livekit/plugins/azure/stt.py @@ -206,6 +206,9 @@ async def process_input(): for task in done: if task != wait_reconnect_task: task.result() + if wait_reconnect_task not in done: + break + self._reconnect_event.clear() finally: await utils.aio.gracefully_cancel( process_input_task, wait_reconnect_task @@ -220,9 +223,6 @@ def _cleanup(): del self._recognizer await asyncio.to_thread(_cleanup) - if not self._reconnect_event.is_set(): - break - self._reconnect_event.clear() def _on_recognized(self, evt: speechsdk.SpeechRecognitionEventArgs): detected_lg = speechsdk.AutoDetectSourceLanguageResult(evt.result).language diff --git a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/stt.py b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/stt.py index 4675a2ee2..41868810f 100644 --- a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/stt.py +++ b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/stt.py @@ -465,6 +465,9 @@ async def process_stream(stream): for task in done: if task != wait_reconnect_task: task.result() + if wait_reconnect_task not in done: + break + self._reconnect_event.clear() finally: await utils.aio.gracefully_cancel( process_stream_task, wait_reconnect_task @@ -479,8 +482,6 @@ async def process_stream(stream): ) except Exception as e: raise APIConnectionError() from e - finally: - self._reconnect_event.clear() def _recognize_response_to_speech_event( From 621de94928da9d485fc023d4bb5c4c16b38664c2 Mon Sep 17 00:00:00 2001 From: David Zhao Date: Mon, 6 Jan 2025 13:01:51 -0800 Subject: [PATCH 3/6] fix num_retries increment --- livekit-agents/livekit/agents/stt/stt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/livekit-agents/livekit/agents/stt/stt.py b/livekit-agents/livekit/agents/stt/stt.py index 89c7c6ee3..3dea525f1 100644 --- a/livekit-agents/livekit/agents/stt/stt.py +++ b/livekit-agents/livekit/agents/stt/stt.py @@ -234,9 +234,9 @@ async def _main_task(self) -> None: "streamed": True, }, ) - if num_retries > 0: await asyncio.sleep(self._conn_options.retry_interval) + num_retries += 1 async def _metrics_monitor_task( self, event_aiter: AsyncIterable[SpeechEvent] From 9ed565554c7ce71a234f75994f4c60d7fbfa3c40 Mon Sep 17 00:00:00 2001 From: David Zhao Date: Mon, 6 Jan 2025 14:18:41 -0800 Subject: [PATCH 4/6] fix infinite loop, faster initial retry across the board --- livekit-agents/livekit/agents/llm/llm.py | 5 +++-- livekit-agents/livekit/agents/stt/stt.py | 15 ++++++++------- livekit-agents/livekit/agents/tts/tts.py | 10 ++++++---- livekit-agents/livekit/agents/types.py | 10 ++++++++++ 4 files changed, 27 insertions(+), 13 deletions(-) diff --git a/livekit-agents/livekit/agents/llm/llm.py b/livekit-agents/livekit/agents/llm/llm.py index 099e3139c..0275f64c1 100644 --- a/livekit-agents/livekit/agents/llm/llm.py +++ b/livekit-agents/livekit/agents/llm/llm.py @@ -148,6 +148,7 @@ async def _main_task(self) -> None: try: return await self._run() except APIError as e: + retry_interval = self._conn_options._interval_for_retry(i) if self._conn_options.max_retry == 0 or not e.retryable: raise elif i == self._conn_options.max_retry: @@ -156,7 +157,7 @@ async def _main_task(self) -> None: ) from e else: logger.warning( - f"failed to generate LLM completion, retrying in {self._conn_options.retry_interval}s", + f"failed to generate LLM completion, retrying in {retry_interval}s", exc_info=e, extra={ "llm": self._llm._label, @@ -164,7 +165,7 @@ async def _main_task(self) -> None: }, ) - await asyncio.sleep(self._conn_options.retry_interval) + await asyncio.sleep(retry_interval) @utils.log_exceptions(logger=logger) async def _metrics_monitor_task( diff --git a/livekit-agents/livekit/agents/stt/stt.py b/livekit-agents/livekit/agents/stt/stt.py index 3dea525f1..a0956e621 100644 --- a/livekit-agents/livekit/agents/stt/stt.py +++ b/livekit-agents/livekit/agents/stt/stt.py @@ -119,6 +119,7 @@ async def recognize( return event except APIError as e: + retry_interval = conn_options._interval_for_retry(i) if conn_options.max_retry == 0: raise elif i == conn_options.max_retry: @@ -127,7 +128,7 @@ async def recognize( ) from e else: logger.warning( - f"failed to recognize speech, retrying in {conn_options.retry_interval}s", + f"failed to recognize speech, retrying in {retry_interval}s", exc_info=e, extra={ "tts": self._label, @@ -136,7 +137,7 @@ async def recognize( }, ) - await asyncio.sleep(conn_options.retry_interval) + await asyncio.sleep(retry_interval) raise RuntimeError("unreachable") @@ -215,8 +216,7 @@ async def _main_task(self) -> None: while num_retries <= max_retries: try: - await self._run() - num_retries = 0 + return await self._run() except APIError as e: if max_retries == 0: raise @@ -225,8 +225,9 @@ async def _main_task(self) -> None: f"failed to recognize speech after {num_retries} attempts", ) from e else: + retry_interval = self._conn_options._interval_for_retry(num_retries) logger.warning( - f"failed to recognize speech, retrying in {self._conn_options.retry_interval}s", + f"failed to recognize speech, retrying in {retry_interval}s", exc_info=e, extra={ "tts": self._stt._label, @@ -234,8 +235,8 @@ async def _main_task(self) -> None: "streamed": True, }, ) - if num_retries > 0: - await asyncio.sleep(self._conn_options.retry_interval) + await asyncio.sleep(retry_interval) + num_retries += 1 async def _metrics_monitor_task( diff --git a/livekit-agents/livekit/agents/tts/tts.py b/livekit-agents/livekit/agents/tts/tts.py index e641bf39d..a4e1f2089 100644 --- a/livekit-agents/livekit/agents/tts/tts.py +++ b/livekit-agents/livekit/agents/tts/tts.py @@ -178,6 +178,7 @@ async def _main_task(self) -> None: try: return await self._run() except APIError as e: + retry_interval = self._conn_options._interval_for_retry(i) if self._conn_options.max_retry == 0: raise elif i == self._conn_options.max_retry: @@ -186,7 +187,7 @@ async def _main_task(self) -> None: ) from e else: logger.warning( - f"failed to synthesize speech, retrying in {self._conn_options.retry_interval}s", + f"failed to synthesize speech, retrying in {retry_interval}s", exc_info=e, extra={ "tts": self._tts._label, @@ -195,7 +196,7 @@ async def _main_task(self) -> None: }, ) - await asyncio.sleep(self._conn_options.retry_interval) + await asyncio.sleep(retry_interval) async def aclose(self) -> None: """Close is automatically called if the stream is completely collected""" @@ -258,6 +259,7 @@ async def _main_task(self) -> None: try: return await self._run() except APIError as e: + retry_interval = self._conn_options._interval_for_retry(i) if self._conn_options.max_retry == 0: raise elif i == self._conn_options.max_retry: @@ -266,7 +268,7 @@ async def _main_task(self) -> None: ) from e else: logger.warning( - f"failed to synthesize speech, retrying in {self._conn_options.retry_interval}s", + f"failed to synthesize speech, retrying in {retry_interval}s", exc_info=e, extra={ "tts": self._tts._label, @@ -275,7 +277,7 @@ async def _main_task(self) -> None: }, ) - await asyncio.sleep(self._conn_options.retry_interval) + await asyncio.sleep(retry_interval) async def _metrics_monitor_task( self, event_aiter: AsyncIterable[SynthesizedAudio] diff --git a/livekit-agents/livekit/agents/types.py b/livekit-agents/livekit/agents/types.py index 0cb76882a..d7378c3ea 100644 --- a/livekit-agents/livekit/agents/types.py +++ b/livekit-agents/livekit/agents/types.py @@ -57,5 +57,15 @@ def __post_init__(self): if self.timeout < 0: raise ValueError("timeout must be greater than or equal to 0") + def _interval_for_retry(self, num_retries: int) -> float: + """ + Return the interval for the given number of retries. + + The first retry is immediate, and then uses specified retry_interval + """ + if num_retries == 0: + return 0.1 + return self.retry_interval + DEFAULT_API_CONNECT_OPTIONS = APIConnectOptions() From 9cb35be97c2b0a76b8b2672cd518368d26801b8b Mon Sep 17 00:00:00 2001 From: David Zhao Date: Mon, 6 Jan 2025 14:23:25 -0800 Subject: [PATCH 5/6] changeset --- .changeset/funny-buttons-taste.md | 6 ++++++ .changeset/wild-planes-cheer.md | 5 +++++ 2 files changed, 11 insertions(+) create mode 100644 .changeset/funny-buttons-taste.md create mode 100644 .changeset/wild-planes-cheer.md diff --git a/.changeset/funny-buttons-taste.md b/.changeset/funny-buttons-taste.md new file mode 100644 index 000000000..97ee4cc39 --- /dev/null +++ b/.changeset/funny-buttons-taste.md @@ -0,0 +1,6 @@ +--- +"livekit-plugins-azure": patch +"livekit-agents": patch +--- + +reduces initial delay before model retries diff --git a/.changeset/wild-planes-cheer.md b/.changeset/wild-planes-cheer.md new file mode 100644 index 000000000..00c68d562 --- /dev/null +++ b/.changeset/wild-planes-cheer.md @@ -0,0 +1,5 @@ +--- +"livekit-plugins-google": patch +--- + +fix Google STT handling of session timeouts From e973c65d65b4a991c1bf17dd40b208423ce33809 Mon Sep 17 00:00:00 2001 From: David Zhao Date: Mon, 6 Jan 2025 15:01:56 -0800 Subject: [PATCH 6/6] send end-of-speech event if final transcript came before --- .../livekit/plugins/google/stt.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/stt.py b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/stt.py index 41868810f..5afa411c3 100644 --- a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/stt.py +++ b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/stt.py @@ -377,6 +377,7 @@ async def input_generator(should_stop: asyncio.Event): ) async def process_stream(stream): + has_started = False async for resp in stream: if ( resp.speech_event_type @@ -385,6 +386,7 @@ async def process_stream(stream): self._event_ch.send_nowait( stt.SpeechEvent(type=stt.SpeechEventType.START_OF_SPEECH) ) + has_started = True if ( resp.speech_event_type @@ -414,6 +416,13 @@ async def process_stream(stream): > _max_session_duration ): logger.debug("restarting session due to timeout") + if has_started: + self._event_ch.send_nowait( + stt.SpeechEvent( + type=stt.SpeechEventType.END_OF_SPEECH + ) + ) + has_started = False self._reconnect_event.set() return @@ -424,6 +433,7 @@ async def process_stream(stream): self._event_ch.send_nowait( stt.SpeechEvent(type=stt.SpeechEventType.END_OF_SPEECH) ) + has_started = False while True: try: