Skip to content

Commit

Permalink
Add ChunkedStream
Browse files Browse the repository at this point in the history
  • Loading branch information
nbsp committed May 20, 2024
1 parent f2445c0 commit c156d3c
Show file tree
Hide file tree
Showing 9 changed files with 124 additions and 63 deletions.
10 changes: 5 additions & 5 deletions agents/src/ipc/job_main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@ import { log } from '../log.js';
import { IPC_MESSAGE, type JobMainArgs, type Message, type Ping } from './protocol.js';

export const runJob = (args: JobMainArgs): ChildProcess => {
return fork(__filename, [args.raw, args.entry, args.fallbackURL]);
return fork(import.meta.filename, [args.raw, args.entry, args.fallbackURL]);
};

if (process.send) {
// process.argv:
// [0] `node' or `bun'
// [1] __filename
// [0] `node'
// [1] import.meta.filename
// [2] proto.JobAssignment, serialized to JSON string
// [3] __filename of function containing entry file
// [3] import.meta.filename of function containing entry file
// [4] fallback URL in case JobAssignment.url is empty

const msg = new ServerMessage();
Expand Down Expand Up @@ -56,7 +56,7 @@ if (process.send) {
// here we import the file containing the exported entry function, and call it.
// the file must export default an Agent, usually using defineAgent().
import(process.argv[3]).then((agent) => {
agent.entry(new JobContext(closeEvent, args.job!, room));
agent.default.entry(new JobContext(closeEvent, args.job!, room));
});
}
};
Expand Down
7 changes: 3 additions & 4 deletions agents/src/ipc/job_process.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import { once } from 'events';
import type { Logger } from 'pino';
import type { AcceptData } from '../job_request.js';
import { log } from '../log.js';
import { runJob } from './job_main.js';
import {
IPC_MESSAGE,
type JobMainArgs,
Expand Down Expand Up @@ -41,7 +40,7 @@ export class JobProcess {
}

async close() {
this.logger.info('closing job process');
this.logger.debug('closing job process');
await this.clear();
this.process!.send({ type: IPC_MESSAGE.ShutdownRequest });
await once(this.process!, 'disconnect');
Expand Down Expand Up @@ -75,7 +74,7 @@ export class JobProcess {
this.clear();
}, PING_TIMEOUT);

this.process = runJob(this.args);
this.process = await import('./job_main.js').then((main) => main.runJob(this.args));

this.process.on('message', (msg: Message) => {
if (msg.type === IPC_MESSAGE.StartJobResponse) {
Expand All @@ -87,7 +86,7 @@ export class JobProcess {
}
this.pongTimeout?.refresh();
} else if (msg.type === IPC_MESSAGE.UserExit || msg.type === IPC_MESSAGE.ShutdownResponse) {
this.logger.info('job exiting');
this.logger.debug('job exiting');
this.clear();
}
});
Expand Down
2 changes: 2 additions & 0 deletions agents/src/tts/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// SPDX-License-Identifier: Apache-2.0
import { StreamAdapter, StreamAdapterWrapper } from './stream_adapter.js';
import {
ChunkedStream,
SynthesisEvent,
SynthesisEventType,
SynthesizeStream,
Expand All @@ -18,4 +19,5 @@ export {
SynthesizeStream,
StreamAdapter,
StreamAdapterWrapper,
ChunkedStream,
};
20 changes: 8 additions & 12 deletions agents/src/tts/stream_adapter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,7 @@
//
// SPDX-License-Identifier: Apache-2.0
import type { SentenceStream, SentenceTokenizer } from '../tokenize.js';
import {
SynthesisEvent,
SynthesisEventType,
SynthesizeStream,
type SynthesizedAudio,
TTS,
} from './tts.js';
import { ChunkedStream, SynthesisEvent, SynthesisEventType, SynthesizeStream, TTS } from './tts.js';

export class StreamAdapterWrapper extends SynthesizeStream {
closed: boolean;
Expand Down Expand Up @@ -41,10 +35,12 @@ export class StreamAdapterWrapper extends SynthesizeStream {
reject(new Error('cancelled'));
};
for await (const sentence of this.sentenceStream) {
const audio = await this.tts.synthesize(sentence.text);
this.eventQueue.push(new SynthesisEvent(SynthesisEventType.STARTED));
this.eventQueue.push(new SynthesisEvent(SynthesisEventType.AUDIO, audio));
this.eventQueue.push(new SynthesisEvent(SynthesisEventType.FINISHED));
const audio = await this.tts.synthesize(sentence.text).then((data) => data.next());
if (!audio.done) {
this.eventQueue.push(new SynthesisEvent(SynthesisEventType.STARTED));
this.eventQueue.push(new SynthesisEvent(SynthesisEventType.AUDIO, audio.value));
this.eventQueue.push(new SynthesisEvent(SynthesisEventType.FINISHED));
}
}
}
}
Expand Down Expand Up @@ -86,7 +82,7 @@ export class StreamAdapter extends TTS {
this.tokenizer = tokenizer;
}

synthesize(text: string): Promise<SynthesizedAudio> {
synthesize(text: string): Promise<ChunkedStream> {
return this.tts.synthesize(text);
}

Expand Down
24 changes: 23 additions & 1 deletion agents/src/tts/tts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
//
// SPDX-License-Identifier: Apache-2.0
import type { AudioFrame } from '@livekit/rtc-node';
import { mergeFrames } from '../utils.js';

export interface SynthesizedAudio {
text: string;
Expand Down Expand Up @@ -46,11 +47,32 @@ export abstract class TTS {
this.#streamingSupported = streamingSupported;
}

abstract synthesize(text: string): Promise<SynthesizedAudio>;
abstract synthesize(text: string): Promise<ChunkedStream>;

abstract stream(): SynthesizeStream;

get streamingSupported(): boolean {
return this.#streamingSupported;
}
}

export abstract class ChunkedStream implements AsyncIterableIterator<SynthesizedAudio> {
async collect(): Promise<AudioFrame> {
const frames = [];
for await (const ev of this) {
frames.push(ev.data);
}
return mergeFrames(frames);
}

abstract close(): Promise<void>;
abstract next(): Promise<IteratorResult<SynthesizedAudio>>;

[Symbol.iterator](): ChunkedStream {
return this;
}

[Symbol.asyncIterator](): ChunkedStream {
return this;
}
}
6 changes: 3 additions & 3 deletions agents/src/worker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,8 @@ export class Worker {
this.processes[job.id] = { proc, activeJob: new ActiveJob(job, acceptData) };
proc
.run()
.catch(() => {
proc.logger.error(`error running job process ${proc.job.id}`);
.catch((e) => {
proc.logger.error(`error running job process ${proc.job.id}: ${e}`);
})
.finally(() => {
proc.clear();
Expand Down Expand Up @@ -374,7 +374,7 @@ export class Worker {
async close() {
if (this.closed) return;
this.closed = true;
this.logger.info('shutting down worker');
this.logger.debug('shutting down worker');
await this.httpServer.close();
for await (const value of Object.values(this.processes)) {
await value.proc.close();
Expand Down
2 changes: 1 addition & 1 deletion examples/src/minimal.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import { fileURLToPath } from 'url';

const requestFunc = async (req: JobRequest) => {
console.log('received request', req);
await req.accept(__filename);
await req.accept(import.meta.filename);
};

if (process.argv[1] === fileURLToPath(import.meta.url)) {
Expand Down
20 changes: 13 additions & 7 deletions examples/src/tts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import { AudioSource, LocalAudioTrack, TrackPublishOptions, TrackSource } from '
import { fileURLToPath } from 'url';

const requestFunc = async (req: JobRequest) => {
await req.accept(__filename);
await req.accept(import.meta.filename);
};

if (process.argv[1] === fileURLToPath(import.meta.url)) {
Expand All @@ -32,14 +32,20 @@ export default defineAgent(async (job: JobContext) => {

const tts = new TTS();
log.info('speaking "Hello!"');
await tts.synthesize('Hello!').then((output) => {
source.captureFrame(output.data);
});
await tts
.synthesize('Hello!')
.then((output) => output.collect())
.then((output) => {
source.captureFrame(output);
});

await new Promise((resolve) => setTimeout(resolve, 1000));

log.info('speaking "Goodbye."');
await tts.synthesize('Goodbye.').then((output) => {
source.captureFrame(output.data);
});
await tts
.synthesize('Goodbye.')
.then((output) => output.collect())
.then((output) => {
source.captureFrame(output);
});
});
96 changes: 66 additions & 30 deletions plugins/elevenlabs/src/tts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -98,36 +98,8 @@ export class TTS extends tts.TTS {
});
}

synthesize(text: string): Promise<tts.SynthesizedAudio> {
const voice = this.config.voice;

return new Promise<tts.SynthesizedAudio>((resolve) => {
const url = new URL(`${this.config.baseURL}/text-to-speech/${voice.id}`);
url.searchParams.append('output_format', 'pcm_' + this.config.sampleRate);

fetch(url.toString(), {
method: 'POST',
headers: { [AUTHORIZATION_HEADER]: this.config.apiKey },
body: JSON.stringify({
text,
model_id: this.config.modelID,
voice_settings: this.config.voice.settings || undefined,
}),
})
.then((data) => data.arrayBuffer())
.then((data) => new DataView(data, 0, data.byteLength))
.then((data) =>
resolve({
text,
data: new AudioFrame(
new Uint16Array(data.buffer),
this.config.sampleRate,
1,
data.byteLength,
),
}),
);
});
async synthesize(text: string): Promise<tts.ChunkedStream> {
return new ChunkedStream(text, this.config);
}

stream(): tts.SynthesizeStream {
Expand Down Expand Up @@ -293,3 +265,67 @@ export class SynthesizeStream extends tts.SynthesizeStream {
}
}
}

class ChunkedStream extends tts.ChunkedStream {
config: TTSOptions;
text: string;
queue: (tts.SynthesizedAudio | undefined)[] = [];

constructor(text: string, config: TTSOptions) {
super();
this.config = config;
this.text = text;
}

async next(): Promise<IteratorResult<tts.SynthesizedAudio>> {
await this.run();
const audio = this.queue.shift();
if (audio) {
return { done: false, value: audio };
} else {
return { done: true, value: undefined };
}
}

async close() {
this.queue.push(undefined);
}

async run() {
const voice = this.config.voice;

const url = new URL(`${this.config.baseURL}/text-to-speech/${voice.id}/stream`);
url.searchParams.append('output_format', 'pcm_' + this.config.sampleRate);
url.searchParams.append('optimize_streaming_latency', this.config.latency.toString());

await fetch(url.toString(), {
method: 'POST',
headers: {
[AUTHORIZATION_HEADER]: this.config.apiKey,
'Content-Type': 'application/json',
},
body: JSON.stringify({
text: this.text,
model_id: this.config.modelID,
voice_settings: this.config.voice.settings || undefined,
}),
})
.then((data) => data.arrayBuffer())
.then((data) => new DataView(data, 0, data.byteLength))
.then((data) =>
this.queue.push(
{
text: this.text,
data: new AudioFrame(
new Uint16Array(data.buffer),
this.config.sampleRate,
1,
data.byteLength / 2,
),
},
undefined,
),
)
.catch(() => this.queue.push(undefined));
}
}

0 comments on commit c156d3c

Please sign in to comment.