From f4a4b599d79ea126efd51841a2b82b1bf5434087 Mon Sep 17 00:00:00 2001 From: aoife cassidy Date: Mon, 2 Dec 2024 15:57:58 +0200 Subject: [PATCH 1/7] chore: move to node:stream/web from homegrown API --- agents/src/llm/llm.ts | 33 +++++--- agents/src/multimodal/agent_playout.ts | 15 ++-- agents/src/pipeline/agent_output.ts | 22 ++++-- agents/src/pipeline/agent_playout.ts | 7 +- agents/src/pipeline/pipeline_agent.ts | 17 +++-- agents/src/stt/stream_adapter.ts | 9 ++- agents/src/stt/stt.ts | 64 +++++++++------- agents/src/tokenize/token_stream.ts | 26 +++++-- agents/src/tokenize/tokenizer.ts | 97 +++++++++++++++--------- agents/src/tts/stream_adapter.ts | 7 +- agents/src/tts/tts.ts | 101 +++++++++++++++---------- agents/src/utils.ts | 38 ---------- agents/src/vad.ts | 64 +++++++++------- 13 files changed, 283 insertions(+), 217 deletions(-) diff --git a/agents/src/llm/llm.ts b/agents/src/llm/llm.ts index 1d4fec5b..2c0d1329 100644 --- a/agents/src/llm/llm.ts +++ b/agents/src/llm/llm.ts @@ -3,8 +3,9 @@ // SPDX-License-Identifier: Apache-2.0 import type { TypedEventEmitter as TypedEmitter } from '@livekit/typed-emitter'; import { EventEmitter } from 'node:events'; +import type { ReadableStream } from 'node:stream/web'; +import { TransformStream } from 'node:stream/web'; import type { LLMMetrics } from '../metrics/base.js'; -import { AsyncIterableQueue } from '../utils.js'; import type { ChatContext, ChatRole } from './chat_context.js'; import type { FunctionCallInfo, FunctionContext } from './function_context.js'; @@ -59,8 +60,7 @@ export abstract class LLM extends (EventEmitter as new () => TypedEmitter { - protected output = new AsyncIterableQueue(); - protected queue = new AsyncIterableQueue(); + protected output = new TransformStream(); protected closed = false; protected _functionCalls: FunctionCallInfo[] = []; abstract label: string; @@ -68,22 +68,24 @@ export abstract class LLMStream implements AsyncIterableIterator { #llm: LLM; #chatCtx: ChatContext; #fncCtx?: FunctionContext; + #outputReadable: ReadableStream; constructor(llm: LLM, chatCtx: ChatContext, fncCtx?: FunctionContext) { this.#llm = llm; this.#chatCtx = chatCtx; this.#fncCtx = fncCtx; - this.monitorMetrics(); + const [r1, r2] = this.output.readable.tee(); + this.#outputReadable = r1; + this.monitorMetrics(r2); } - protected async monitorMetrics() { + protected async monitorMetrics(readable: ReadableStream) { const startTime = process.hrtime.bigint(); let ttft: bigint | undefined; let requestId = ''; let usage: CompletionUsage | undefined; - for await (const ev of this.queue) { - this.output.put(ev); + for await (const ev of readable) { requestId = ev.requestId; if (!ttft) { ttft = process.hrtime.bigint() - startTime; @@ -92,7 +94,6 @@ export abstract class LLMStream implements AsyncIterableIterator { usage = ev.usage; } } - this.output.close(); const duration = process.hrtime.bigint() - startTime; const metrics: LLMMetrics = { @@ -138,13 +139,21 @@ export abstract class LLMStream implements AsyncIterableIterator { return this._functionCalls; } - next(): Promise> { - return this.output.next(); + async next(): Promise> { + return this.#outputReadable + .getReader() + .read() + .then(({ value }) => { + if (value) { + return { value, done: false }; + } else { + return { value: undefined, done: true }; + } + }); } close() { - this.output.close(); - this.queue.close(); + this.output.writable.close(); this.closed = true; } diff --git a/agents/src/multimodal/agent_playout.ts b/agents/src/multimodal/agent_playout.ts index 25f4edd0..ce758e95 100644 --- a/agents/src/multimodal/agent_playout.ts +++ b/agents/src/multimodal/agent_playout.ts @@ -4,9 +4,10 @@ import type { AudioFrame } from '@livekit/rtc-node'; import { type AudioSource } from '@livekit/rtc-node'; import { EventEmitter } from 'node:events'; +import type { TransformStream } from 'node:stream/web'; import { AudioByteStream } from '../audio.js'; import type { TranscriptionForwarder } from '../transcription.js'; -import { type AsyncIterableQueue, CancellablePromise, Future, gracefullyCancel } from '../utils.js'; +import { CancellablePromise, Future, gracefullyCancel } from '../utils.js'; export const proto = {}; @@ -112,8 +113,8 @@ export class AgentPlayout extends EventEmitter { itemId: string, contentIndex: number, transcriptionFwd: TranscriptionForwarder, - textStream: AsyncIterableQueue, - audioStream: AsyncIterableQueue, + textStream: TransformStream, + audioStream: TransformStream, ): PlayoutHandle { const handle = new PlayoutHandle( this.#audioSource, @@ -129,8 +130,8 @@ export class AgentPlayout extends EventEmitter { #makePlayoutTask( oldTask: CancellablePromise | null, handle: PlayoutHandle, - textStream: AsyncIterableQueue, - audioStream: AsyncIterableQueue, + textStream: TransformStream, + audioStream: TransformStream, ): CancellablePromise { return new CancellablePromise((resolve, reject, onCancel) => { let cancelled = false; @@ -155,7 +156,7 @@ export class AgentPlayout extends EventEmitter { (async () => { try { - for await (const text of textStream) { + for await (const text of textStream.readable) { if (cancelledText || cancelled) { break; } @@ -184,7 +185,7 @@ export class AgentPlayout extends EventEmitter { samplesPerChannel, ); - for await (const frame of audioStream) { + for await (const frame of audioStream.readable) { if (cancelledCapture || cancelled) { break; } diff --git a/agents/src/pipeline/agent_output.ts b/agents/src/pipeline/agent_output.ts index 686a789b..03de4a07 100644 --- a/agents/src/pipeline/agent_output.ts +++ b/agents/src/pipeline/agent_output.ts @@ -2,9 +2,10 @@ // // SPDX-License-Identifier: Apache-2.0 import type { AudioFrame } from '@livekit/rtc-node'; +import { TransformStream } from 'node:stream/web'; import { log } from '../log.js'; import { SynthesizeStream, type TTS } from '../tts/index.js'; -import { AsyncIterableQueue, CancellablePromise, Future, gracefullyCancel } from '../utils.js'; +import { CancellablePromise, Future, gracefullyCancel } from '../utils.js'; import type { AgentPlayout, PlayoutHandle } from './agent_playout.js'; export type SpeechSource = AsyncIterable | string | Promise; @@ -17,7 +18,10 @@ export class SynthesisHandle { ttsSource: SpeechSource; #agentPlayout: AgentPlayout; tts: TTS; - queue = new AsyncIterableQueue(); + queue = new TransformStream< + AudioFrame | typeof SynthesisHandle.FLUSH_SENTINEL, + AudioFrame | typeof SynthesisHandle.FLUSH_SENTINEL + >(); #playHandle?: PlayoutHandle; intFut = new Future(); #logger = log(); @@ -51,7 +55,7 @@ export class SynthesisHandle { throw new Error('synthesis was interrupted'); } - this.#playHandle = this.#agentPlayout.play(this.#speechId, this.queue); + this.#playHandle = this.#agentPlayout.play(this.#speechId, this.queue.readable); return this.#playHandle; } @@ -134,6 +138,8 @@ const stringSynthesisTask = (text: string, handle: SynthesisHandle): Cancellable cancelled = true; }); + const writer = handle.queue.writable.getWriter(); + const ttsStream = handle.tts.stream(); ttsStream.pushText(text); ttsStream.flush(); @@ -142,9 +148,9 @@ const stringSynthesisTask = (text: string, handle: SynthesisHandle): Cancellable if (cancelled || audio === SynthesizeStream.END_OF_STREAM) { break; } - handle.queue.put(audio.frame); + writer.write(audio.frame); } - handle.queue.put(SynthesisHandle.FLUSH_SENTINEL); + writer.write(SynthesisHandle.FLUSH_SENTINEL); resolve(text); }); @@ -162,6 +168,8 @@ const streamSynthesisTask = ( cancelled = true; }); + const writer = handle.queue.writable.getWriter(); + const ttsStream = handle.tts.stream(); const readGeneratedAudio = async () => { for await (const audio of ttsStream) { @@ -169,9 +177,9 @@ const streamSynthesisTask = ( if (audio === SynthesizeStream.END_OF_STREAM) { break; } - handle.queue.put(audio.frame); + writer.write(audio.frame); } - handle.queue.put(SynthesisHandle.FLUSH_SENTINEL); + writer.write(SynthesisHandle.FLUSH_SENTINEL); }; readGeneratedAudio(); diff --git a/agents/src/pipeline/agent_playout.ts b/agents/src/pipeline/agent_playout.ts index 4793d623..754273de 100644 --- a/agents/src/pipeline/agent_playout.ts +++ b/agents/src/pipeline/agent_playout.ts @@ -4,6 +4,7 @@ import type { AudioFrame, AudioSource } from '@livekit/rtc-node'; import type { TypedEventEmitter as TypedEmitter } from '@livekit/typed-emitter'; import EventEmitter from 'node:events'; +import type { ReadableStream } from 'node:stream/web'; import { log } from '../log.js'; import { CancellablePromise, Future, gracefullyCancel } from '../utils.js'; import { SynthesisHandle } from './agent_output.js'; @@ -21,7 +22,7 @@ export type AgentPlayoutCallbacks = { export class PlayoutHandle { #speechId: string; #audioSource: AudioSource; - playoutSource: AsyncIterable; + playoutSource: ReadableStream; totalPlayedTime?: number; #interrupted = false; pushedDuration = 0; @@ -31,7 +32,7 @@ export class PlayoutHandle { constructor( speechId: string, audioSource: AudioSource, - playoutSource: AsyncIterable, + playoutSource: ReadableStream, ) { this.#speechId = speechId; this.#audioSource = audioSource; @@ -90,7 +91,7 @@ export class AgentPlayout extends (EventEmitter as new () => TypedEmitter, + playoutSource: ReadableStream, ): PlayoutHandle { if (this.#closed) { throw new Error('source closed'); diff --git a/agents/src/pipeline/pipeline_agent.ts b/agents/src/pipeline/pipeline_agent.ts index ca6c9238..c47118ff 100644 --- a/agents/src/pipeline/pipeline_agent.ts +++ b/agents/src/pipeline/pipeline_agent.ts @@ -11,6 +11,7 @@ import { } from '@livekit/rtc-node'; import type { TypedEventEmitter as TypedEmitter } from '@livekit/typed-emitter'; import EventEmitter from 'node:events'; +import { TransformStream } from 'node:stream/web'; import type { CallableFunctionResult, FunctionCallInfo, @@ -30,7 +31,7 @@ import { import type { SentenceTokenizer, WordTokenizer } from '../tokenize/tokenizer.js'; import type { TTS } from '../tts/index.js'; import { TTSEvent, StreamAdapter as TTSStreamAdapter } from '../tts/index.js'; -import { AsyncIterableQueue, CancellablePromise, Future, gracefullyCancel } from '../utils.js'; +import { CancellablePromise, Future, gracefullyCancel } from '../utils.js'; import { type VAD, type VADEvent, VADEventType } from '../vad.js'; import type { SpeechSource, SynthesisHandle } from './agent_output.js'; import { AgentOutput } from './agent_output.js'; @@ -241,7 +242,10 @@ export class VoicePipelineAgent extends (EventEmitter as new () => TypedEmitter< #transcribedText = ''; #transcribedInterimText = ''; #speechQueueOpen = new Future(); - #speechQueue = new AsyncIterableQueue(); + #speechQueue = new TransformStream< + SpeechHandle | typeof VoicePipelineAgent.FLUSH_SENTINEL, + SpeechHandle | typeof VoicePipelineAgent.FLUSH_SENTINEL + >(); #updateStateTask?: CancellablePromise; #started = false; #room?: Room; @@ -545,7 +549,7 @@ export class VoicePipelineAgent extends (EventEmitter as new () => TypedEmitter< while (true) { await this.#speechQueueOpen.await; - for await (const speech of this.#speechQueue) { + for await (const speech of this.#speechQueue.readable) { if (speech === VoicePipelineAgent.FLUSH_SENTINEL) break; this.#playingSpeech = speech; await this.#playSpeech(speech); @@ -868,7 +872,7 @@ export class VoicePipelineAgent extends (EventEmitter as new () => TypedEmitter< // in some bad timimg, we could end up with two pushed agent replies inside the speech queue. // so make sure we directly interrupt every reply when validating a new one if (this.#speechQueueOpen.done) { - for await (const speech of this.#speechQueue) { + for await (const speech of this.#speechQueue.readable) { if (speech === VoicePipelineAgent.FLUSH_SENTINEL) break; if (!speech.isReply) continue; if (speech.allowInterruptions) speech.interrupt(); @@ -920,8 +924,9 @@ export class VoicePipelineAgent extends (EventEmitter as new () => TypedEmitter< } #addSpeechForPlayout(handle: SpeechHandle) { - this.#speechQueue.put(handle); - this.#speechQueue.put(VoicePipelineAgent.FLUSH_SENTINEL); + const writer = this.#speechQueue.writable.getWriter(); + writer.write(handle); + writer.write(VoicePipelineAgent.FLUSH_SENTINEL); this.#speechQueueOpen.resolve(); } diff --git a/agents/src/stt/stream_adapter.ts b/agents/src/stt/stream_adapter.ts index a1a71014..5452d8e7 100644 --- a/agents/src/stt/stream_adapter.ts +++ b/agents/src/stt/stream_adapter.ts @@ -52,7 +52,7 @@ export class StreamAdapterWrapper extends SpeechStream { async #run() { const forwardInput = async () => { - for await (const input of this.input) { + for await (const input of this.input.readable) { if (input === SpeechStream.FLUSH_SENTINEL) { this.#vadStream.flush(); } else { @@ -63,20 +63,21 @@ export class StreamAdapterWrapper extends SpeechStream { }; const recognize = async () => { + const writer = this.output.writable.getWriter(); for await (const ev of this.#vadStream) { switch (ev.type) { case VADEventType.START_OF_SPEECH: - this.output.put({ type: SpeechEventType.START_OF_SPEECH }); + writer.write({ type: SpeechEventType.START_OF_SPEECH }); break; case VADEventType.END_OF_SPEECH: - this.output.put({ type: SpeechEventType.END_OF_SPEECH }); + writer.write({ type: SpeechEventType.END_OF_SPEECH }); const event = await this.#stt.recognize(ev.frames); if (!event.alternatives![0].text) { continue; } - this.output.put(event); + writer.write(event); break; } } diff --git a/agents/src/stt/stt.ts b/agents/src/stt/stt.ts index 42868bfe..8968dbf9 100644 --- a/agents/src/stt/stt.ts +++ b/agents/src/stt/stt.ts @@ -4,9 +4,10 @@ import type { AudioFrame } from '@livekit/rtc-node'; import type { TypedEventEmitter as TypedEmitter } from '@livekit/typed-emitter'; import { EventEmitter } from 'node:events'; +import type { ReadableStream } from 'node:stream/web'; +import { TransformStream } from 'node:stream/web'; import type { STTMetrics } from '../metrics/base.js'; import type { AudioBuffer } from '../utils.js'; -import { AsyncIterableQueue } from '../utils.js'; /** Indicates start/middle/end of speech */ export enum SpeechEventType { @@ -137,23 +138,27 @@ export abstract class STT extends (EventEmitter as new () => TypedEmitter { protected static readonly FLUSH_SENTINEL = Symbol('FLUSH_SENTINEL'); - protected input = new AsyncIterableQueue(); - protected output = new AsyncIterableQueue(); - protected queue = new AsyncIterableQueue(); + protected input = new TransformStream< + AudioFrame | typeof SpeechStream.FLUSH_SENTINEL, + AudioFrame | typeof SpeechStream.FLUSH_SENTINEL + >(); + protected output = new TransformStream(); abstract label: string; protected closed = false; #stt: STT; + #outputReadable: ReadableStream; constructor(stt: STT) { this.#stt = stt; - this.monitorMetrics(); + const [r1, r2] = this.output.readable.tee(); + this.#outputReadable = r1; + this.monitorMetrics(r2); } - protected async monitorMetrics() { + protected async monitorMetrics(readable: ReadableStream) { const startTime = process.hrtime.bigint(); - for await (const event of this.queue) { - this.output.put(event); + for await (const event of readable) { if (event.type !== SpeechEventType.RECOGNITION_USAGE) continue; const duration = process.hrtime.bigint() - startTime; const metrics: STTMetrics = { @@ -166,51 +171,58 @@ export abstract class SpeechStream implements AsyncIterableIterator }; this.#stt.emit(SpeechEventType.METRICS_COLLECTED, metrics); } - this.output.close(); } /** Push an audio frame to the STT */ pushFrame(frame: AudioFrame) { - if (this.input.closed) { - throw new Error('Input is closed'); - } + // if (this.input.closed) { + // throw new Error('Input is closed'); + // } if (this.closed) { throw new Error('Stream is closed'); } - this.input.put(frame); + this.input.writable.getWriter().write(frame); } /** Flush the STT, causing it to process all pending text */ flush() { - if (this.input.closed) { - throw new Error('Input is closed'); - } + // if (this.input.closed) { + // throw new Error('Input is closed'); + // } if (this.closed) { throw new Error('Stream is closed'); } - this.input.put(SpeechStream.FLUSH_SENTINEL); + this.input.writable.getWriter().write(SpeechStream.FLUSH_SENTINEL); } /** Mark the input as ended and forbid additional pushes */ endInput() { - if (this.input.closed) { - throw new Error('Input is closed'); - } + // if (this.input.closed) { + // throw new Error('Input is closed'); + // } if (this.closed) { throw new Error('Stream is closed'); } - this.input.close(); + this.input.writable.close(); } - next(): Promise> { - return this.output.next(); + async next(): Promise> { + return this.#outputReadable + .getReader() + .read() + .then(({ value }) => { + if (value) { + return { value, done: false }; + } else { + return { value: undefined, done: true }; + } + }); } /** Close both the input and output of the STT stream */ close() { - this.input.close(); - this.queue.close(); - this.output.close(); + this.input.writable.close(); + this.output.writable.close(); this.closed = true; } diff --git a/agents/src/tokenize/token_stream.ts b/agents/src/tokenize/token_stream.ts index 61dcf30a..90686e6c 100644 --- a/agents/src/tokenize/token_stream.ts +++ b/agents/src/tokenize/token_stream.ts @@ -2,14 +2,13 @@ // // SPDX-License-Identifier: Apache-2.0 import { randomUUID } from 'node:crypto'; -import { AsyncIterableQueue } from '../utils.js'; import type { TokenData } from './tokenizer.js'; import { SentenceStream, WordStream } from './tokenizer.js'; type TokenizeFunc = (x: string) => string[] | [string, number, number][]; export class BufferedTokenStream implements AsyncIterableIterator { - protected queue = new AsyncIterableQueue(); + protected queue = new TransformStream(); protected closed = false; #func: TokenizeFunc; @@ -34,6 +33,8 @@ export class BufferedTokenStream implements AsyncIterableIterator { throw new Error('Stream is closed'); } + const writer = this.queue.writable.getWriter(); + this.#inBuf += text; if (this.#inBuf.length < this.#minContextLength) return; @@ -51,7 +52,7 @@ export class BufferedTokenStream implements AsyncIterableIterator { this.#outBuf += tokText; if (this.#outBuf.length >= this.#minTokenLength) { - this.queue.put({ token: this.#outBuf, segmentId: this.#currentSegmentId }); + writer.write({ token: this.#outBuf, segmentId: this.#currentSegmentId }); this.#outBuf = ''; } @@ -71,6 +72,8 @@ export class BufferedTokenStream implements AsyncIterableIterator { throw new Error('Stream is closed'); } + const writer = this.queue.writable.getWriter(); + if (this.#inBuf || this.#outBuf) { const tokens = this.#func(this.#inBuf); if (tokens) { @@ -84,7 +87,7 @@ export class BufferedTokenStream implements AsyncIterableIterator { } if (this.#outBuf) { - this.queue.put({ token: this.#outBuf, segmentId: this.#currentSegmentId }); + writer.write({ token: this.#outBuf, segmentId: this.#currentSegmentId }); } this.#currentSegmentId = randomUUID(); @@ -103,13 +106,22 @@ export class BufferedTokenStream implements AsyncIterableIterator { this.close(); } - next(): Promise> { - return this.queue.next(); + async next(): Promise> { + return this.queue.readable + .getReader() + .read() + .then(({ value }) => { + if (value) { + return { value, done: false }; + } else { + return { value: undefined, done: true }; + } + }); } /** Close both the input and output of the token stream */ close() { - this.queue.close(); + this.queue.writable.close(); this.closed = true; } diff --git a/agents/src/tokenize/tokenizer.ts b/agents/src/tokenize/tokenizer.ts index 6ee38a20..14394ae4 100644 --- a/agents/src/tokenize/tokenizer.ts +++ b/agents/src/tokenize/tokenizer.ts @@ -1,7 +1,6 @@ // SPDX-FileCopyrightText: 2024 LiveKit, Inc. // // SPDX-License-Identifier: Apache-2.0 -import { AsyncIterableQueue } from '../utils.js'; // prettier-ignore export const PUNCTUATIONS = [ @@ -26,8 +25,11 @@ export abstract class SentenceTokenizer { export abstract class SentenceStream { protected static readonly FLUSH_SENTINEL = Symbol('FLUSH_SENTINEL'); - protected input = new AsyncIterableQueue(); - protected queue = new AsyncIterableQueue(); + protected input = new TransformStream< + string | typeof SentenceStream.FLUSH_SENTINEL, + string | typeof SentenceStream.FLUSH_SENTINEL + >(); + protected output = new TransformStream(); #closed = false; get closed(): boolean { @@ -36,45 +38,54 @@ export abstract class SentenceStream { /** Push a string of text to the tokenizer */ pushText(text: string) { - if (this.input.closed) { - throw new Error('Input is closed'); - } + // if (this.input.closed) { + // throw new Error('Input is closed'); + // } if (this.#closed) { throw new Error('Stream is closed'); } - this.input.put(text); + this.input.writable.getWriter().write(text); } /** Flush the tokenizer, causing it to process all pending text */ flush() { - if (this.input.closed) { - throw new Error('Input is closed'); - } + // if (this.input.closed) { + // throw new Error('Input is closed'); + // } if (this.#closed) { throw new Error('Stream is closed'); } - this.input.put(SentenceStream.FLUSH_SENTINEL); + this.input.writable.getWriter().write(SentenceStream.FLUSH_SENTINEL); } /** Mark the input as ended and forbid additional pushes */ endInput() { - if (this.input.closed) { - throw new Error('Input is closed'); - } + // if (this.input.closed) { + // throw new Error('Input is closed'); + // } if (this.#closed) { throw new Error('Stream is closed'); } - this.input.close(); + this.input.writable.close(); } - next(): Promise> { - return this.queue.next(); + async next(): Promise> { + return this.output.readable + .getReader() + .read() + .then(({ value }) => { + if (value) { + return { value, done: false }; + } else { + return { value: undefined, done: true }; + } + }); } /** Close both the input and output of the tokenizer stream */ close() { - this.input.close(); - this.queue.close(); + this.input.writable.close(); + this.output.writable.close(); this.#closed = true; } @@ -94,8 +105,11 @@ export abstract class WordTokenizer { export abstract class WordStream { protected static readonly FLUSH_SENTINEL = Symbol('FLUSH_SENTINEL'); - protected input = new AsyncIterableQueue(); - protected queue = new AsyncIterableQueue(); + protected input = new TransformStream< + string | typeof WordStream.FLUSH_SENTINEL, + string | typeof WordStream.FLUSH_SENTINEL + >(); + protected output = new TransformStream(); #closed = false; get closed(): boolean { @@ -104,45 +118,54 @@ export abstract class WordStream { /** Push a string of text to the tokenizer */ pushText(text: string) { - if (this.input.closed) { - throw new Error('Input is closed'); - } + // if (this.input.closed) { + // throw new Error('Input is closed'); + // } if (this.#closed) { throw new Error('Stream is closed'); } - this.input.put(text); + this.input.writable.getWriter().write(text); } /** Flush the tokenizer, causing it to process all pending text */ flush() { - if (this.input.closed) { - throw new Error('Input is closed'); - } + // if (this.input.closed) { + // throw new Error('Input is closed'); + // } if (this.#closed) { throw new Error('Stream is closed'); } - this.input.put(WordStream.FLUSH_SENTINEL); + this.input.writable.getWriter().write(WordStream.FLUSH_SENTINEL); } /** Mark the input as ended and forbid additional pushes */ endInput() { - if (this.input.closed) { - throw new Error('Input is closed'); - } + // if (this.input.closed) { + // throw new Error('Input is closed'); + // } if (this.#closed) { throw new Error('Stream is closed'); } - this.input.close(); + this.input.writable.close(); } - next(): Promise> { - return this.queue.next(); + async next(): Promise> { + return this.output.readable + .getReader() + .read() + .then(({ value }) => { + if (value) { + return { value, done: false }; + } else { + return { value: undefined, done: true }; + } + }); } /** Close both the input and output of the tokenizer stream */ close() { - this.input.close(); - this.queue.close(); + this.input.writable.close(); + this.output.writable.close(); this.#closed = true; } diff --git a/agents/src/tts/stream_adapter.ts b/agents/src/tts/stream_adapter.ts index e1ab402d..7fc9bbe9 100644 --- a/agents/src/tts/stream_adapter.ts +++ b/agents/src/tts/stream_adapter.ts @@ -51,7 +51,7 @@ export class StreamAdapterWrapper extends SynthesizeStream { async #run() { const forwardInput = async () => { - for await (const input of this.input) { + for await (const input of this.input.readable) { if (input === SynthesizeStream.FLUSH_SENTINEL) { this.#sentenceStream.flush(); } else { @@ -63,12 +63,13 @@ export class StreamAdapterWrapper extends SynthesizeStream { }; const synthesize = async () => { + const writer = this.output.writable.getWriter(); for await (const ev of this.#sentenceStream) { for await (const audio of this.#tts.synthesize(ev.token)) { - this.output.put(audio); + writer.write(audio); } } - this.output.put(SynthesizeStream.END_OF_STREAM); + writer.write(SynthesizeStream.END_OF_STREAM); }; Promise.all([forwardInput(), synthesize()]); diff --git a/agents/src/tts/tts.ts b/agents/src/tts/tts.ts index 7826b446..f06c1c65 100644 --- a/agents/src/tts/tts.ts +++ b/agents/src/tts/tts.ts @@ -4,8 +4,10 @@ import type { AudioFrame } from '@livekit/rtc-node'; import type { TypedEventEmitter as TypedEmitter } from '@livekit/typed-emitter'; import { EventEmitter } from 'node:events'; +import type { ReadableStream } from 'node:stream/web'; +import { TransformStream } from 'node:stream/web'; import type { TTSMetrics } from '../metrics/base.js'; -import { AsyncIterableQueue, mergeFrames } from '../utils.js'; +import { mergeFrames } from '../utils.js'; /** SynthesizedAudio is a packet of speech synthesis as returned by the TTS. */ export interface SynthesizedAudio { @@ -105,11 +107,12 @@ export abstract class SynthesizeStream { protected static readonly FLUSH_SENTINEL = Symbol('FLUSH_SENTINEL'); static readonly END_OF_STREAM = Symbol('END_OF_STREAM'); - protected input = new AsyncIterableQueue(); - protected queue = new AsyncIterableQueue< - SynthesizedAudio | typeof SynthesizeStream.END_OF_STREAM + protected input = new TransformStream< + string | typeof SynthesizeStream.FLUSH_SENTINEL, + string | typeof SynthesizeStream.FLUSH_SENTINEL >(); - protected output = new AsyncIterableQueue< + protected output = new TransformStream< + SynthesizedAudio | typeof SynthesizeStream.END_OF_STREAM, SynthesizedAudio | typeof SynthesizeStream.END_OF_STREAM >(); protected closed = false; @@ -117,13 +120,18 @@ export abstract class SynthesizeStream #tts: TTS; #metricsPendingTexts: string[] = []; #metricsText = ''; - #monitorMetricsTask?: Promise; + #outputReadable: ReadableStream; constructor(tts: TTS) { this.#tts = tts; + const [r1, r2] = this.output.readable.tee(); + this.#outputReadable = r1; + this.monitorMetrics(r2); } - protected async monitorMetrics() { + protected async monitorMetrics( + readable: ReadableStream, + ) { const startTime = process.hrtime.bigint(); let audioDuration = 0; let ttfb: bigint | undefined; @@ -148,8 +156,7 @@ export abstract class SynthesizeStream } }; - for await (const audio of this.queue) { - this.output.put(audio); + for await (const audio of readable) { if (audio === SynthesizeStream.END_OF_STREAM) continue; requestId = audio.requestId; if (!ttfb) { @@ -164,23 +171,19 @@ export abstract class SynthesizeStream if (requestId) { emit(); } - this.output.close(); } /** Push a string of text to the TTS */ pushText(text: string) { - if (!this.#monitorMetricsTask) { - this.#monitorMetricsTask = this.monitorMetrics(); - } this.#metricsText += text; - if (this.input.closed) { - throw new Error('Input is closed'); - } + // if (this.input.closed) { + // throw new Error('Input is closed'); + // } if (this.closed) { throw new Error('Stream is closed'); } - this.input.put(text); + this.input.writable.getWriter().write(text); } /** Flush the TTS, causing it to process all pending text */ @@ -189,34 +192,43 @@ export abstract class SynthesizeStream this.#metricsPendingTexts.push(this.#metricsText); this.#metricsText = ''; } - if (this.input.closed) { - throw new Error('Input is closed'); - } + // if (this.input.closed) { + // throw new Error('Input is closed'); + // } if (this.closed) { throw new Error('Stream is closed'); } - this.input.put(SynthesizeStream.FLUSH_SENTINEL); + this.input.writable.getWriter().write(SynthesizeStream.FLUSH_SENTINEL); } /** Mark the input as ended and forbid additional pushes */ endInput() { - if (this.input.closed) { - throw new Error('Input is closed'); - } + // if (this.input.closed) { + // throw new Error('Input is closed'); + // } if (this.closed) { throw new Error('Stream is closed'); } - this.input.close(); + this.input.writable.close(); } - next(): Promise> { - return this.output.next(); + async next(): Promise> { + return this.#outputReadable + .getReader() + .read() + .then(({ value }) => { + if (value) { + return { value, done: false }; + } else { + return { value: undefined, done: true }; + } + }); } /** Close both the input and output of the TTS stream */ close() { - this.input.close(); - this.output.close(); + this.input.writable.close(); + this.output.writable.close(); this.closed = true; } @@ -240,35 +252,34 @@ export abstract class SynthesizeStream * exports its own child ChunkedStream class, which inherits this class's methods. */ export abstract class ChunkedStream implements AsyncIterableIterator { - protected queue = new AsyncIterableQueue(); - protected output = new AsyncIterableQueue(); + protected output = new TransformStream(); protected closed = false; abstract label: string; #text: string; #tts: TTS; + #outputReadable: ReadableStream; constructor(text: string, tts: TTS) { this.#text = text; this.#tts = tts; - - this.monitorMetrics(); + const [r1, r2] = this.output.readable.tee(); + this.#outputReadable = r1; + this.monitorMetrics(r2); } - protected async monitorMetrics() { + protected async monitorMetrics(readable: ReadableStream) { const startTime = process.hrtime.bigint(); let audioDuration = 0; let ttfb: bigint | undefined; let requestId = ''; - for await (const audio of this.queue) { - this.output.put(audio); + for await (const audio of readable) { requestId = audio.requestId; if (!ttfb) { ttfb = process.hrtime.bigint() - startTime; } audioDuration += audio.frame.samplesPerChannel / audio.frame.sampleRate; } - this.output.close(); const duration = process.hrtime.bigint() - startTime; const metrics: TTSMetrics = { @@ -294,14 +305,22 @@ export abstract class ChunkedStream implements AsyncIterableIterator> { - return this.output.next(); + async next(): Promise> { + return this.#outputReadable + .getReader() + .read() + .then(({ value }) => { + if (value) { + return { value, done: false }; + } else { + return { value: undefined, done: true }; + } + }); } /** Close both the input and output of the TTS stream */ close() { - this.queue.close(); - this.output.close(); + this.output.writable.close(); this.closed = true; } diff --git a/agents/src/utils.ts b/agents/src/utils.ts index bd29dd18..88843155 100644 --- a/agents/src/utils.ts +++ b/agents/src/utils.ts @@ -224,44 +224,6 @@ export async function gracefullyCancel(promise: CancellablePromise): Promi } } -/** @internal */ -export class AsyncIterableQueue implements AsyncIterableIterator { - private static readonly CLOSE_SENTINEL = Symbol('CLOSE_SENTINEL'); - #queue = new Queue(); - #closed = false; - - get closed(): boolean { - return this.#closed; - } - - put(item: T): void { - if (this.#closed) { - throw new Error('Queue is closed'); - } - this.#queue.put(item); - } - - close(): void { - this.#closed = true; - this.#queue.put(AsyncIterableQueue.CLOSE_SENTINEL); - } - - async next(): Promise> { - if (this.#closed && this.#queue.items.length === 0) { - return { value: undefined, done: true }; - } - const item = await this.#queue.get(); - if (item === AsyncIterableQueue.CLOSE_SENTINEL && this.#closed) { - return { value: undefined, done: true }; - } - return { value: item as T, done: false }; - } - - [Symbol.asyncIterator](): AsyncIterableQueue { - return this; - } -} - /** @internal */ export class ExpFilter { #alpha: number; diff --git a/agents/src/vad.ts b/agents/src/vad.ts index 766bae8b..983db6a7 100644 --- a/agents/src/vad.ts +++ b/agents/src/vad.ts @@ -4,8 +4,9 @@ import type { AudioFrame } from '@livekit/rtc-node'; import type { TypedEventEmitter as TypedEmitter } from '@livekit/typed-emitter'; import { EventEmitter } from 'node:events'; +import type { ReadableStream } from 'node:stream/web'; +import { TransformStream } from 'node:stream/web'; import type { VADMetrics } from './metrics/base.js'; -import { AsyncIterableQueue } from './utils.js'; export enum VADEventType { START_OF_SPEECH, @@ -77,24 +78,28 @@ export abstract class VAD extends (EventEmitter as new () => TypedEmitter { protected static readonly FLUSH_SENTINEL = Symbol('FLUSH_SENTINEL'); - protected input = new AsyncIterableQueue(); - protected queue = new AsyncIterableQueue(); - protected output = new AsyncIterableQueue(); + protected input = new TransformStream< + AudioFrame | typeof VADStream.FLUSH_SENTINEL, + AudioFrame | typeof VADStream.FLUSH_SENTINEL + >(); + protected output = new TransformStream(); protected closed = false; #vad: VAD; #lastActivityTime = BigInt(0); + #outputReadable: ReadableStream; constructor(vad: VAD) { this.#vad = vad; - this.monitorMetrics(); + const [r1, r2] = this.output.readable.tee(); + this.#outputReadable = r1; + this.monitorMetrics(r2); } - protected async monitorMetrics() { + protected async monitorMetrics(readable: ReadableStream) { let inferenceDurationTotal = 0; let inferenceCount = 0; - for await (const event of this.queue) { - this.output.put(event); + for await (const event of readable) { switch (event.type) { case VADEventType.START_OF_SPEECH: inferenceCount++; @@ -119,47 +124,54 @@ export abstract class VADStream implements AsyncIterableIterator { break; } } - this.output.close(); } pushFrame(frame: AudioFrame) { - if (this.input.closed) { - throw new Error('Input is closed'); - } + // if (this.input.closed) { + // throw new Error('Input is closed'); + // } if (this.closed) { throw new Error('Stream is closed'); } - this.input.put(frame); + this.input.writable.getWriter().write(frame); } flush() { - if (this.input.closed) { - throw new Error('Input is closed'); - } + // if (this.input.closed) { + // throw new Error('Input is closed'); + // } if (this.closed) { throw new Error('Stream is closed'); } - this.input.put(VADStream.FLUSH_SENTINEL); + this.input.writable.getWriter().write(VADStream.FLUSH_SENTINEL); } endInput() { - if (this.input.closed) { - throw new Error('Input is closed'); - } + // if (this.input.closed) { + // throw new Error('Input is closed'); + // } if (this.closed) { throw new Error('Stream is closed'); } - this.input.close(); + this.input.writable.close(); } - next(): Promise> { - return this.output.next(); + async next(): Promise> { + return this.#outputReadable + .getReader() + .read() + .then(({ value }) => { + if (value) { + return { value, done: false }; + } else { + return { value: undefined, done: true }; + } + }); } close() { - this.input.close(); - this.queue.close(); - this.output.close(); + this.input.writable.close(); + this.output.writable.close(); this.closed = true; } From 6cd5f42cbeb48ee2db6246284a054a09efc97291 Mon Sep 17 00:00:00 2001 From: aoife cassidy Date: Mon, 2 Dec 2024 16:55:42 +0200 Subject: [PATCH 2/7] add plugin support, almost --- plugins/deepgram/src/stt.ts | 14 +++--- plugins/elevenlabs/src/tts.ts | 49 +++++++++---------- plugins/openai/src/llm.ts | 8 +-- plugins/openai/src/realtime/realtime_model.ts | 18 +++---- plugins/openai/src/tts.ts | 6 ++- plugins/silero/src/vad.ts | 9 ++-- 6 files changed, 53 insertions(+), 51 deletions(-) diff --git a/plugins/deepgram/src/stt.ts b/plugins/deepgram/src/stt.ts index 213ac457..dea54df0 100644 --- a/plugins/deepgram/src/stt.ts +++ b/plugins/deepgram/src/stt.ts @@ -191,7 +191,7 @@ export class SpeechStream extends stt.SpeechStream { samples100Ms, ); - for await (const data of this.input) { + for await (const data of this.input.readable) { let frames: AudioFrame[]; if (data === SpeechStream.FLUSH_SENTINEL) { frames = stream.flush(); @@ -225,6 +225,8 @@ export class SpeechStream extends stt.SpeechStream { }), ); + const writer = this.output.writable.getWriter() + while (!this.closed) { try { await new Promise((resolve) => { @@ -239,7 +241,7 @@ export class SpeechStream extends stt.SpeechStream { // It's also possible we receive a transcript without a SpeechStarted event. if (this.#speaking) return; this.#speaking = true; - this.queue.put({ type: stt.SpeechEventType.START_OF_SPEECH }); + writer.write({ type: stt.SpeechEventType.START_OF_SPEECH }); break; } // see this page: @@ -257,16 +259,16 @@ export class SpeechStream extends stt.SpeechStream { if (alternatives[0] && alternatives[0].text) { if (!this.#speaking) { this.#speaking = true; - this.queue.put({ type: stt.SpeechEventType.START_OF_SPEECH }); + writer.write({ type: stt.SpeechEventType.START_OF_SPEECH }); } if (isFinal) { - this.queue.put({ + writer.write({ type: stt.SpeechEventType.FINAL_TRANSCRIPT, alternatives: [alternatives[0], ...alternatives.slice(1)], }); } else { - this.queue.put({ + writer.write({ type: stt.SpeechEventType.INTERIM_TRANSCRIPT, alternatives: [alternatives[0], ...alternatives.slice(1)], }); @@ -278,7 +280,7 @@ export class SpeechStream extends stt.SpeechStream { // a non-empty transcript (deepgram doesn't have a SpeechEnded event) if (isEndpoint && this.#speaking) { this.#speaking = false; - this.queue.put({ type: stt.SpeechEventType.END_OF_SPEECH }); + writer.write({ type: stt.SpeechEventType.END_OF_SPEECH }); } break; diff --git a/plugins/elevenlabs/src/tts.ts b/plugins/elevenlabs/src/tts.ts index 1c73d59c..92426681 100644 --- a/plugins/elevenlabs/src/tts.ts +++ b/plugins/elevenlabs/src/tts.ts @@ -1,7 +1,7 @@ // SPDX-FileCopyrightText: 2024 LiveKit, Inc. // // SPDX-License-Identifier: Apache-2.0 -import { AsyncIterableQueue, AudioByteStream, log, tokenize, tts } from '@livekit/agents'; +import { AudioByteStream, log, tokenize, tts } from '@livekit/agents'; import type { AudioFrame } from '@livekit/rtc-node'; import { randomUUID } from 'node:crypto'; import { URL } from 'node:url'; @@ -142,33 +142,28 @@ export class SynthesizeStream extends tts.SynthesizeStream { } async #run() { - const segments = new AsyncIterableQueue(); - - const tokenizeInput = async () => { - let stream: tokenize.WordStream | null = null; - for await (const text of this.input) { - if (text === SynthesizeStream.FLUSH_SENTINEL) { - stream?.endInput(); - stream = null; - } else { - if (!stream) { - stream = this.#opts.wordTokenizer.stream(); - segments.put(stream); - } - stream.pushText(text); + const segments = new WritableStream({ + write: async (chunk) => { + await this.#runWS(chunk); + this.output.writable.getWriter().write(SynthesizeStream.END_OF_STREAM); + }, + }); + const writer = segments.getWriter(); + + let stream: tokenize.WordStream | null = null; + for await (const text of this.input.readable) { + if (text === SynthesizeStream.FLUSH_SENTINEL) { + stream?.endInput(); + stream = null; + } else { + if (!stream) { + stream = this.#opts.wordTokenizer.stream(); + writer.write(stream); } + stream.pushText(text); } - segments.close(); - }; - - const runStream = async () => { - for await (const stream of segments) { - await this.#runWS(stream); - this.queue.put(SynthesizeStream.END_OF_STREAM); - } - }; - - await Promise.all([tokenizeInput(), runStream()]); + } + writer.close(); this.close(); } @@ -244,7 +239,7 @@ export class SynthesizeStream extends tts.SynthesizeStream { let lastFrame: AudioFrame | undefined; const sendLastFrame = (segmentId: string, final: boolean) => { if (lastFrame) { - this.queue.put({ requestId, segmentId, frame: lastFrame, final }); + this.output.writable.getWriter().write({ requestId, segmentId, frame: lastFrame, final }); lastFrame = undefined; } }; diff --git a/plugins/openai/src/llm.ts b/plugins/openai/src/llm.ts index 19386bb6..94534b47 100644 --- a/plugins/openai/src/llm.ts +++ b/plugins/openai/src/llm.ts @@ -435,6 +435,8 @@ export class LLMStream extends llm.LLMStream { } async #run(opts: LLMOptions, n?: number, parallelToolCalls?: boolean, temperature?: number) { + const writer = this.output.writable.getWriter(); + const tools = this.fncCtx ? Object.entries(this.fncCtx).map(([name, func]) => ({ type: 'function' as const, @@ -469,12 +471,12 @@ export class LLMStream extends llm.LLMStream { for (const choice of chunk.choices) { const chatChunk = this.#parseChoice(chunk.id, choice); if (chatChunk) { - this.queue.put(chatChunk); + writer.write(chatChunk); } if (chunk.usage) { const usage = chunk.usage; - this.queue.put({ + writer.write({ requestId: chunk.id, choices: [], usage: { @@ -487,7 +489,7 @@ export class LLMStream extends llm.LLMStream { } } } finally { - this.queue.close(); + writer.close(); } } diff --git a/plugins/openai/src/realtime/realtime_model.ts b/plugins/openai/src/realtime/realtime_model.ts index 888c5b3e..77faa856 100644 --- a/plugins/openai/src/realtime/realtime_model.ts +++ b/plugins/openai/src/realtime/realtime_model.ts @@ -2,7 +2,6 @@ // // SPDX-License-Identifier: Apache-2.0 import { - AsyncIterableQueue, Future, Queue, llm, @@ -14,6 +13,7 @@ import { import { AudioFrame } from '@livekit/rtc-node'; import { once } from 'node:events'; import { WebSocket } from 'ws'; +import {TransformStream} from 'node:stream/web' import * as api_proto from './api_proto.js'; interface ModelOptions { @@ -62,8 +62,8 @@ export interface RealtimeContent { contentIndex: number; text: string; audio: AudioFrame[]; - textStream: AsyncIterableQueue; - audioStream: AsyncIterableQueue; + textStream: TransformStream; + audioStream: TransformStream; toolCalls: RealtimeToolCall[]; } @@ -1114,8 +1114,8 @@ export class RealtimeSession extends multimodal.RealtimeSession { const outputIndex = event.output_index; const output = response!.output[outputIndex]; - const textStream = new AsyncIterableQueue(); - const audioStream = new AsyncIterableQueue(); + const textStream = new TransformStream(); + const audioStream = new TransformStream(); const newContent: RealtimeContent = { responseId: responseId, @@ -1151,12 +1151,12 @@ export class RealtimeSession extends multimodal.RealtimeSession { const transcript = event.delta; content.text += transcript; - content.textStream.put(transcript); + content.textStream.writable.getWriter().write(transcript); } #handleResponseAudioTranscriptDone(event: api_proto.ResponseAudioTranscriptDoneEvent): void { const content = this.#getContent(event); - content.textStream.close(); + content.textStream.writable.getWriter().close(); } #handleResponseAudioDelta(event: api_proto.ResponseAudioDeltaEvent): void { @@ -1170,12 +1170,12 @@ export class RealtimeSession extends multimodal.RealtimeSession { ); content.audio.push(audio); - content.audioStream.put(audio); + content.audioStream.writable.getWriter().write(audio); } #handleResponseAudioDone(event: api_proto.ResponseAudioDoneEvent): void { const content = this.#getContent(event); - content.audioStream.close(); + content.audioStream.writable.getWriter().close(); } #handleResponseFunctionCallArgumentsDelta( diff --git a/plugins/openai/src/tts.ts b/plugins/openai/src/tts.ts index 7621ba65..5521baf0 100644 --- a/plugins/openai/src/tts.ts +++ b/plugins/openai/src/tts.ts @@ -92,10 +92,12 @@ export class ChunkedStream extends tts.ChunkedStream { const audioByteStream = new AudioByteStream(OPENAI_TTS_SAMPLE_RATE, OPENAI_TTS_CHANNELS); const frames = audioByteStream.write(buffer); + const writer = this.output.writable.getWriter(); + let lastFrame: AudioFrame | undefined; const sendLastFrame = (segmentId: string, final: boolean) => { if (lastFrame) { - this.queue.put({ requestId, segmentId, frame: lastFrame, final }); + writer.write({ requestId, segmentId, frame: lastFrame, final }); lastFrame = undefined; } }; @@ -106,6 +108,6 @@ export class ChunkedStream extends tts.ChunkedStream { } sendLastFrame(requestId, true); - this.queue.close(); + writer.close(); } } diff --git a/plugins/silero/src/vad.ts b/plugins/silero/src/vad.ts index e6c4f21a..6c0efd1a 100644 --- a/plugins/silero/src/vad.ts +++ b/plugins/silero/src/vad.ts @@ -103,6 +103,7 @@ export class VADStream extends baseStream { super(vad); this.#opts = opts; this.#model = model; + const writer = this.output.writable.getWriter() this.#task = new Promise(async () => { let inferenceData = new Float32Array(this.#model.windowSizeSamples); @@ -131,7 +132,7 @@ export class VADStream extends baseStream { // used to avoid drift when the sampleRate ratio is not an integer let inputCopyRemainingFrac = 0.0; - for await (const frame of this.input) { + for await (const frame of this.input.readable) { if (typeof frame === 'symbol') { continue; // ignore flush sentinel for now } @@ -229,7 +230,7 @@ export class VADStream extends baseStream { pubSilenceDuration += inferenceDuration; } - this.queue.put({ + writer.write({ type: VADEventType.INFERENCE_DONE, samplesIndex: pubCurrentSample, timestamp: pubTimestamp, @@ -278,7 +279,7 @@ export class VADStream extends baseStream { pubSilenceDuration = 0; pubSpeechDuration = speechThresholdDuration; - this.queue.put({ + writer.write({ type: VADEventType.START_OF_SPEECH, samplesIndex: pubCurrentSample, timestamp: pubTimestamp, @@ -305,7 +306,7 @@ export class VADStream extends baseStream { pubSpeechDuration = 0; pubSilenceDuration = silenceThresholdDuration; - this.queue.put({ + writer.write({ type: VADEventType.END_OF_SPEECH, samplesIndex: pubCurrentSample, timestamp: pubTimestamp, From 6a2b499dd2ffd51fb46f9be36030afc0204f17f1 Mon Sep 17 00:00:00 2001 From: aoife cassidy Date: Mon, 2 Dec 2024 17:02:30 +0200 Subject: [PATCH 3/7] Create curvy-pumpkins-rush.md --- .changeset/curvy-pumpkins-rush.md | 9 +++++++++ 1 file changed, 9 insertions(+) create mode 100644 .changeset/curvy-pumpkins-rush.md diff --git a/.changeset/curvy-pumpkins-rush.md b/.changeset/curvy-pumpkins-rush.md new file mode 100644 index 00000000..1944f6fa --- /dev/null +++ b/.changeset/curvy-pumpkins-rush.md @@ -0,0 +1,9 @@ +--- +"@livekit/agents": patch +"@livekit/agents-plugin-deepgram": patch +"@livekit/agents-plugin-elevenlabs": patch +"@livekit/agents-plugin-openai": patch +"@livekit/agents-plugin-silero": patch +--- + +chore: move to node:stream/web from homegrown API From b474b224d0bcd3322d4b0ad53001145ca250101d Mon Sep 17 00:00:00 2001 From: aoife cassidy Date: Mon, 2 Dec 2024 17:34:46 +0200 Subject: [PATCH 4/7] input closed --- agents/src/stt/stt.ts | 20 +++++++++++--------- agents/src/tts/tts.ts | 22 +++++++++++++--------- agents/src/vad.ts | 22 +++++++++++++--------- plugins/deepgram/src/stt.ts | 2 +- 4 files changed, 38 insertions(+), 28 deletions(-) diff --git a/agents/src/stt/stt.ts b/agents/src/stt/stt.ts index 8968dbf9..9ff9b052 100644 --- a/agents/src/stt/stt.ts +++ b/agents/src/stt/stt.ts @@ -145,11 +145,13 @@ export abstract class SpeechStream implements AsyncIterableIterator protected output = new TransformStream(); abstract label: string; protected closed = false; + protected inputClosed = false; #stt: STT; #outputReadable: ReadableStream; constructor(stt: STT) { this.#stt = stt; + this.output.writable.close().then(() => { this.inputClosed = true }) const [r1, r2] = this.output.readable.tee(); this.#outputReadable = r1; this.monitorMetrics(r2); @@ -175,9 +177,9 @@ export abstract class SpeechStream implements AsyncIterableIterator /** Push an audio frame to the STT */ pushFrame(frame: AudioFrame) { - // if (this.input.closed) { - // throw new Error('Input is closed'); - // } + if (this.inputClosed) { + throw new Error('Input is closed'); + } if (this.closed) { throw new Error('Stream is closed'); } @@ -186,9 +188,9 @@ export abstract class SpeechStream implements AsyncIterableIterator /** Flush the STT, causing it to process all pending text */ flush() { - // if (this.input.closed) { - // throw new Error('Input is closed'); - // } + if (this.inputClosed) { + throw new Error('Input is closed'); + } if (this.closed) { throw new Error('Stream is closed'); } @@ -197,9 +199,9 @@ export abstract class SpeechStream implements AsyncIterableIterator /** Mark the input as ended and forbid additional pushes */ endInput() { - // if (this.input.closed) { - // throw new Error('Input is closed'); - // } + if (this.inputClosed) { + throw new Error('Input is closed'); + } if (this.closed) { throw new Error('Stream is closed'); } diff --git a/agents/src/tts/tts.ts b/agents/src/tts/tts.ts index f06c1c65..aaba317e 100644 --- a/agents/src/tts/tts.ts +++ b/agents/src/tts/tts.ts @@ -116,6 +116,7 @@ export abstract class SynthesizeStream SynthesizedAudio | typeof SynthesizeStream.END_OF_STREAM >(); protected closed = false; + protected inputClosed = false; abstract label: string; #tts: TTS; #metricsPendingTexts: string[] = []; @@ -124,6 +125,9 @@ export abstract class SynthesizeStream constructor(tts: TTS) { this.#tts = tts; + this.output.writable.close().then(() => { + this.inputClosed = true; + }); const [r1, r2] = this.output.readable.tee(); this.#outputReadable = r1; this.monitorMetrics(r2); @@ -177,9 +181,9 @@ export abstract class SynthesizeStream pushText(text: string) { this.#metricsText += text; - // if (this.input.closed) { - // throw new Error('Input is closed'); - // } + if (this.inputClosed) { + throw new Error('Input is closed'); + } if (this.closed) { throw new Error('Stream is closed'); } @@ -192,9 +196,9 @@ export abstract class SynthesizeStream this.#metricsPendingTexts.push(this.#metricsText); this.#metricsText = ''; } - // if (this.input.closed) { - // throw new Error('Input is closed'); - // } + if (this.inputClosed) { + throw new Error('Input is closed'); + } if (this.closed) { throw new Error('Stream is closed'); } @@ -203,9 +207,9 @@ export abstract class SynthesizeStream /** Mark the input as ended and forbid additional pushes */ endInput() { - // if (this.input.closed) { - // throw new Error('Input is closed'); - // } + if (this.inputClosed) { + throw new Error('Input is closed'); + } if (this.closed) { throw new Error('Stream is closed'); } diff --git a/agents/src/vad.ts b/agents/src/vad.ts index 983db6a7..6c4530e0 100644 --- a/agents/src/vad.ts +++ b/agents/src/vad.ts @@ -84,12 +84,16 @@ export abstract class VADStream implements AsyncIterableIterator { >(); protected output = new TransformStream(); protected closed = false; + protected inputClosed = false; #vad: VAD; #lastActivityTime = BigInt(0); #outputReadable: ReadableStream; constructor(vad: VAD) { this.#vad = vad; + this.output.writable.close().then(() => { + this.inputClosed = true; + }); const [r1, r2] = this.output.readable.tee(); this.#outputReadable = r1; this.monitorMetrics(r2); @@ -127,9 +131,9 @@ export abstract class VADStream implements AsyncIterableIterator { } pushFrame(frame: AudioFrame) { - // if (this.input.closed) { - // throw new Error('Input is closed'); - // } + if (this.inputClosed) { + throw new Error('Input is closed'); + } if (this.closed) { throw new Error('Stream is closed'); } @@ -137,9 +141,9 @@ export abstract class VADStream implements AsyncIterableIterator { } flush() { - // if (this.input.closed) { - // throw new Error('Input is closed'); - // } + if (this.inputClosed) { + throw new Error('Input is closed'); + } if (this.closed) { throw new Error('Stream is closed'); } @@ -147,9 +151,9 @@ export abstract class VADStream implements AsyncIterableIterator { } endInput() { - // if (this.input.closed) { - // throw new Error('Input is closed'); - // } + if (this.inputClosed) { + throw new Error('Input is closed'); + } if (this.closed) { throw new Error('Stream is closed'); } diff --git a/plugins/deepgram/src/stt.ts b/plugins/deepgram/src/stt.ts index dea54df0..a8084864 100644 --- a/plugins/deepgram/src/stt.ts +++ b/plugins/deepgram/src/stt.ts @@ -113,7 +113,7 @@ export class SpeechStream extends stt.SpeechStream { async #run(maxRetry = 32) { let retries = 0; let ws: WebSocket; - while (!this.input.closed) { + while (!this.inputClosed) { const streamURL = new URL(API_BASE_URL_V1); const params = { model: this.#opts.model, From 8f6060a7050b4a6be4620ca4f87b65962ef70d98 Mon Sep 17 00:00:00 2001 From: aoife cassidy Date: Mon, 2 Dec 2024 17:36:24 +0200 Subject: [PATCH 5/7] lint --- agents/src/stt/stt.ts | 4 +++- plugins/deepgram/src/stt.ts | 2 +- plugins/openai/src/llm.ts | 2 +- plugins/openai/src/realtime/realtime_model.ts | 12 ++---------- plugins/openai/src/tts.ts | 2 +- plugins/silero/src/vad.ts | 2 +- 6 files changed, 9 insertions(+), 15 deletions(-) diff --git a/agents/src/stt/stt.ts b/agents/src/stt/stt.ts index 9ff9b052..168b102a 100644 --- a/agents/src/stt/stt.ts +++ b/agents/src/stt/stt.ts @@ -151,7 +151,9 @@ export abstract class SpeechStream implements AsyncIterableIterator constructor(stt: STT) { this.#stt = stt; - this.output.writable.close().then(() => { this.inputClosed = true }) + this.output.writable.close().then(() => { + this.inputClosed = true; + }); const [r1, r2] = this.output.readable.tee(); this.#outputReadable = r1; this.monitorMetrics(r2); diff --git a/plugins/deepgram/src/stt.ts b/plugins/deepgram/src/stt.ts index a8084864..ea2056a7 100644 --- a/plugins/deepgram/src/stt.ts +++ b/plugins/deepgram/src/stt.ts @@ -225,7 +225,7 @@ export class SpeechStream extends stt.SpeechStream { }), ); - const writer = this.output.writable.getWriter() + const writer = this.output.writable.getWriter(); while (!this.closed) { try { diff --git a/plugins/openai/src/llm.ts b/plugins/openai/src/llm.ts index 94534b47..059bf459 100644 --- a/plugins/openai/src/llm.ts +++ b/plugins/openai/src/llm.ts @@ -436,7 +436,7 @@ export class LLMStream extends llm.LLMStream { async #run(opts: LLMOptions, n?: number, parallelToolCalls?: boolean, temperature?: number) { const writer = this.output.writable.getWriter(); - + const tools = this.fncCtx ? Object.entries(this.fncCtx).map(([name, func]) => ({ type: 'function' as const, diff --git a/plugins/openai/src/realtime/realtime_model.ts b/plugins/openai/src/realtime/realtime_model.ts index 77faa856..9ae2f2d2 100644 --- a/plugins/openai/src/realtime/realtime_model.ts +++ b/plugins/openai/src/realtime/realtime_model.ts @@ -1,19 +1,11 @@ // SPDX-FileCopyrightText: 2024 LiveKit, Inc. // // SPDX-License-Identifier: Apache-2.0 -import { - Future, - Queue, - llm, - log, - mergeFrames, - metrics, - multimodal, -} from '@livekit/agents'; +import { Future, Queue, llm, log, mergeFrames, metrics, multimodal } from '@livekit/agents'; import { AudioFrame } from '@livekit/rtc-node'; import { once } from 'node:events'; +import { TransformStream } from 'node:stream/web'; import { WebSocket } from 'ws'; -import {TransformStream} from 'node:stream/web' import * as api_proto from './api_proto.js'; interface ModelOptions { diff --git a/plugins/openai/src/tts.ts b/plugins/openai/src/tts.ts index 5521baf0..eb1c2ca6 100644 --- a/plugins/openai/src/tts.ts +++ b/plugins/openai/src/tts.ts @@ -93,7 +93,7 @@ export class ChunkedStream extends tts.ChunkedStream { const frames = audioByteStream.write(buffer); const writer = this.output.writable.getWriter(); - + let lastFrame: AudioFrame | undefined; const sendLastFrame = (segmentId: string, final: boolean) => { if (lastFrame) { diff --git a/plugins/silero/src/vad.ts b/plugins/silero/src/vad.ts index 6c0efd1a..d1c6615e 100644 --- a/plugins/silero/src/vad.ts +++ b/plugins/silero/src/vad.ts @@ -103,7 +103,7 @@ export class VADStream extends baseStream { super(vad); this.#opts = opts; this.#model = model; - const writer = this.output.writable.getWriter() + const writer = this.output.writable.getWriter(); this.#task = new Promise(async () => { let inferenceData = new Float32Array(this.#model.windowSizeSamples); From 98c22e4394fc88d3bd76419b0f11212527c07de2 Mon Sep 17 00:00:00 2001 From: aoife cassidy Date: Mon, 2 Dec 2024 20:03:48 +0200 Subject: [PATCH 6/7] testing 1 --- agents/src/pipeline/agent_output.ts | 10 +++++--- agents/src/stt/stt.ts | 13 +++++----- agents/src/tts/tts.ts | 11 ++++++--- agents/src/vad.ts | 38 +++++++++++++++-------------- 4 files changed, 40 insertions(+), 32 deletions(-) diff --git a/agents/src/pipeline/agent_output.ts b/agents/src/pipeline/agent_output.ts index 03de4a07..63eba85d 100644 --- a/agents/src/pipeline/agent_output.ts +++ b/agents/src/pipeline/agent_output.ts @@ -148,9 +148,10 @@ const stringSynthesisTask = (text: string, handle: SynthesisHandle): Cancellable if (cancelled || audio === SynthesizeStream.END_OF_STREAM) { break; } - writer.write(audio.frame); + await writer.write(audio.frame); } - writer.write(SynthesisHandle.FLUSH_SENTINEL); + await writer.write(SynthesisHandle.FLUSH_SENTINEL); + writer.releaseLock(); resolve(text); }); @@ -177,9 +178,10 @@ const streamSynthesisTask = ( if (audio === SynthesizeStream.END_OF_STREAM) { break; } - writer.write(audio.frame); + await writer.write(audio.frame); } - writer.write(SynthesisHandle.FLUSH_SENTINEL); + await writer.write(SynthesisHandle.FLUSH_SENTINEL); + writer.releaseLock(); }; readGeneratedAudio(); diff --git a/agents/src/stt/stt.ts b/agents/src/stt/stt.ts index 168b102a..bb62e8a3 100644 --- a/agents/src/stt/stt.ts +++ b/agents/src/stt/stt.ts @@ -4,7 +4,7 @@ import type { AudioFrame } from '@livekit/rtc-node'; import type { TypedEventEmitter as TypedEmitter } from '@livekit/typed-emitter'; import { EventEmitter } from 'node:events'; -import type { ReadableStream } from 'node:stream/web'; +import type { ReadableStream, WritableStreamDefaultWriter } from 'node:stream/web'; import { TransformStream } from 'node:stream/web'; import type { STTMetrics } from '../metrics/base.js'; import type { AudioBuffer } from '../utils.js'; @@ -148,12 +148,11 @@ export abstract class SpeechStream implements AsyncIterableIterator protected inputClosed = false; #stt: STT; #outputReadable: ReadableStream; + #writer: WritableStreamDefaultWriter; constructor(stt: STT) { this.#stt = stt; - this.output.writable.close().then(() => { - this.inputClosed = true; - }); + this.#writer = this.input.writable.getWriter(); const [r1, r2] = this.output.readable.tee(); this.#outputReadable = r1; this.monitorMetrics(r2); @@ -185,7 +184,7 @@ export abstract class SpeechStream implements AsyncIterableIterator if (this.closed) { throw new Error('Stream is closed'); } - this.input.writable.getWriter().write(frame); + this.#writer.write(frame); } /** Flush the STT, causing it to process all pending text */ @@ -196,7 +195,7 @@ export abstract class SpeechStream implements AsyncIterableIterator if (this.closed) { throw new Error('Stream is closed'); } - this.input.writable.getWriter().write(SpeechStream.FLUSH_SENTINEL); + this.#writer.write(SpeechStream.FLUSH_SENTINEL); } /** Mark the input as ended and forbid additional pushes */ @@ -207,6 +206,7 @@ export abstract class SpeechStream implements AsyncIterableIterator if (this.closed) { throw new Error('Stream is closed'); } + this.inputClosed = true; this.input.writable.close(); } @@ -228,6 +228,7 @@ export abstract class SpeechStream implements AsyncIterableIterator this.input.writable.close(); this.output.writable.close(); this.closed = true; + this.inputClosed = true; } [Symbol.asyncIterator](): SpeechStream { diff --git a/agents/src/tts/tts.ts b/agents/src/tts/tts.ts index aaba317e..94e86ef8 100644 --- a/agents/src/tts/tts.ts +++ b/agents/src/tts/tts.ts @@ -4,7 +4,7 @@ import type { AudioFrame } from '@livekit/rtc-node'; import type { TypedEventEmitter as TypedEmitter } from '@livekit/typed-emitter'; import { EventEmitter } from 'node:events'; -import type { ReadableStream } from 'node:stream/web'; +import type { ReadableStream, WritableStreamDefaultWriter } from 'node:stream/web'; import { TransformStream } from 'node:stream/web'; import type { TTSMetrics } from '../metrics/base.js'; import { mergeFrames } from '../utils.js'; @@ -122,12 +122,14 @@ export abstract class SynthesizeStream #metricsPendingTexts: string[] = []; #metricsText = ''; #outputReadable: ReadableStream; + #writer: WritableStreamDefaultWriter; constructor(tts: TTS) { this.#tts = tts; this.output.writable.close().then(() => { this.inputClosed = true; }); + this.#writer = this.input.writable.getWriter(); const [r1, r2] = this.output.readable.tee(); this.#outputReadable = r1; this.monitorMetrics(r2); @@ -187,7 +189,7 @@ export abstract class SynthesizeStream if (this.closed) { throw new Error('Stream is closed'); } - this.input.writable.getWriter().write(text); + this.#writer.write(text); } /** Flush the TTS, causing it to process all pending text */ @@ -202,7 +204,7 @@ export abstract class SynthesizeStream if (this.closed) { throw new Error('Stream is closed'); } - this.input.writable.getWriter().write(SynthesizeStream.FLUSH_SENTINEL); + this.#writer.write(SynthesizeStream.FLUSH_SENTINEL); } /** Mark the input as ended and forbid additional pushes */ @@ -213,7 +215,7 @@ export abstract class SynthesizeStream if (this.closed) { throw new Error('Stream is closed'); } - this.input.writable.close(); + this.#writer.close(); } async next(): Promise> { @@ -231,6 +233,7 @@ export abstract class SynthesizeStream /** Close both the input and output of the TTS stream */ close() { + this.#writer.close(); this.input.writable.close(); this.output.writable.close(); this.closed = true; diff --git a/agents/src/vad.ts b/agents/src/vad.ts index 6c4530e0..6ee52418 100644 --- a/agents/src/vad.ts +++ b/agents/src/vad.ts @@ -4,7 +4,11 @@ import type { AudioFrame } from '@livekit/rtc-node'; import type { TypedEventEmitter as TypedEmitter } from '@livekit/typed-emitter'; import { EventEmitter } from 'node:events'; -import type { ReadableStream } from 'node:stream/web'; +import type { + ReadableStream, + ReadableStreamDefaultReader, + WritableStreamDefaultWriter, +} from 'node:stream/web'; import { TransformStream } from 'node:stream/web'; import type { VADMetrics } from './metrics/base.js'; @@ -87,15 +91,14 @@ export abstract class VADStream implements AsyncIterableIterator { protected inputClosed = false; #vad: VAD; #lastActivityTime = BigInt(0); - #outputReadable: ReadableStream; + #writer: WritableStreamDefaultWriter; + #reader: ReadableStreamDefaultReader; constructor(vad: VAD) { this.#vad = vad; - this.output.writable.close().then(() => { - this.inputClosed = true; - }); const [r1, r2] = this.output.readable.tee(); - this.#outputReadable = r1; + this.#reader = r1.getReader(); + this.#writer = this.input.writable.getWriter(); this.monitorMetrics(r2); } @@ -137,7 +140,7 @@ export abstract class VADStream implements AsyncIterableIterator { if (this.closed) { throw new Error('Stream is closed'); } - this.input.writable.getWriter().write(frame); + this.#writer.write(frame); } flush() { @@ -147,7 +150,8 @@ export abstract class VADStream implements AsyncIterableIterator { if (this.closed) { throw new Error('Stream is closed'); } - this.input.writable.getWriter().write(VADStream.FLUSH_SENTINEL); + this.inputClosed = true; + this.#writer.write(VADStream.FLUSH_SENTINEL); } endInput() { @@ -157,20 +161,18 @@ export abstract class VADStream implements AsyncIterableIterator { if (this.closed) { throw new Error('Stream is closed'); } + this.inputClosed = true; this.input.writable.close(); } async next(): Promise> { - return this.#outputReadable - .getReader() - .read() - .then(({ value }) => { - if (value) { - return { value, done: false }; - } else { - return { value: undefined, done: true }; - } - }); + return this.#reader.read().then(({ value }) => { + if (value) { + return { value, done: false }; + } else { + return { value: undefined, done: true }; + } + }); } close() { From ece515b9ef2bf50838972121b89531de828a0a10 Mon Sep 17 00:00:00 2001 From: aoife cassidy Date: Mon, 2 Dec 2024 21:19:21 +0200 Subject: [PATCH 7/7] testing 2 --- agents/src/llm/llm.ts | 22 +++--- agents/src/pipeline/pipeline_agent.ts | 8 ++- agents/src/stt/stt.ts | 22 +++--- agents/src/tokenize/token_stream.ts | 35 +++++----- agents/src/tokenize/tokenizer.ts | 96 ++++++++++++++------------- agents/src/tts/tts.ts | 56 +++++++--------- agents/src/vad.ts | 1 - 7 files changed, 114 insertions(+), 126 deletions(-) diff --git a/agents/src/llm/llm.ts b/agents/src/llm/llm.ts index 2c0d1329..59241d84 100644 --- a/agents/src/llm/llm.ts +++ b/agents/src/llm/llm.ts @@ -68,14 +68,14 @@ export abstract class LLMStream implements AsyncIterableIterator { #llm: LLM; #chatCtx: ChatContext; #fncCtx?: FunctionContext; - #outputReadable: ReadableStream; + #reader: ReadableStreamDefaultReader; constructor(llm: LLM, chatCtx: ChatContext, fncCtx?: FunctionContext) { this.#llm = llm; this.#chatCtx = chatCtx; this.#fncCtx = fncCtx; const [r1, r2] = this.output.readable.tee(); - this.#outputReadable = r1; + this.#reader = r1.getReader(); this.monitorMetrics(r2); } @@ -140,20 +140,16 @@ export abstract class LLMStream implements AsyncIterableIterator { } async next(): Promise> { - return this.#outputReadable - .getReader() - .read() - .then(({ value }) => { - if (value) { - return { value, done: false }; - } else { - return { value: undefined, done: true }; - } - }); + return this.#reader.read().then(({ value }) => { + if (value) { + return { value, done: false }; + } else { + return { value: undefined, done: true }; + } + }); } close() { - this.output.writable.close(); this.closed = true; } diff --git a/agents/src/pipeline/pipeline_agent.ts b/agents/src/pipeline/pipeline_agent.ts index c47118ff..87362120 100644 --- a/agents/src/pipeline/pipeline_agent.ts +++ b/agents/src/pipeline/pipeline_agent.ts @@ -255,6 +255,7 @@ export class VoicePipelineAgent extends (EventEmitter as new () => TypedEmitter< #agentPublication?: LocalTrackPublication; #lastFinalTranscriptTime?: number; #lastSpeechTime?: number; + #writer: WritableStreamDefaultWriter; constructor( /** Voice Activity Detection instance. */ @@ -289,6 +290,8 @@ export class VoicePipelineAgent extends (EventEmitter as new () => TypedEmitter< this.#validateReplyIfPossible.bind(this), this.#opts.minEndpointingDelay, ); + + this.#writer = this.#speechQueue.writable.getWriter(); } get fncCtx(): FunctionContext | undefined { @@ -924,9 +927,8 @@ export class VoicePipelineAgent extends (EventEmitter as new () => TypedEmitter< } #addSpeechForPlayout(handle: SpeechHandle) { - const writer = this.#speechQueue.writable.getWriter(); - writer.write(handle); - writer.write(VoicePipelineAgent.FLUSH_SENTINEL); + this.#writer.write(handle); + this.#writer.write(VoicePipelineAgent.FLUSH_SENTINEL); this.#speechQueueOpen.resolve(); } diff --git a/agents/src/stt/stt.ts b/agents/src/stt/stt.ts index bb62e8a3..fa3977b9 100644 --- a/agents/src/stt/stt.ts +++ b/agents/src/stt/stt.ts @@ -147,14 +147,14 @@ export abstract class SpeechStream implements AsyncIterableIterator protected closed = false; protected inputClosed = false; #stt: STT; - #outputReadable: ReadableStream; + #reader: ReadableStreamDefaultReader; #writer: WritableStreamDefaultWriter; constructor(stt: STT) { this.#stt = stt; this.#writer = this.input.writable.getWriter(); const [r1, r2] = this.output.readable.tee(); - this.#outputReadable = r1; + this.#reader = r1.getReader(); this.monitorMetrics(r2); } @@ -211,22 +211,18 @@ export abstract class SpeechStream implements AsyncIterableIterator } async next(): Promise> { - return this.#outputReadable - .getReader() - .read() - .then(({ value }) => { - if (value) { - return { value, done: false }; - } else { - return { value: undefined, done: true }; - } - }); + return this.#reader.read().then(({ value }) => { + if (value) { + return { value, done: false }; + } else { + return { value: undefined, done: true }; + } + }); } /** Close both the input and output of the STT stream */ close() { this.input.writable.close(); - this.output.writable.close(); this.closed = true; this.inputClosed = true; } diff --git a/agents/src/tokenize/token_stream.ts b/agents/src/tokenize/token_stream.ts index 90686e6c..0f899a1d 100644 --- a/agents/src/tokenize/token_stream.ts +++ b/agents/src/tokenize/token_stream.ts @@ -18,11 +18,15 @@ export class BufferedTokenStream implements AsyncIterableIterator { #inBuf = ''; #outBuf = ''; #currentSegmentId: string; + #writer: WritableStreamDefaultWriter; + #reader: ReadableStreamDefaultReader; constructor(func: TokenizeFunc, minTokenLength: number, minContextLength: number) { this.#func = func; this.#minTokenLength = minTokenLength; this.#minContextLength = minContextLength; + this.#reader = this.queue.readable.getReader(); + this.#writer = this.queue.writable.getWriter(); this.#currentSegmentId = randomUUID(); } @@ -33,8 +37,6 @@ export class BufferedTokenStream implements AsyncIterableIterator { throw new Error('Stream is closed'); } - const writer = this.queue.writable.getWriter(); - this.#inBuf += text; if (this.#inBuf.length < this.#minContextLength) return; @@ -52,7 +54,7 @@ export class BufferedTokenStream implements AsyncIterableIterator { this.#outBuf += tokText; if (this.#outBuf.length >= this.#minTokenLength) { - writer.write({ token: this.#outBuf, segmentId: this.#currentSegmentId }); + this.#writer.write({ token: this.#outBuf, segmentId: this.#currentSegmentId }); this.#outBuf = ''; } @@ -72,8 +74,6 @@ export class BufferedTokenStream implements AsyncIterableIterator { throw new Error('Stream is closed'); } - const writer = this.queue.writable.getWriter(); - if (this.#inBuf || this.#outBuf) { const tokens = this.#func(this.#inBuf); if (tokens) { @@ -87,7 +87,7 @@ export class BufferedTokenStream implements AsyncIterableIterator { } if (this.#outBuf) { - writer.write({ token: this.#outBuf, segmentId: this.#currentSegmentId }); + this.#writer.write({ token: this.#outBuf, segmentId: this.#currentSegmentId }); } this.#currentSegmentId = randomUUID(); @@ -107,22 +107,21 @@ export class BufferedTokenStream implements AsyncIterableIterator { } async next(): Promise> { - return this.queue.readable - .getReader() - .read() - .then(({ value }) => { - if (value) { - return { value, done: false }; - } else { - return { value: undefined, done: true }; - } - }); + return this.#reader.read().then(({ value }) => { + if (value) { + return { value, done: false }; + } else { + return { value: undefined, done: true }; + } + }); } /** Close both the input and output of the token stream */ close() { - this.queue.writable.close(); - this.closed = true; + if (!this.closed) { + this.#writer.close(); + this.closed = true; + } } [Symbol.asyncIterator](): BufferedTokenStream { diff --git a/agents/src/tokenize/tokenizer.ts b/agents/src/tokenize/tokenizer.ts index 14394ae4..ea5774c0 100644 --- a/agents/src/tokenize/tokenizer.ts +++ b/agents/src/tokenize/tokenizer.ts @@ -31,6 +31,9 @@ export abstract class SentenceStream { >(); protected output = new TransformStream(); #closed = false; + #inputClosed = false; + #reader = this.output.readable.getReader(); + #writer = this.input.writable.getWriter(); get closed(): boolean { return this.#closed; @@ -38,20 +41,20 @@ export abstract class SentenceStream { /** Push a string of text to the tokenizer */ pushText(text: string) { - // if (this.input.closed) { - // throw new Error('Input is closed'); - // } + if (this.#inputClosed) { + throw new Error('Input is closed'); + } if (this.#closed) { throw new Error('Stream is closed'); } - this.input.writable.getWriter().write(text); + this.#writer.write(text); } /** Flush the tokenizer, causing it to process all pending text */ flush() { - // if (this.input.closed) { - // throw new Error('Input is closed'); - // } + if (this.#inputClosed) { + throw new Error('Input is closed'); + } if (this.#closed) { throw new Error('Stream is closed'); } @@ -60,32 +63,31 @@ export abstract class SentenceStream { /** Mark the input as ended and forbid additional pushes */ endInput() { - // if (this.input.closed) { - // throw new Error('Input is closed'); - // } + if (this.#inputClosed) { + throw new Error('Input is closed'); + } if (this.#closed) { throw new Error('Stream is closed'); } - this.input.writable.close(); + this.#writer.close(); + this.#inputClosed = true; } async next(): Promise> { - return this.output.readable - .getReader() - .read() - .then(({ value }) => { - if (value) { - return { value, done: false }; - } else { - return { value: undefined, done: true }; - } - }); + return this.#reader.read().then(({ value }) => { + if (value) { + return { value, done: false }; + } else { + return { value: undefined, done: true }; + } + }); } /** Close both the input and output of the tokenizer stream */ close() { - this.input.writable.close(); - this.output.writable.close(); + if (!this.#inputClosed) { + this.endInput(); + } this.#closed = true; } @@ -110,6 +112,9 @@ export abstract class WordStream { string | typeof WordStream.FLUSH_SENTINEL >(); protected output = new TransformStream(); + #writer = this.input.writable.getWriter(); + #reader = this.output.readable.getReader(); + #inputClosed = false; #closed = false; get closed(): boolean { @@ -118,54 +123,51 @@ export abstract class WordStream { /** Push a string of text to the tokenizer */ pushText(text: string) { - // if (this.input.closed) { - // throw new Error('Input is closed'); - // } + if (this.#inputClosed) { + throw new Error('Input is closed'); + } if (this.#closed) { throw new Error('Stream is closed'); } - this.input.writable.getWriter().write(text); + this.#writer.write(text); } /** Flush the tokenizer, causing it to process all pending text */ flush() { - // if (this.input.closed) { - // throw new Error('Input is closed'); - // } + if (this.#inputClosed) { + throw new Error('Input is closed'); + } if (this.#closed) { throw new Error('Stream is closed'); } - this.input.writable.getWriter().write(WordStream.FLUSH_SENTINEL); + this.#writer.write(WordStream.FLUSH_SENTINEL); } /** Mark the input as ended and forbid additional pushes */ endInput() { - // if (this.input.closed) { - // throw new Error('Input is closed'); - // } + if (this.#inputClosed) { + throw new Error('Input is closed'); + } if (this.#closed) { throw new Error('Stream is closed'); } - this.input.writable.close(); + this.#inputClosed = true; } async next(): Promise> { - return this.output.readable - .getReader() - .read() - .then(({ value }) => { - if (value) { - return { value, done: false }; - } else { - return { value: undefined, done: true }; - } - }); + return this.#reader.read().then(({ value }) => { + if (value) { + return { value, done: false }; + } else { + return { value: undefined, done: true }; + } + }); } /** Close both the input and output of the tokenizer stream */ close() { - this.input.writable.close(); - this.output.writable.close(); + this.endInput(); + this.#writer.close(); this.#closed = true; } diff --git a/agents/src/tts/tts.ts b/agents/src/tts/tts.ts index 94e86ef8..17ba8f31 100644 --- a/agents/src/tts/tts.ts +++ b/agents/src/tts/tts.ts @@ -4,7 +4,11 @@ import type { AudioFrame } from '@livekit/rtc-node'; import type { TypedEventEmitter as TypedEmitter } from '@livekit/typed-emitter'; import { EventEmitter } from 'node:events'; -import type { ReadableStream, WritableStreamDefaultWriter } from 'node:stream/web'; +import type { + ReadableStream, + ReadableStreamDefaultReader, + WritableStreamDefaultWriter, +} from 'node:stream/web'; import { TransformStream } from 'node:stream/web'; import type { TTSMetrics } from '../metrics/base.js'; import { mergeFrames } from '../utils.js'; @@ -121,17 +125,14 @@ export abstract class SynthesizeStream #tts: TTS; #metricsPendingTexts: string[] = []; #metricsText = ''; - #outputReadable: ReadableStream; #writer: WritableStreamDefaultWriter; + #reader: ReadableStreamDefaultReader; constructor(tts: TTS) { this.#tts = tts; - this.output.writable.close().then(() => { - this.inputClosed = true; - }); this.#writer = this.input.writable.getWriter(); const [r1, r2] = this.output.readable.tee(); - this.#outputReadable = r1; + this.#reader = r1.getReader(); this.monitorMetrics(r2); } @@ -216,26 +217,22 @@ export abstract class SynthesizeStream throw new Error('Stream is closed'); } this.#writer.close(); + this.inputClosed = true; } async next(): Promise> { - return this.#outputReadable - .getReader() - .read() - .then(({ value }) => { - if (value) { - return { value, done: false }; - } else { - return { value: undefined, done: true }; - } - }); + return this.#reader.read().then(({ value }) => { + if (value) { + return { value, done: false }; + } else { + return { value: undefined, done: true }; + } + }); } /** Close both the input and output of the TTS stream */ close() { - this.#writer.close(); - this.input.writable.close(); - this.output.writable.close(); + this.endInput(); this.closed = true; } @@ -264,13 +261,13 @@ export abstract class ChunkedStream implements AsyncIterableIterator; + #reader: ReadableStreamDefaultReader; constructor(text: string, tts: TTS) { this.#text = text; this.#tts = tts; const [r1, r2] = this.output.readable.tee(); - this.#outputReadable = r1; + this.#reader = r1.getReader(); this.monitorMetrics(r2); } @@ -313,16 +310,13 @@ export abstract class ChunkedStream implements AsyncIterableIterator> { - return this.#outputReadable - .getReader() - .read() - .then(({ value }) => { - if (value) { - return { value, done: false }; - } else { - return { value: undefined, done: true }; - } - }); + return this.#reader.read().then(({ value }) => { + if (value) { + return { value, done: false }; + } else { + return { value: undefined, done: true }; + } + }); } /** Close both the input and output of the TTS stream */ diff --git a/agents/src/vad.ts b/agents/src/vad.ts index 6ee52418..3de9b50b 100644 --- a/agents/src/vad.ts +++ b/agents/src/vad.ts @@ -177,7 +177,6 @@ export abstract class VADStream implements AsyncIterableIterator { close() { this.input.writable.close(); - this.output.writable.close(); this.closed = true; }