Skip to content

Commit

Permalink
exponential backoff to handle Play.ht 429s (#438)
Browse files Browse the repository at this point in the history
* exponential backoff to handle Play.ht 429s

* fixing unit test
  • Loading branch information
skirdey authored Nov 30, 2023
1 parent 5b42c89 commit 09aed42
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 46 deletions.
24 changes: 14 additions & 10 deletions tests/synthesizer/conftest.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,36 @@
import pytest
from aioresponses import aioresponses, CallbackResult
from vocode.streaming.models.audio_encoding import AudioEncoding
from vocode.streaming.models.synthesizer import ElevenLabsSynthesizerConfig, PlayHtSynthesizerConfig
from vocode.streaming.models.synthesizer import (
ElevenLabsSynthesizerConfig,
PlayHtSynthesizerConfig,
)
import re
from vocode.streaming.synthesizer.eleven_labs_synthesizer import (
ElevenLabsSynthesizer,
ELEVEN_LABS_BASE_URL,
)
from vocode.streaming.synthesizer.play_ht_synthesizer import (
PlayHtSynthesizer,
TTS_ENDPOINT
TTS_ENDPOINT,
)

import re
from tests.streaming.data.loader import get_audio_path
import asyncio
import pytest

DEFAULT_PARAMS = {"sampling_rate": 16000,
"audio_encoding": AudioEncoding.LINEAR16}
DEFAULT_PARAMS = {"sampling_rate": 16000, "audio_encoding": AudioEncoding.LINEAR16}

MOCK_API_KEY = "my_api_key"
MOCK_USER_ID = "my_user_id"


def create_eleven_labs_request_handler(optimize_streaming_latency=False):
def request_handler(url, headers, **kwargs):
if optimize_streaming_latency and not re.search(r"optimize_streaming_latency=\d", url):
if optimize_streaming_latency and not re.search(
r"optimize_streaming_latency=\d", url
):
raise Exception("optimize_streaming_latency not found in url")
if headers["xi-api-key"] != MOCK_API_KEY:
return CallbackResult(status=401)
Expand All @@ -39,8 +43,7 @@ def request_handler(url, headers, **kwargs):
@pytest.fixture
def mock_eleven_labs_api():
with aioresponses() as m:
pattern = re.compile(
rf"{re.escape(ELEVEN_LABS_BASE_URL)}text-to-speech/\w+")
pattern = re.compile(rf"{re.escape(ELEVEN_LABS_BASE_URL)}text-to-speech/\w+")
m.post(pattern, callback=create_eleven_labs_request_handler())
yield m

Expand Down Expand Up @@ -70,11 +73,12 @@ async def fixture_eleven_labs_synthesizer_env_api_key():

# PlayHT Setup


def create_play_ht_request_handler():
def request_handler(url, headers, **kwargs):
if headers["Authorization"] != f"Bearer {MOCK_API_KEY}":
if headers["AUTHORIZATION"] != f"Bearer {MOCK_API_KEY}":
return CallbackResult(status=401)
if headers["X-User-ID"] != MOCK_USER_ID:
if headers["X-USER-ID"] != MOCK_USER_ID:
return CallbackResult(status=401)
with open(get_audio_path("fake_audio.mp3"), "rb") as audio_file:
return CallbackResult(content_type="audio/mpeg", body=audio_file.read())
Expand Down Expand Up @@ -104,6 +108,7 @@ async def fixture_play_ht_synthesizer_wrong_api_key():
params["user_id"] = MOCK_USER_ID
return PlayHtSynthesizer(PlayHtSynthesizerConfig(**params))


@pytest.fixture(scope="module")
async def fixture_play_ht_synthesizer_wrong_user_id():
params = DEFAULT_PARAMS.copy()
Expand All @@ -120,4 +125,3 @@ async def fixture_play_ht_synthesizer_env_api_key():
os.environ["PLAY_HT_API_KEY"] = MOCK_API_KEY
os.environ["PLAY_HT_USER_ID"] = MOCK_USER_ID
return PlayHtSynthesizer(PlayHtSynthesizerConfig(**params))

87 changes: 51 additions & 36 deletions vocode/streaming/synthesizer/play_ht_synthesizer.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
import asyncio
import io
import logging
from typing import Optional

from aiohttp import ClientSession, ClientTimeout
from pydub import AudioSegment
import requests
from opentelemetry.context.context import Context

from vocode import getenv
from vocode.streaming.agent.bot_sentiment_analyser import BotSentiment
from vocode.streaming.models.message import BaseMessage
Expand All @@ -19,7 +14,6 @@
)
from vocode.streaming.utils.mp3_helper import decode_mp3


TTS_ENDPOINT = "https://play.ht/api/v2/tts/stream"


Expand All @@ -29,6 +23,8 @@ def __init__(
synthesizer_config: PlayHtSynthesizerConfig,
logger: Optional[logging.Logger] = None,
aiohttp_session: Optional[ClientSession] = None,
max_backoff_retries=3,
backoff_retry_delay=2,
):
super().__init__(synthesizer_config, aiohttp_session)
self.synthesizer_config = synthesizer_config
Expand All @@ -40,6 +36,8 @@ def __init__(
)
self.words_per_minute = 150
self.experimental_streaming = synthesizer_config.experimental_streaming
self.max_backoff_retries = max_backoff_retries
self.backoff_retry_delay = backoff_retry_delay

async def create_speech(
self,
Expand All @@ -48,12 +46,13 @@ async def create_speech(
bot_sentiment: Optional[BotSentiment] = None,
) -> SynthesisResult:
headers = {
"Authorization": f"Bearer {self.api_key}",
"X-User-ID": self.user_id,
"AUTHORIZATION": f"Bearer {self.api_key}",
"X-USER-ID": self.user_id,
"Accept": "audio/mpeg",
"Content-Type": "application/json",
}
body = {
"quality": "draft",
"voice": self.synthesizer_config.voice_id,
"text": message.text,
"sample_rate": self.synthesizer_config.sampling_rate,
Expand All @@ -69,33 +68,49 @@ async def create_speech(
f"synthesizer.{SynthesizerType.PLAY_HT.value.split('_', 1)[-1]}.create_total",
)

response = await self.aiohttp_session.post(
TTS_ENDPOINT, headers=headers, json=body, timeout=ClientTimeout(total=15)
)
if not response.ok:
raise Exception(f"Play.ht API error status code {response.status}")
if self.experimental_streaming:
return SynthesisResult(
self.experimental_mp3_streaming_output_generator(
response, chunk_size, create_speech_span
), # should be wav
lambda seconds: self.get_message_cutoff_from_voice_speed(
message, seconds, self.words_per_minute
),
)
else:
read_response = await response.read()
create_speech_span.end()
convert_span = tracer.start_span(
f"synthesizer.{SynthesizerType.PLAY_HT.value.split('_', 1)[-1]}.convert",
)
output_bytes_io = decode_mp3(read_response)
backoff_retry_delay = self.backoff_retry_delay
max_backoff_retries = self.max_backoff_retries

result = self.create_synthesis_result_from_wav(
synthesizer_config=self.synthesizer_config,
file=output_bytes_io,
message=message,
chunk_size=chunk_size,
for attempt in range(max_backoff_retries):
response = await self.aiohttp_session.post(
TTS_ENDPOINT,
headers=headers,
json=body,
timeout=ClientTimeout(total=15),
)
convert_span.end()
return result

if response.status == 429 and attempt < max_backoff_retries - 1:
await asyncio.sleep(backoff_retry_delay)
backoff_retry_delay *= 2 # Exponentially increase delay
continue

if not response.ok:
raise Exception(f"Play.ht API error status code {response.status}")

if self.experimental_streaming:
return SynthesisResult(
self.experimental_mp3_streaming_output_generator(
response, chunk_size, create_speech_span
),
lambda seconds: self.get_message_cutoff_from_voice_speed(
message, seconds, self.words_per_minute
),
)
else:
read_response = await response.read()
create_speech_span.end()
convert_span = tracer.start_span(
f"synthesizer.{SynthesizerType.PLAY_HT.value.split('_', 1)[-1]}.convert",
)
output_bytes_io = decode_mp3(read_response)

result = self.create_synthesis_result_from_wav(
synthesizer_config=self.synthesizer_config,
file=output_bytes_io,
message=message,
chunk_size=chunk_size,
)
convert_span.end()
return result

raise Exception("Max retries reached for Play.ht API")

0 comments on commit 09aed42

Please sign in to comment.