Skip to content

Commit

Permalink
improved handling of LLM errors, do not retry if already began (#1298)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidzhao authored Dec 25, 2024
1 parent baae79b commit 3238393
Show file tree
Hide file tree
Showing 8 changed files with 54 additions and 18 deletions.
7 changes: 7 additions & 0 deletions .changeset/gorgeous-sheep-grow.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
"livekit-plugins-anthropic": patch
"livekit-plugins-openai": patch
"livekit-agents": patch
---

improved handling of LLM errors, do not retry if already began
31 changes: 24 additions & 7 deletions livekit-agents/livekit/agents/_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,22 @@ class APIError(Exception):
body: object | None
"""The API response body, if available.
If the API returned a valid json, the body will contains
the decodede result.
"""

def __init__(self, message: str, *, body: object | None) -> None:
retryable: bool = False
"""Whether the error can be retried."""

def __init__(
self, message: str, *, body: object | None, retryable: bool = True
) -> None:
super().__init__(message)

self.message = message
self.body = body
self.retryable = retryable


class APIStatusError(APIError):
Expand All @@ -51,8 +57,15 @@ def __init__(
status_code: int = -1,
request_id: str | None = None,
body: object | None = None,
retryable: bool | None = None,
) -> None:
super().__init__(message, body=body)
if retryable is None:
retryable = True
# 4xx errors are not retryable
if status_code >= 400 and status_code < 500:
retryable = False

super().__init__(message, body=body, retryable=retryable)

self.status_code = status_code
self.request_id = request_id
Expand All @@ -61,12 +74,16 @@ def __init__(
class APIConnectionError(APIError):
"""Raised when an API request failed due to a connection error."""

def __init__(self, message: str = "Connection error.") -> None:
super().__init__(message, body=None)
def __init__(
self, message: str = "Connection error.", *, retryable: bool = True
) -> None:
super().__init__(message, body=None, retryable=retryable)


class APITimeoutError(APIConnectionError):
"""Raised when an API request timed out."""

def __init__(self, message: str = "Request timed out.") -> None:
super().__init__(message)
def __init__(
self, message: str = "Request timed out.", *, retryable: bool = True
) -> None:
super().__init__(message, retryable=retryable)
2 changes: 1 addition & 1 deletion livekit-agents/livekit/agents/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ async def _main_task(self) -> None:
try:
return await self._run()
except APIError as e:
if self._conn_options.max_retry == 0:
if self._conn_options.max_retry == 0 or not e.retryable:
raise
elif i == self._conn_options.max_retry:
raise APIConnectionError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ def __init__(
self._output_tokens = 0

async def _run(self) -> None:
retryable = True
try:
if not self._anthropic_stream:
self._anthropic_stream = await self._awaitable_anthropic_stream
Expand All @@ -215,6 +216,7 @@ async def _run(self) -> None:
chat_chunk = self._parse_event(event)
if chat_chunk is not None:
self._event_ch.send_nowait(chat_chunk)
retryable = False

self._event_ch.send_nowait(
llm.ChatChunk(
Expand All @@ -227,7 +229,7 @@ async def _run(self) -> None:
)
)
except anthropic.APITimeoutError:
raise APITimeoutError()
raise APITimeoutError(retryable=retryable)
except anthropic.APIStatusError as e:
raise APIStatusError(
e.message,
Expand All @@ -236,7 +238,7 @@ async def _run(self) -> None:
body=e.body,
)
except Exception as e:
raise APIConnectionError() from e
raise APIConnectionError(retryable=retryable) from e

def _parse_event(
self, event: anthropic.types.RawMessageStreamEvent
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,10 @@ def _send_last_frame(*, segment_id: str, is_final: bool) -> None:
aiohttp.WSMsgType.CLOSE,
aiohttp.WSMsgType.CLOSING,
):
raise Exception("Cartesia connection closed unexpectedly")
raise APIStatusError(
"Cartesia connection closed unexpectedly",
request_id=request_id,
)

if msg.type != aiohttp.WSMsgType.TEXT:
logger.warning("unexpected Cartesia message type %s", msg.type)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -314,8 +314,9 @@ def _send_last_frame(*, segment_id: str, is_final: bool) -> None:
aiohttp.WSMsgType.CLOSING,
):
if not closing_ws:
raise Exception(
"Deepgram websocket connection closed unexpectedly"
raise APIStatusError(
"Deepgram websocket connection closed unexpectedly",
request_id=request_id,
)
return

Expand Down Expand Up @@ -393,7 +394,10 @@ async def _connection_timeout():
raise APITimeoutError() from e
except aiohttp.ClientResponseError as e:
raise APIStatusError(
message=e.message, status_code=e.status, request_id=None, body=None
message=e.message,
status_code=e.status,
request_id=request_id,
body=None,
) from e
except Exception as e:
raise APIConnectionError() from e
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -469,8 +469,9 @@ def _send_last_frame(*, segment_id: str, is_final: bool) -> None:
aiohttp.WSMsgType.CLOSING,
):
if not eos_sent:
raise Exception(
"11labs connection closed unexpectedly, not all tokens have been consumed"
raise APIStatusError(
"11labs connection closed unexpectedly, not all tokens have been consumed",
request_id=request_id,
)
return

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,7 @@ async def _run(self) -> None:
self._fnc_name: str | None = None
self._fnc_raw_arguments: str | None = None
self._tool_index: int | None = None
retryable = True

try:
opts: dict[str, Any] = dict()
Expand Down Expand Up @@ -755,6 +756,7 @@ async def _run(self) -> None:
for choice in chunk.choices:
chat_chunk = self._parse_choice(chunk.id, choice)
if chat_chunk is not None:
retryable = False
self._event_ch.send_nowait(chat_chunk)

if chunk.usage is not None:
Expand All @@ -771,7 +773,7 @@ async def _run(self) -> None:
)

except openai.APITimeoutError:
raise APITimeoutError()
raise APITimeoutError(retryable=retryable)
except openai.APIStatusError as e:
raise APIStatusError(
e.message,
Expand All @@ -780,7 +782,7 @@ async def _run(self) -> None:
body=e.body,
)
except Exception as e:
raise APIConnectionError() from e
raise APIConnectionError(retryable=retryable) from e

def _parse_choice(self, id: str, choice: Choice) -> llm.ChatChunk | None:
delta = choice.delta
Expand Down

0 comments on commit 3238393

Please sign in to comment.