Skip to content

Commit

Permalink
11labs: send phoneme in one entire xml chunk (#766)
Browse files Browse the repository at this point in the history
  • Loading branch information
theomonnom authored Sep 17, 2024
1 parent 044a29d commit b18447a
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 43 deletions.
6 changes: 6 additions & 0 deletions .changeset/large-dogs-notice.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
"livekit-agents": patch
"livekit-plugins-elevenlabs": patch
---

11labs: send phoneme in one entire xml chunk
65 changes: 34 additions & 31 deletions livekit-agents/livekit/agents/tokenize/token_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,65 +26,68 @@ def __init__(
self._current_segment_id = shortuuid()

self._buf_tokens: list[str] = [] # <= min_token_len
self._buf = ""
self._in_buf = ""
self._out_buf = ""

@typing.no_type_check
def push_text(self, text: str) -> None:
self._check_not_closed()
self._buf += text
self._in_buf += text

if len(self._buf) < self._min_ctx_len:
if len(self._in_buf) < self._min_ctx_len:
return

tokens = self._tokenize_fnc(self._buf)
while True:
tokens = self._tokenize_fnc(self._in_buf)
if len(tokens) <= 1:
break

buf_toks = []
buf = ""
while len(tokens) > 1:
if buf:
buf += " "
if self._out_buf:
self._out_buf += " "

tok = tokens.pop(0)
tok_text = tok
if isinstance(tok, tuple):
tok_text = tok[0]

buf += tok_text
buf_toks.append(tok)
if len(buf) >= self._min_token_len:
self._out_buf += tok_text
if len(self._out_buf) >= self._min_token_len:
self._event_ch.send_nowait(
TokenData(token=buf, segment_id=self._current_segment_id)
TokenData(token=self._out_buf, segment_id=self._current_segment_id)
)

if isinstance(tok, tuple):
self._buf = self._buf[tok[2] :]
else:
for i, tok in enumerate(buf_toks):
tok_i = max(self._buf.find(tok), 0)
self._buf = self._buf[tok_i + len(tok) :].lstrip()
self._out_buf = ""

buf_toks = []
buf = ""
if isinstance(tok, tuple):
self._in_buf = self._in_buf[tok[2] :]
else:
tok_i = max(self._in_buf.find(tok), 0)
self._in_buf = self._in_buf[tok_i + len(tok) :].lstrip()

@typing.no_type_check
def flush(self) -> None:
self._check_not_closed()
if self._buf:
tokens = self._tokenize_fnc(self._buf)

if self._in_buf or self._out_buf:
tokens = self._tokenize_fnc(self._in_buf)
if tokens:
if self._out_buf:
self._out_buf += " "

if isinstance(tokens[0], tuple):
buf = " ".join([tok[0] for tok in tokens])
self._out_buf += " ".join([tok[0] for tok in tokens])
else:
buf = " ".join(tokens)
else:
buf = self._buf
self._out_buf += " ".join(tokens)

if self._out_buf:
self._event_ch.send_nowait(
TokenData(token=self._out_buf, segment_id=self._current_segment_id)
)

self._event_ch.send_nowait(
TokenData(token=buf, segment_id=self._current_segment_id)
)
self._current_segment_id = shortuuid()

self._buf = ""
self._in_buf = ""
self._out_buf = ""

def end_input(self) -> None:
self.flush()
Expand Down
10 changes: 1 addition & 9 deletions livekit-agents/livekit/agents/tokenize/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,6 @@ def replace_words(

replacements = {k.lower(): v for k, v in replacements.items()}

def _match_case(word, replacement):
if word.isupper():
return replacement.upper()
elif word.istitle():
return replacement.title()
else:
return replacement.lower()

def _process_words(text, words):
offset = 0
processed_index = 0
Expand All @@ -54,7 +46,7 @@ def _process_words(text, words):
if replacement:
text = (
text[: start_index + offset]
+ _match_case(word, replacement)
+ replacement
+ text[end_index + offset - punctuation_off :]
)
offset += len(replacement) - len(word) + punctuation_off
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ async def _main_task(self) -> None:
if encoding == "mp3":
async for bytes_data, _ in resp.content.iter_chunks():
for frame in self._mp3_decoder.decode_chunk(bytes_data):
for frame in bstream.write(frame.data):
for frame in bstream.write(frame.data.tobytes()):
self._event_ch.send_nowait(
tts.SynthesizedAudio(
request_id=request_id,
Expand Down Expand Up @@ -325,15 +325,34 @@ async def _run_ws(
async def send_task():
nonlocal eos_sent

xml_content = []
async for data in word_stream:
text = data.token

# send the xml phoneme in one go
if (
self._opts.enable_ssml_parsing
and data.token.startswith("<phoneme")
or xml_content
):
xml_content.append(text)
if data.token.find("</phoneme>") > -1:
text = self._opts.word_tokenizer.format_words(xml_content)
xml_content = []
else:
continue

# try_trigger_generation=True is a bad practice, we expose
# chunk_length_schedule instead
data_pkt = dict(
text=f"{data.token} ", # must always end with a space
text=f"{text} ", # must always end with a space
try_trigger_generation=False,
)
await ws_conn.send_str(json.dumps(data_pkt))

if xml_content:
logger.warning("11labs stream ended with incomplete xml content")

# no more token, mark eos
eos_pkt = dict(text="")
await ws_conn.send_str(json.dumps(eos_pkt))
Expand Down
56 changes: 55 additions & 1 deletion tests/test_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,60 @@ async def test_streamed_word_tokenizer(tokenizer: tokenize.WordTokenizer):
assert ev.token == WORDS_EXPECTED[i]


WORDS_PUNCT_TEXT = 'This is <phoneme alphabet="cmu-arpabet" ph="AE K CH UW AH L IY">actually</phoneme> tricky to handle.'

WORDS_PUNCT_EXPECTED = [
"This",
"is",
"<phoneme",
'alphabet="cmu-arpabet"',
'ph="AE',
"K",
"CH",
"UW",
"AH",
"L",
'IY">actually</phoneme>',
"tricky",
"to",
"handle.",
]

WORD_PUNCT_TOKENIZERS = [basic.WordTokenizer(ignore_punctuation=False)]


@pytest.mark.parametrize("tokenizer", WORD_PUNCT_TOKENIZERS)
def test_punct_word_tokenizer(tokenizer: tokenize.WordTokenizer):
tokens = tokenizer.tokenize(text=WORDS_PUNCT_TEXT)
for i, token in enumerate(WORDS_PUNCT_EXPECTED):
assert token == tokens[i]


@pytest.mark.parametrize("tokenizer", WORD_PUNCT_TOKENIZERS)
async def test_streamed_punct_word_tokenizer(tokenizer: tokenize.WordTokenizer):
# divide text by chunks of arbitrary length (1-4)
pattern = [1, 2, 4]
text = WORDS_PUNCT_TEXT
chunks = []
pattern_iter = iter(pattern * (len(text) // sum(pattern) + 1))

for chunk_size in pattern_iter:
if not text:
break
chunks.append(text[:chunk_size])
text = text[chunk_size:]

stream = tokenizer.stream()
for chunk in chunks:
stream.push_text(chunk)

stream.end_input()

for i in range(len(WORDS_PUNCT_EXPECTED)):
ev = await stream.__anext__()
assert ev.token == WORDS_PUNCT_EXPECTED[i]


HYPHENATOR_TEXT = [
"Segment",
"expected",
Expand Down Expand Up @@ -148,7 +202,7 @@ def test_hyphenate_word():
"framework. A.B.C"
)
REPLACE_EXPECTED = (
"This is a test. Hello universe, I'm creating this assistants.. library. Twice again "
"This is a test. Hello universe, I'm creating this assistants.. library. twice again "
"library. A.B.C.D"
)

Expand Down

0 comments on commit b18447a

Please sign in to comment.