From c156d3c5eedcea3d829342e0a3787cd525c849a5 Mon Sep 17 00:00:00 2001 From: aoife cassidy Date: Mon, 20 May 2024 03:33:00 +0300 Subject: [PATCH] Add ChunkedStream --- agents/src/ipc/job_main.ts | 10 ++-- agents/src/ipc/job_process.ts | 7 +-- agents/src/tts/index.ts | 2 + agents/src/tts/stream_adapter.ts | 20 +++---- agents/src/tts/tts.ts | 24 +++++++- agents/src/worker.ts | 6 +- examples/src/minimal.ts | 2 +- examples/src/tts.ts | 20 ++++--- plugins/elevenlabs/src/tts.ts | 96 ++++++++++++++++++++++---------- 9 files changed, 124 insertions(+), 63 deletions(-) diff --git a/agents/src/ipc/job_main.ts b/agents/src/ipc/job_main.ts index 8077b84d..7d8466fa 100644 --- a/agents/src/ipc/job_main.ts +++ b/agents/src/ipc/job_main.ts @@ -10,15 +10,15 @@ import { log } from '../log.js'; import { IPC_MESSAGE, type JobMainArgs, type Message, type Ping } from './protocol.js'; export const runJob = (args: JobMainArgs): ChildProcess => { - return fork(__filename, [args.raw, args.entry, args.fallbackURL]); + return fork(import.meta.filename, [args.raw, args.entry, args.fallbackURL]); }; if (process.send) { // process.argv: - // [0] `node' or `bun' - // [1] __filename + // [0] `node' + // [1] import.meta.filename // [2] proto.JobAssignment, serialized to JSON string - // [3] __filename of function containing entry file + // [3] import.meta.filename of function containing entry file // [4] fallback URL in case JobAssignment.url is empty const msg = new ServerMessage(); @@ -56,7 +56,7 @@ if (process.send) { // here we import the file containing the exported entry function, and call it. // the file must export default an Agent, usually using defineAgent(). import(process.argv[3]).then((agent) => { - agent.entry(new JobContext(closeEvent, args.job!, room)); + agent.default.entry(new JobContext(closeEvent, args.job!, room)); }); } }; diff --git a/agents/src/ipc/job_process.ts b/agents/src/ipc/job_process.ts index 5bb7b08a..690135fd 100644 --- a/agents/src/ipc/job_process.ts +++ b/agents/src/ipc/job_process.ts @@ -7,7 +7,6 @@ import { once } from 'events'; import type { Logger } from 'pino'; import type { AcceptData } from '../job_request.js'; import { log } from '../log.js'; -import { runJob } from './job_main.js'; import { IPC_MESSAGE, type JobMainArgs, @@ -41,7 +40,7 @@ export class JobProcess { } async close() { - this.logger.info('closing job process'); + this.logger.debug('closing job process'); await this.clear(); this.process!.send({ type: IPC_MESSAGE.ShutdownRequest }); await once(this.process!, 'disconnect'); @@ -75,7 +74,7 @@ export class JobProcess { this.clear(); }, PING_TIMEOUT); - this.process = runJob(this.args); + this.process = await import('./job_main.js').then((main) => main.runJob(this.args)); this.process.on('message', (msg: Message) => { if (msg.type === IPC_MESSAGE.StartJobResponse) { @@ -87,7 +86,7 @@ export class JobProcess { } this.pongTimeout?.refresh(); } else if (msg.type === IPC_MESSAGE.UserExit || msg.type === IPC_MESSAGE.ShutdownResponse) { - this.logger.info('job exiting'); + this.logger.debug('job exiting'); this.clear(); } }); diff --git a/agents/src/tts/index.ts b/agents/src/tts/index.ts index cc1d6e06..ac236d28 100644 --- a/agents/src/tts/index.ts +++ b/agents/src/tts/index.ts @@ -3,6 +3,7 @@ // SPDX-License-Identifier: Apache-2.0 import { StreamAdapter, StreamAdapterWrapper } from './stream_adapter.js'; import { + ChunkedStream, SynthesisEvent, SynthesisEventType, SynthesizeStream, @@ -18,4 +19,5 @@ export { SynthesizeStream, StreamAdapter, StreamAdapterWrapper, + ChunkedStream, }; diff --git a/agents/src/tts/stream_adapter.ts b/agents/src/tts/stream_adapter.ts index 7c0f4c17..9f1abc30 100644 --- a/agents/src/tts/stream_adapter.ts +++ b/agents/src/tts/stream_adapter.ts @@ -2,13 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 import type { SentenceStream, SentenceTokenizer } from '../tokenize.js'; -import { - SynthesisEvent, - SynthesisEventType, - SynthesizeStream, - type SynthesizedAudio, - TTS, -} from './tts.js'; +import { ChunkedStream, SynthesisEvent, SynthesisEventType, SynthesizeStream, TTS } from './tts.js'; export class StreamAdapterWrapper extends SynthesizeStream { closed: boolean; @@ -41,10 +35,12 @@ export class StreamAdapterWrapper extends SynthesizeStream { reject(new Error('cancelled')); }; for await (const sentence of this.sentenceStream) { - const audio = await this.tts.synthesize(sentence.text); - this.eventQueue.push(new SynthesisEvent(SynthesisEventType.STARTED)); - this.eventQueue.push(new SynthesisEvent(SynthesisEventType.AUDIO, audio)); - this.eventQueue.push(new SynthesisEvent(SynthesisEventType.FINISHED)); + const audio = await this.tts.synthesize(sentence.text).then((data) => data.next()); + if (!audio.done) { + this.eventQueue.push(new SynthesisEvent(SynthesisEventType.STARTED)); + this.eventQueue.push(new SynthesisEvent(SynthesisEventType.AUDIO, audio.value)); + this.eventQueue.push(new SynthesisEvent(SynthesisEventType.FINISHED)); + } } } } @@ -86,7 +82,7 @@ export class StreamAdapter extends TTS { this.tokenizer = tokenizer; } - synthesize(text: string): Promise { + synthesize(text: string): Promise { return this.tts.synthesize(text); } diff --git a/agents/src/tts/tts.ts b/agents/src/tts/tts.ts index 2c5bf19e..0f88a8ab 100644 --- a/agents/src/tts/tts.ts +++ b/agents/src/tts/tts.ts @@ -2,6 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 import type { AudioFrame } from '@livekit/rtc-node'; +import { mergeFrames } from '../utils.js'; export interface SynthesizedAudio { text: string; @@ -46,7 +47,7 @@ export abstract class TTS { this.#streamingSupported = streamingSupported; } - abstract synthesize(text: string): Promise; + abstract synthesize(text: string): Promise; abstract stream(): SynthesizeStream; @@ -54,3 +55,24 @@ export abstract class TTS { return this.#streamingSupported; } } + +export abstract class ChunkedStream implements AsyncIterableIterator { + async collect(): Promise { + const frames = []; + for await (const ev of this) { + frames.push(ev.data); + } + return mergeFrames(frames); + } + + abstract close(): Promise; + abstract next(): Promise>; + + [Symbol.iterator](): ChunkedStream { + return this; + } + + [Symbol.asyncIterator](): ChunkedStream { + return this; + } +} diff --git a/agents/src/worker.ts b/agents/src/worker.ts index 4a043623..8ab21bb5 100644 --- a/agents/src/worker.ts +++ b/agents/src/worker.ts @@ -216,8 +216,8 @@ export class Worker { this.processes[job.id] = { proc, activeJob: new ActiveJob(job, acceptData) }; proc .run() - .catch(() => { - proc.logger.error(`error running job process ${proc.job.id}`); + .catch((e) => { + proc.logger.error(`error running job process ${proc.job.id}: ${e}`); }) .finally(() => { proc.clear(); @@ -374,7 +374,7 @@ export class Worker { async close() { if (this.closed) return; this.closed = true; - this.logger.info('shutting down worker'); + this.logger.debug('shutting down worker'); await this.httpServer.close(); for await (const value of Object.values(this.processes)) { await value.proc.close(); diff --git a/examples/src/minimal.ts b/examples/src/minimal.ts index cbc2b0c7..055a77ea 100644 --- a/examples/src/minimal.ts +++ b/examples/src/minimal.ts @@ -6,7 +6,7 @@ import { fileURLToPath } from 'url'; const requestFunc = async (req: JobRequest) => { console.log('received request', req); - await req.accept(__filename); + await req.accept(import.meta.filename); }; if (process.argv[1] === fileURLToPath(import.meta.url)) { diff --git a/examples/src/tts.ts b/examples/src/tts.ts index 4730c9c6..e89e2bfa 100644 --- a/examples/src/tts.ts +++ b/examples/src/tts.ts @@ -14,7 +14,7 @@ import { AudioSource, LocalAudioTrack, TrackPublishOptions, TrackSource } from ' import { fileURLToPath } from 'url'; const requestFunc = async (req: JobRequest) => { - await req.accept(__filename); + await req.accept(import.meta.filename); }; if (process.argv[1] === fileURLToPath(import.meta.url)) { @@ -32,14 +32,20 @@ export default defineAgent(async (job: JobContext) => { const tts = new TTS(); log.info('speaking "Hello!"'); - await tts.synthesize('Hello!').then((output) => { - source.captureFrame(output.data); - }); + await tts + .synthesize('Hello!') + .then((output) => output.collect()) + .then((output) => { + source.captureFrame(output); + }); await new Promise((resolve) => setTimeout(resolve, 1000)); log.info('speaking "Goodbye."'); - await tts.synthesize('Goodbye.').then((output) => { - source.captureFrame(output.data); - }); + await tts + .synthesize('Goodbye.') + .then((output) => output.collect()) + .then((output) => { + source.captureFrame(output); + }); }); diff --git a/plugins/elevenlabs/src/tts.ts b/plugins/elevenlabs/src/tts.ts index d3352872..e1172445 100644 --- a/plugins/elevenlabs/src/tts.ts +++ b/plugins/elevenlabs/src/tts.ts @@ -98,36 +98,8 @@ export class TTS extends tts.TTS { }); } - synthesize(text: string): Promise { - const voice = this.config.voice; - - return new Promise((resolve) => { - const url = new URL(`${this.config.baseURL}/text-to-speech/${voice.id}`); - url.searchParams.append('output_format', 'pcm_' + this.config.sampleRate); - - fetch(url.toString(), { - method: 'POST', - headers: { [AUTHORIZATION_HEADER]: this.config.apiKey }, - body: JSON.stringify({ - text, - model_id: this.config.modelID, - voice_settings: this.config.voice.settings || undefined, - }), - }) - .then((data) => data.arrayBuffer()) - .then((data) => new DataView(data, 0, data.byteLength)) - .then((data) => - resolve({ - text, - data: new AudioFrame( - new Uint16Array(data.buffer), - this.config.sampleRate, - 1, - data.byteLength, - ), - }), - ); - }); + async synthesize(text: string): Promise { + return new ChunkedStream(text, this.config); } stream(): tts.SynthesizeStream { @@ -293,3 +265,67 @@ export class SynthesizeStream extends tts.SynthesizeStream { } } } + +class ChunkedStream extends tts.ChunkedStream { + config: TTSOptions; + text: string; + queue: (tts.SynthesizedAudio | undefined)[] = []; + + constructor(text: string, config: TTSOptions) { + super(); + this.config = config; + this.text = text; + } + + async next(): Promise> { + await this.run(); + const audio = this.queue.shift(); + if (audio) { + return { done: false, value: audio }; + } else { + return { done: true, value: undefined }; + } + } + + async close() { + this.queue.push(undefined); + } + + async run() { + const voice = this.config.voice; + + const url = new URL(`${this.config.baseURL}/text-to-speech/${voice.id}/stream`); + url.searchParams.append('output_format', 'pcm_' + this.config.sampleRate); + url.searchParams.append('optimize_streaming_latency', this.config.latency.toString()); + + await fetch(url.toString(), { + method: 'POST', + headers: { + [AUTHORIZATION_HEADER]: this.config.apiKey, + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + text: this.text, + model_id: this.config.modelID, + voice_settings: this.config.voice.settings || undefined, + }), + }) + .then((data) => data.arrayBuffer()) + .then((data) => new DataView(data, 0, data.byteLength)) + .then((data) => + this.queue.push( + { + text: this.text, + data: new AudioFrame( + new Uint16Array(data.buffer), + this.config.sampleRate, + 1, + data.byteLength / 2, + ), + }, + undefined, + ), + ) + .catch(() => this.queue.push(undefined)); + } +}