diff --git a/agents/src/generator.ts b/agents/src/generator.ts new file mode 100644 index 00000000..6e559661 --- /dev/null +++ b/agents/src/generator.ts @@ -0,0 +1,18 @@ +// SPDX-FileCopyrightText: 2024 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import type { JobContext } from './job_context.js'; + +type entryFunction = (job: JobContext) => Promise; + +export interface Agent { + entry: entryFunction; +} + +/** + * Helper to define an agent according to the required interface. + * @example `export default defineAgent(myAgent);` + */ +export function defineAgent(entry: entryFunction): Agent { + return { entry }; +} diff --git a/agents/src/index.ts b/agents/src/index.ts index 8a4b1f64..885e9985 100644 --- a/agents/src/index.ts +++ b/agents/src/index.ts @@ -13,5 +13,6 @@ export * from './job_request.js'; export * from './worker.js'; export * from './utils.js'; export * from './log.js'; +export * from './generator.js'; export { cli, stt, tts }; diff --git a/agents/src/ipc/job_main.ts b/agents/src/ipc/job_main.ts index 34a72bb0..7d8466fa 100644 --- a/agents/src/ipc/job_main.ts +++ b/agents/src/ipc/job_main.ts @@ -10,15 +10,15 @@ import { log } from '../log.js'; import { IPC_MESSAGE, type JobMainArgs, type Message, type Ping } from './protocol.js'; export const runJob = (args: JobMainArgs): ChildProcess => { - return fork(__filename, [args.raw, args.entry, args.fallbackURL]); + return fork(import.meta.filename, [args.raw, args.entry, args.fallbackURL]); }; if (process.send) { // process.argv: - // [0] `node' or `bun' - // [1] __filename + // [0] `node' + // [1] import.meta.filename // [2] proto.JobAssignment, serialized to JSON string - // [3] __filename of function containing entry file + // [3] import.meta.filename of function containing entry file // [4] fallback URL in case JobAssignment.url is empty const msg = new ServerMessage(); @@ -54,9 +54,9 @@ if (process.send) { process.send!({ type: IPC_MESSAGE.StartJobResponse }); // here we import the file containing the exported entry function, and call it. - // the function in that file /has/ to be called [entry] and /has/ to be exported. - import(process.argv[3]).then((ext) => { - ext.entry(new JobContext(closeEvent, args.job!, room)); + // the file must export default an Agent, usually using defineAgent(). + import(process.argv[3]).then((agent) => { + agent.default.entry(new JobContext(closeEvent, args.job!, room)); }); } }; diff --git a/agents/src/ipc/job_process.ts b/agents/src/ipc/job_process.ts index 5bb7b08a..690135fd 100644 --- a/agents/src/ipc/job_process.ts +++ b/agents/src/ipc/job_process.ts @@ -7,7 +7,6 @@ import { once } from 'events'; import type { Logger } from 'pino'; import type { AcceptData } from '../job_request.js'; import { log } from '../log.js'; -import { runJob } from './job_main.js'; import { IPC_MESSAGE, type JobMainArgs, @@ -41,7 +40,7 @@ export class JobProcess { } async close() { - this.logger.info('closing job process'); + this.logger.debug('closing job process'); await this.clear(); this.process!.send({ type: IPC_MESSAGE.ShutdownRequest }); await once(this.process!, 'disconnect'); @@ -75,7 +74,7 @@ export class JobProcess { this.clear(); }, PING_TIMEOUT); - this.process = runJob(this.args); + this.process = await import('./job_main.js').then((main) => main.runJob(this.args)); this.process.on('message', (msg: Message) => { if (msg.type === IPC_MESSAGE.StartJobResponse) { @@ -87,7 +86,7 @@ export class JobProcess { } this.pongTimeout?.refresh(); } else if (msg.type === IPC_MESSAGE.UserExit || msg.type === IPC_MESSAGE.ShutdownResponse) { - this.logger.info('job exiting'); + this.logger.debug('job exiting'); this.clear(); } }); diff --git a/agents/src/tts/index.ts b/agents/src/tts/index.ts index cc1d6e06..ac236d28 100644 --- a/agents/src/tts/index.ts +++ b/agents/src/tts/index.ts @@ -3,6 +3,7 @@ // SPDX-License-Identifier: Apache-2.0 import { StreamAdapter, StreamAdapterWrapper } from './stream_adapter.js'; import { + ChunkedStream, SynthesisEvent, SynthesisEventType, SynthesizeStream, @@ -18,4 +19,5 @@ export { SynthesizeStream, StreamAdapter, StreamAdapterWrapper, + ChunkedStream, }; diff --git a/agents/src/tts/stream_adapter.ts b/agents/src/tts/stream_adapter.ts index 7c0f4c17..9f1abc30 100644 --- a/agents/src/tts/stream_adapter.ts +++ b/agents/src/tts/stream_adapter.ts @@ -2,13 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 import type { SentenceStream, SentenceTokenizer } from '../tokenize.js'; -import { - SynthesisEvent, - SynthesisEventType, - SynthesizeStream, - type SynthesizedAudio, - TTS, -} from './tts.js'; +import { ChunkedStream, SynthesisEvent, SynthesisEventType, SynthesizeStream, TTS } from './tts.js'; export class StreamAdapterWrapper extends SynthesizeStream { closed: boolean; @@ -41,10 +35,12 @@ export class StreamAdapterWrapper extends SynthesizeStream { reject(new Error('cancelled')); }; for await (const sentence of this.sentenceStream) { - const audio = await this.tts.synthesize(sentence.text); - this.eventQueue.push(new SynthesisEvent(SynthesisEventType.STARTED)); - this.eventQueue.push(new SynthesisEvent(SynthesisEventType.AUDIO, audio)); - this.eventQueue.push(new SynthesisEvent(SynthesisEventType.FINISHED)); + const audio = await this.tts.synthesize(sentence.text).then((data) => data.next()); + if (!audio.done) { + this.eventQueue.push(new SynthesisEvent(SynthesisEventType.STARTED)); + this.eventQueue.push(new SynthesisEvent(SynthesisEventType.AUDIO, audio.value)); + this.eventQueue.push(new SynthesisEvent(SynthesisEventType.FINISHED)); + } } } } @@ -86,7 +82,7 @@ export class StreamAdapter extends TTS { this.tokenizer = tokenizer; } - synthesize(text: string): Promise { + synthesize(text: string): Promise { return this.tts.synthesize(text); } diff --git a/agents/src/tts/tts.ts b/agents/src/tts/tts.ts index 2c5bf19e..0f88a8ab 100644 --- a/agents/src/tts/tts.ts +++ b/agents/src/tts/tts.ts @@ -2,6 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 import type { AudioFrame } from '@livekit/rtc-node'; +import { mergeFrames } from '../utils.js'; export interface SynthesizedAudio { text: string; @@ -46,7 +47,7 @@ export abstract class TTS { this.#streamingSupported = streamingSupported; } - abstract synthesize(text: string): Promise; + abstract synthesize(text: string): Promise; abstract stream(): SynthesizeStream; @@ -54,3 +55,24 @@ export abstract class TTS { return this.#streamingSupported; } } + +export abstract class ChunkedStream implements AsyncIterableIterator { + async collect(): Promise { + const frames = []; + for await (const ev of this) { + frames.push(ev.data); + } + return mergeFrames(frames); + } + + abstract close(): Promise; + abstract next(): Promise>; + + [Symbol.iterator](): ChunkedStream { + return this; + } + + [Symbol.asyncIterator](): ChunkedStream { + return this; + } +} diff --git a/agents/src/worker.ts b/agents/src/worker.ts index 4a043623..8ab21bb5 100644 --- a/agents/src/worker.ts +++ b/agents/src/worker.ts @@ -216,8 +216,8 @@ export class Worker { this.processes[job.id] = { proc, activeJob: new ActiveJob(job, acceptData) }; proc .run() - .catch(() => { - proc.logger.error(`error running job process ${proc.job.id}`); + .catch((e) => { + proc.logger.error(`error running job process ${proc.job.id}: ${e}`); }) .finally(() => { proc.clear(); @@ -374,7 +374,7 @@ export class Worker { async close() { if (this.closed) return; this.closed = true; - this.logger.info('shutting down worker'); + this.logger.debug('shutting down worker'); await this.httpServer.close(); for await (const value of Object.values(this.processes)) { await value.proc.close(); diff --git a/examples/src/minimal.ts b/examples/src/minimal.ts index 8602072d..055a77ea 100644 --- a/examples/src/minimal.ts +++ b/examples/src/minimal.ts @@ -1,23 +1,22 @@ // SPDX-FileCopyrightText: 2024 LiveKit, Inc. // // SPDX-License-Identifier: Apache-2.0 -import { type JobContext, type JobRequest, WorkerOptions, cli } from '@livekit/agents'; +import { type JobContext, type JobRequest, WorkerOptions, cli, defineAgent } from '@livekit/agents'; import { fileURLToPath } from 'url'; -// your entry file *has* to include an exported function [entry]. -// this file will be imported from inside the library, and this function -// will be called. -export const entry = async (job: JobContext) => { - console.log('starting voice assistant...'); - job; - // etc -}; - const requestFunc = async (req: JobRequest) => { console.log('received request', req); - await req.accept(__filename); + await req.accept(import.meta.filename); }; if (process.argv[1] === fileURLToPath(import.meta.url)) { cli.runApp(new WorkerOptions({ requestFunc })); } + +// your entry file has to provide a default export of type Agent. +// use the defineAgent() helper function to generate your agent. +export default defineAgent(async (job: JobContext) => { + console.log('starting voice assistant...'); + job; + // etc +}); diff --git a/examples/src/tts.ts b/examples/src/tts.ts index a687411a..e89e2bfa 100644 --- a/examples/src/tts.ts +++ b/examples/src/tts.ts @@ -1,12 +1,27 @@ // SPDX-FileCopyrightText: 2024 LiveKit, Inc. // // SPDX-License-Identifier: Apache-2.0 -import { type JobContext, type JobRequest, WorkerOptions, cli, log } from '@livekit/agents'; +import { + type JobContext, + type JobRequest, + WorkerOptions, + cli, + defineAgent, + log, +} from '@livekit/agents'; import { TTS } from '@livekit/agents-plugin-elevenlabs'; import { AudioSource, LocalAudioTrack, TrackPublishOptions, TrackSource } from '@livekit/rtc-node'; import { fileURLToPath } from 'url'; -export const entry = async (job: JobContext) => { +const requestFunc = async (req: JobRequest) => { + await req.accept(import.meta.filename); +}; + +if (process.argv[1] === fileURLToPath(import.meta.url)) { + cli.runApp(new WorkerOptions({ requestFunc })); +} + +export default defineAgent(async (job: JobContext) => { log.info('starting TTS example agent'); const source = new AudioSource(24000, 1); @@ -17,22 +32,20 @@ export const entry = async (job: JobContext) => { const tts = new TTS(); log.info('speaking "Hello!"'); - await tts.synthesize('Hello!').then((output) => { - source.captureFrame(output.data); - }); + await tts + .synthesize('Hello!') + .then((output) => output.collect()) + .then((output) => { + source.captureFrame(output); + }); await new Promise((resolve) => setTimeout(resolve, 1000)); log.info('speaking "Goodbye."'); - await tts.synthesize('Goodbye.').then((output) => { - source.captureFrame(output.data); - }); -}; - -const requestFunc = async (req: JobRequest) => { - await req.accept(__filename); -}; - -if (process.argv[1] === fileURLToPath(import.meta.url)) { - cli.runApp(new WorkerOptions({ requestFunc })); -} + await tts + .synthesize('Goodbye.') + .then((output) => output.collect()) + .then((output) => { + source.captureFrame(output); + }); +}); diff --git a/plugins/elevenlabs/src/tts.ts b/plugins/elevenlabs/src/tts.ts index d3352872..e1172445 100644 --- a/plugins/elevenlabs/src/tts.ts +++ b/plugins/elevenlabs/src/tts.ts @@ -98,36 +98,8 @@ export class TTS extends tts.TTS { }); } - synthesize(text: string): Promise { - const voice = this.config.voice; - - return new Promise((resolve) => { - const url = new URL(`${this.config.baseURL}/text-to-speech/${voice.id}`); - url.searchParams.append('output_format', 'pcm_' + this.config.sampleRate); - - fetch(url.toString(), { - method: 'POST', - headers: { [AUTHORIZATION_HEADER]: this.config.apiKey }, - body: JSON.stringify({ - text, - model_id: this.config.modelID, - voice_settings: this.config.voice.settings || undefined, - }), - }) - .then((data) => data.arrayBuffer()) - .then((data) => new DataView(data, 0, data.byteLength)) - .then((data) => - resolve({ - text, - data: new AudioFrame( - new Uint16Array(data.buffer), - this.config.sampleRate, - 1, - data.byteLength, - ), - }), - ); - }); + async synthesize(text: string): Promise { + return new ChunkedStream(text, this.config); } stream(): tts.SynthesizeStream { @@ -293,3 +265,67 @@ export class SynthesizeStream extends tts.SynthesizeStream { } } } + +class ChunkedStream extends tts.ChunkedStream { + config: TTSOptions; + text: string; + queue: (tts.SynthesizedAudio | undefined)[] = []; + + constructor(text: string, config: TTSOptions) { + super(); + this.config = config; + this.text = text; + } + + async next(): Promise> { + await this.run(); + const audio = this.queue.shift(); + if (audio) { + return { done: false, value: audio }; + } else { + return { done: true, value: undefined }; + } + } + + async close() { + this.queue.push(undefined); + } + + async run() { + const voice = this.config.voice; + + const url = new URL(`${this.config.baseURL}/text-to-speech/${voice.id}/stream`); + url.searchParams.append('output_format', 'pcm_' + this.config.sampleRate); + url.searchParams.append('optimize_streaming_latency', this.config.latency.toString()); + + await fetch(url.toString(), { + method: 'POST', + headers: { + [AUTHORIZATION_HEADER]: this.config.apiKey, + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + text: this.text, + model_id: this.config.modelID, + voice_settings: this.config.voice.settings || undefined, + }), + }) + .then((data) => data.arrayBuffer()) + .then((data) => new DataView(data, 0, data.byteLength)) + .then((data) => + this.queue.push( + { + text: this.text, + data: new AudioFrame( + new Uint16Array(data.buffer), + this.config.sampleRate, + 1, + data.byteLength / 2, + ), + }, + undefined, + ), + ) + .catch(() => this.queue.push(undefined)); + } +}