diff --git a/tests/synthesizer/conftest.py b/tests/synthesizer/conftest.py index 040ebeb18..c537e7ce1 100644 --- a/tests/synthesizer/conftest.py +++ b/tests/synthesizer/conftest.py @@ -1,7 +1,10 @@ import pytest from aioresponses import aioresponses, CallbackResult from vocode.streaming.models.audio_encoding import AudioEncoding -from vocode.streaming.models.synthesizer import ElevenLabsSynthesizerConfig, PlayHtSynthesizerConfig +from vocode.streaming.models.synthesizer import ( + ElevenLabsSynthesizerConfig, + PlayHtSynthesizerConfig, +) import re from vocode.streaming.synthesizer.eleven_labs_synthesizer import ( ElevenLabsSynthesizer, @@ -9,7 +12,7 @@ ) from vocode.streaming.synthesizer.play_ht_synthesizer import ( PlayHtSynthesizer, - TTS_ENDPOINT + TTS_ENDPOINT, ) import re @@ -17,8 +20,7 @@ import asyncio import pytest -DEFAULT_PARAMS = {"sampling_rate": 16000, - "audio_encoding": AudioEncoding.LINEAR16} +DEFAULT_PARAMS = {"sampling_rate": 16000, "audio_encoding": AudioEncoding.LINEAR16} MOCK_API_KEY = "my_api_key" MOCK_USER_ID = "my_user_id" @@ -26,7 +28,9 @@ def create_eleven_labs_request_handler(optimize_streaming_latency=False): def request_handler(url, headers, **kwargs): - if optimize_streaming_latency and not re.search(r"optimize_streaming_latency=\d", url): + if optimize_streaming_latency and not re.search( + r"optimize_streaming_latency=\d", url + ): raise Exception("optimize_streaming_latency not found in url") if headers["xi-api-key"] != MOCK_API_KEY: return CallbackResult(status=401) @@ -39,8 +43,7 @@ def request_handler(url, headers, **kwargs): @pytest.fixture def mock_eleven_labs_api(): with aioresponses() as m: - pattern = re.compile( - rf"{re.escape(ELEVEN_LABS_BASE_URL)}text-to-speech/\w+") + pattern = re.compile(rf"{re.escape(ELEVEN_LABS_BASE_URL)}text-to-speech/\w+") m.post(pattern, callback=create_eleven_labs_request_handler()) yield m @@ -70,11 +73,12 @@ async def fixture_eleven_labs_synthesizer_env_api_key(): # PlayHT Setup + def create_play_ht_request_handler(): def request_handler(url, headers, **kwargs): - if headers["Authorization"] != f"Bearer {MOCK_API_KEY}": + if headers["AUTHORIZATION"] != f"Bearer {MOCK_API_KEY}": return CallbackResult(status=401) - if headers["X-User-ID"] != MOCK_USER_ID: + if headers["X-USER-ID"] != MOCK_USER_ID: return CallbackResult(status=401) with open(get_audio_path("fake_audio.mp3"), "rb") as audio_file: return CallbackResult(content_type="audio/mpeg", body=audio_file.read()) @@ -104,6 +108,7 @@ async def fixture_play_ht_synthesizer_wrong_api_key(): params["user_id"] = MOCK_USER_ID return PlayHtSynthesizer(PlayHtSynthesizerConfig(**params)) + @pytest.fixture(scope="module") async def fixture_play_ht_synthesizer_wrong_user_id(): params = DEFAULT_PARAMS.copy() @@ -120,4 +125,3 @@ async def fixture_play_ht_synthesizer_env_api_key(): os.environ["PLAY_HT_API_KEY"] = MOCK_API_KEY os.environ["PLAY_HT_USER_ID"] = MOCK_USER_ID return PlayHtSynthesizer(PlayHtSynthesizerConfig(**params)) - diff --git a/vocode/streaming/synthesizer/play_ht_synthesizer.py b/vocode/streaming/synthesizer/play_ht_synthesizer.py index b80c021df..d6deb49af 100644 --- a/vocode/streaming/synthesizer/play_ht_synthesizer.py +++ b/vocode/streaming/synthesizer/play_ht_synthesizer.py @@ -1,13 +1,8 @@ import asyncio -import io import logging from typing import Optional from aiohttp import ClientSession, ClientTimeout -from pydub import AudioSegment -import requests -from opentelemetry.context.context import Context - from vocode import getenv from vocode.streaming.agent.bot_sentiment_analyser import BotSentiment from vocode.streaming.models.message import BaseMessage @@ -19,7 +14,6 @@ ) from vocode.streaming.utils.mp3_helper import decode_mp3 - TTS_ENDPOINT = "https://play.ht/api/v2/tts/stream" @@ -29,6 +23,8 @@ def __init__( synthesizer_config: PlayHtSynthesizerConfig, logger: Optional[logging.Logger] = None, aiohttp_session: Optional[ClientSession] = None, + max_backoff_retries=3, + backoff_retry_delay=2, ): super().__init__(synthesizer_config, aiohttp_session) self.synthesizer_config = synthesizer_config @@ -40,6 +36,8 @@ def __init__( ) self.words_per_minute = 150 self.experimental_streaming = synthesizer_config.experimental_streaming + self.max_backoff_retries = max_backoff_retries + self.backoff_retry_delay = backoff_retry_delay async def create_speech( self, @@ -48,12 +46,13 @@ async def create_speech( bot_sentiment: Optional[BotSentiment] = None, ) -> SynthesisResult: headers = { - "Authorization": f"Bearer {self.api_key}", - "X-User-ID": self.user_id, + "AUTHORIZATION": f"Bearer {self.api_key}", + "X-USER-ID": self.user_id, "Accept": "audio/mpeg", "Content-Type": "application/json", } body = { + "quality": "draft", "voice": self.synthesizer_config.voice_id, "text": message.text, "sample_rate": self.synthesizer_config.sampling_rate, @@ -69,33 +68,49 @@ async def create_speech( f"synthesizer.{SynthesizerType.PLAY_HT.value.split('_', 1)[-1]}.create_total", ) - response = await self.aiohttp_session.post( - TTS_ENDPOINT, headers=headers, json=body, timeout=ClientTimeout(total=15) - ) - if not response.ok: - raise Exception(f"Play.ht API error status code {response.status}") - if self.experimental_streaming: - return SynthesisResult( - self.experimental_mp3_streaming_output_generator( - response, chunk_size, create_speech_span - ), # should be wav - lambda seconds: self.get_message_cutoff_from_voice_speed( - message, seconds, self.words_per_minute - ), - ) - else: - read_response = await response.read() - create_speech_span.end() - convert_span = tracer.start_span( - f"synthesizer.{SynthesizerType.PLAY_HT.value.split('_', 1)[-1]}.convert", - ) - output_bytes_io = decode_mp3(read_response) + backoff_retry_delay = self.backoff_retry_delay + max_backoff_retries = self.max_backoff_retries - result = self.create_synthesis_result_from_wav( - synthesizer_config=self.synthesizer_config, - file=output_bytes_io, - message=message, - chunk_size=chunk_size, + for attempt in range(max_backoff_retries): + response = await self.aiohttp_session.post( + TTS_ENDPOINT, + headers=headers, + json=body, + timeout=ClientTimeout(total=15), ) - convert_span.end() - return result + + if response.status == 429 and attempt < max_backoff_retries - 1: + await asyncio.sleep(backoff_retry_delay) + backoff_retry_delay *= 2 # Exponentially increase delay + continue + + if not response.ok: + raise Exception(f"Play.ht API error status code {response.status}") + + if self.experimental_streaming: + return SynthesisResult( + self.experimental_mp3_streaming_output_generator( + response, chunk_size, create_speech_span + ), + lambda seconds: self.get_message_cutoff_from_voice_speed( + message, seconds, self.words_per_minute + ), + ) + else: + read_response = await response.read() + create_speech_span.end() + convert_span = tracer.start_span( + f"synthesizer.{SynthesizerType.PLAY_HT.value.split('_', 1)[-1]}.convert", + ) + output_bytes_io = decode_mp3(read_response) + + result = self.create_synthesis_result_from_wav( + synthesizer_config=self.synthesizer_config, + file=output_bytes_io, + message=message, + chunk_size=chunk_size, + ) + convert_span.end() + return result + + raise Exception("Max retries reached for Play.ht API")