Skip to content

Commit

Permalink
add update_options to TTSs (#922)
Browse files Browse the repository at this point in the history
  • Loading branch information
theomonnom authored Oct 15, 2024
1 parent eff6dfb commit c422168
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 28 deletions.
9 changes: 9 additions & 0 deletions .changeset/eight-pigs-dream.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
---
"livekit-plugins-azure": patch
"livekit-plugins-cartesia": patch
"livekit-plugins-elevenlabs": patch
"livekit-plugins-google": patch
"livekit-plugins-openai": patch
---

add update_options to TTS
17 changes: 14 additions & 3 deletions livekit-plugins/livekit-plugins-azure/livekit/plugins/azure/tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,12 @@ class TTS(tts.TTS):
def __init__(
self,
*,
speech_key: str | None = None,
speech_region: str | None = None,
voice: str | None = None,
endpoint_id: str | None = None,
language: str | None = None,
prosody: ProsodyConfig | None = None,
speech_key: str | None = None,
speech_region: str | None = None,
endpoint_id: str | None = None,
) -> None:
"""
Create a new instance of Azure TTS.
Expand Down Expand Up @@ -147,6 +147,17 @@ def __init__(
prosody=prosody,
)

def update_options(
self,
*,
voice: str | None = None,
language: str | None = None,
prosody: ProsodyConfig | None = None,
) -> None:
self._opts.voice = voice or self._opts.voice
self._opts.language = language or self._opts.language
self._opts.prosody = prosody or self._opts.prosody

def synthesize(self, text: str) -> "ChunkedStream":
return ChunkedStream(text, self._opts)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@

@dataclass
class _TTSOptions:
model: TTSModels
model: TTSModels | str
encoding: TTSEncoding
sample_rate: int
voice: str | list[float]
Expand All @@ -57,7 +57,7 @@ class TTS(tts.TTS):
def __init__(
self,
*,
model: TTSModels = "sonic-english",
model: TTSModels | str = "sonic-english",
language: str = "en",
encoding: TTSEncoding = "pcm_s16le",
voice: str | list[float] = TTSDefaultVoiceId,
Expand Down Expand Up @@ -112,6 +112,35 @@ def _ensure_session(self) -> aiohttp.ClientSession:

return self._session

def update_options(
self,
*,
model: TTSModels | None = None,
language: str | None = None,
voice: str | list[float] | None = None,
speed: TTSVoiceSpeed | float | None = None,
emotion: list[TTSVoiceEmotion | str] | None = None,
) -> None:
"""
Update the Text-to-Speech (TTS) configuration options.
This method allows updating the TTS settings, including model type, language, voice, speed,
and emotion. If any parameter is not provided, the existing value will be retained.
Args:
model (TTSModels, optional): The Cartesia TTS model to use. Defaults to "sonic-english".
language (str, optional): The language code for synthesis. Defaults to "en".
voice (str | list[float], optional): The voice ID or embedding array.
speed (TTSVoiceSpeed | float, optional): Voice Control - Speed (https://docs.cartesia.ai/user-guides/voice-control)
emotion (list[TTSVoiceEmotion], optional): Voice Control - Emotion (https://docs.cartesia.ai/user-guides/voice-control)
"""
self._opts.model = model or self._opts.model
self._opts.language = language or self._opts.language
self._opts.voice = voice or self._opts.voice
self._opts.speed = speed or self._opts.speed
if emotion is not None:
self._opts.emotion = emotion

def synthesize(self, text: str) -> "ChunkedStream":
return ChunkedStream(text, self._opts, self._ensure_session())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class Voice:
class _TTSOptions:
api_key: str
voice: Voice
model_id: TTSModels
model: TTSModels | str
base_url: str
encoding: TTSEncoding
sample_rate: int
Expand All @@ -94,7 +94,7 @@ def __init__(
self,
*,
voice: Voice = DEFAULT_VOICE,
model_id: TTSModels = "eleven_turbo_v2_5",
model: TTSModels | str = "eleven_turbo_v2_5",
api_key: str | None = None,
base_url: str | None = None,
encoding: TTSEncoding = "mp3_22050_32",
Expand All @@ -105,12 +105,23 @@ def __init__(
enable_ssml_parsing: bool = False,
chunk_length_schedule: list[int] = [80, 120, 200, 260], # range is [50, 500]
http_session: aiohttp.ClientSession | None = None,
# deprecated
model_id: TTSModels | str | None = None,
) -> None:
"""
Create a new instance of ElevenLabs TTS.
``api_key`` must be set to your ElevenLabs API key, either using the argument or by setting
the ``ELEVEN_API_KEY`` environmental variable.
Args:
voice (Voice): Voice configuration. Defaults to `DEFAULT_VOICE`.
model (TTSModels | str): TTS model to use. Defaults to "eleven_turbo_v2_5".
api_key (str | None): ElevenLabs API key. Can be set via argument or `ELEVEN_API_KEY` environment variable.
base_url (str | None): Custom base URL for the API. Optional.
encoding (TTSEncoding): Audio encoding format. Defaults to "mp3_22050_32".
streaming_latency (int): Latency in seconds for streaming. Defaults to 3.
word_tokenizer (tokenize.WordTokenizer): Tokenizer for processing text. Defaults to basic WordTokenizer.
enable_ssml_parsing (bool): Enable SSML parsing for input text. Defaults to False.
chunk_length_schedule (list[int]): Schedule for chunk lengths, ranging from 50 to 500. Defaults to [80, 120, 200, 260].
http_session (aiohttp.ClientSession | None): Custom HTTP session for API requests. Optional.
"""

super().__init__(
Expand All @@ -120,13 +131,20 @@ def __init__(
sample_rate=_sample_rate_from_format(encoding),
num_channels=1,
)

if model_id is not None:
logger.warning(
"model_id is deprecated and will be removed in 1.5.0, use model instead",
)
model = model_id

api_key = api_key or os.environ.get("ELEVEN_API_KEY")
if not api_key:
raise ValueError("ELEVEN_API_KEY must be set")

self._opts = _TTSOptions(
voice=voice,
model_id=model_id,
model=model,
api_key=api_key,
base_url=base_url or API_BASE_URL_V1,
encoding=encoding,
Expand All @@ -151,6 +169,20 @@ async def list_voices(self) -> List[Voice]:
) as resp:
return _dict_to_voices_list(await resp.json())

def update_options(
self,
*,
voice: Voice = DEFAULT_VOICE,
model: TTSModels | str = "eleven_turbo_v2_5",
) -> None:
"""
Args:
voice (Voice): Voice configuration. Defaults to `DEFAULT_VOICE`.
model (TTSModels | str): TTS model to use. Defaults to "eleven_turbo_v2_5".
"""
self._opts.model = model or self._opts.model
self._opts.voice = voice or self._opts.voice

def synthesize(self, text: str) -> "ChunkedStream":
return ChunkedStream(text, self._opts, self._ensure_session())

Expand Down Expand Up @@ -184,7 +216,7 @@ async def _main_task(self) -> None:
)
data = {
"text": self._text,
"model_id": self._opts.model_id,
"model_id": self._opts.model,
"voice_settings": voice_settings,
}

Expand Down Expand Up @@ -450,7 +482,7 @@ def _strip_nones(data: dict[str, Any]):
def _synthesize_url(opts: _TTSOptions) -> str:
base_url = opts.base_url
voice_id = opts.voice.id
model_id = opts.model_id
model_id = opts.model
output_format = opts.encoding
latency = opts.streaming_latency
return (
Expand All @@ -462,7 +494,7 @@ def _synthesize_url(opts: _TTSOptions) -> str:
def _stream_url(opts: _TTSOptions) -> str:
base_url = opts.base_url
voice_id = opts.voice.id
model_id = opts.model_id
model_id = opts.model
output_format = opts.encoding
latency = opts.streaming_latency
enable_ssml = str(opts.enable_ssml_parsing).lower()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Union

from livekit import rtc
from livekit.agents import tts, utils
Expand All @@ -26,10 +25,6 @@
from .log import logger
from .models import AudioEncoding, Gender, SpeechLanguages

LgType = Union[SpeechLanguages, str]
GenderType = Union[Gender, str]
AudioEncodingType = Union[AudioEncoding, str]


@dataclass
class _TTSOptions:
Expand All @@ -41,10 +36,10 @@ class TTS(tts.TTS):
def __init__(
self,
*,
language: LgType = "en-US",
gender: GenderType = "neutral",
language: SpeechLanguages | str = "en-US",
gender: Gender | str = "neutral",
voice_name: str = "", # Not required
encoding: AudioEncodingType = "linear16",
encoding: AudioEncoding | str = "linear16",
sample_rate: int = 24000,
speaking_rate: float = 1.0,
credentials_info: dict | None = None,
Expand All @@ -56,6 +51,16 @@ def __init__(
Credentials must be provided, either by using the ``credentials_info`` dict, or reading
from the file specified in ``credentials_file`` or the ``GOOGLE_APPLICATION_CREDENTIALS``
environmental variable.
Args:
language (SpeechLanguages | str, optional): Language code (e.g., "en-US"). Default is "en-US".
gender (Gender | str, optional): Voice gender ("male", "female", "neutral"). Default is "neutral".
voice_name (str, optional): Specific voice name. Default is an empty string.
encoding (AudioEncoding | str, optional): Audio encoding format (e.g., "linear16"). Default is "linear16".
sample_rate (int, optional): Audio sample rate in Hz. Default is 24000.
speaking_rate (float, optional): Speed of speech. Default is 1.0.
credentials_info (dict, optional): Dictionary containing Google Cloud credentials. Default is None.
credentials_file (str, optional): Path to the Google Cloud credentials JSON file. Default is None.
"""

super().__init__(
Expand All @@ -70,14 +75,10 @@ def __init__(
self._credentials_info = credentials_info
self._credentials_file = credentials_file

ssml_gender = SsmlVoiceGender.NEUTRAL
if gender == "male":
ssml_gender = SsmlVoiceGender.MALE
elif gender == "female":
ssml_gender = SsmlVoiceGender.FEMALE

voice = texttospeech.VoiceSelectionParams(
name=voice_name, language_code=language, ssml_gender=ssml_gender
name=voice_name,
language_code=language,
ssml_gender=_gender_from_str(gender),
)

if encoding == "linear16" or encoding == "wav":
Expand All @@ -96,6 +97,30 @@ def __init__(
),
)

def update_options(
self,
*,
language: SpeechLanguages | str = "en-US",
gender: Gender | str = "neutral",
voice_name: str = "", # Not required
speaking_rate: float = 1.0,
) -> None:
"""
Update the TTS options.
Args:
language (SpeechLanguages | str, optional): Language code (e.g., "en-US"). Default is "en-US".
gender (Gender | str, optional): Voice gender ("male", "female", "neutral"). Default is "neutral".
voice_name (str, optional): Specific voice name. Default is an empty string.
speaking_rate (float, optional): Speed of speech. Default is 1.0.
"""
self._opts.voice = texttospeech.VoiceSelectionParams(
name=voice_name,
language_code=language,
ssml_gender=_gender_from_str(gender),
)
self._opts.audio_config.speaking_rate = speaking_rate

def _ensure_client(self) -> texttospeech.TextToSpeechAsyncClient:
if not self._client:
if self._credentials_info:
Expand Down Expand Up @@ -172,3 +197,13 @@ async def _main_task(self) -> None:
),
)
)


def _gender_from_str(gender: str) -> SsmlVoiceGender:
ssml_gender = SsmlVoiceGender.NEUTRAL
if gender == "male":
ssml_gender = SsmlVoiceGender.MALE
elif gender == "female":
ssml_gender = SsmlVoiceGender.FEMALE

return ssml_gender
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,13 @@ def __init__(
speed=speed,
)

def update_options(
self, *, model: TTSModels | None, voice: TTSVoices | None, speed: float | None
) -> None:
self._opts.model = model or self._opts.model
self._opts.voice = voice or self._opts.voice
self._opts.speed = speed or self._opts.speed

@staticmethod
def create_azure_client(
*,
Expand Down

0 comments on commit c422168

Please sign in to comment.