Skip to content

Commit

Permalink
feat: give rustbot a voice
Browse files Browse the repository at this point in the history
We introduce a tts module that handles TTS tasks
of rustbot. We had to revamp some thing around such as the llm having to
send the chunks to both to TTS and the jet.Reader.
We're using rodio for playing the speech.

Signed-off-by: Milos Gajdos <[email protected]>
  • Loading branch information
milosgajdos committed Apr 17, 2024
1 parent a4347c6 commit 992dec6
Show file tree
Hide file tree
Showing 7 changed files with 239 additions and 10 deletions.
3 changes: 3 additions & 0 deletions rustbot/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,6 @@ rand = "0.8"
ollama-rs = { version = "0.1", features = ["stream"] }
bytes = { version = "1", features = ["serde"] }
clap = { version = "4.5.4", features = ["derive"] }
#playht_rs = "0.1.0"
playht_rs = { path = "/Users/milosgajdos/rust/github.com/milosgajdos/playht_rs" }
rodio = "0.17.3"
52 changes: 52 additions & 0 deletions rustbot/src/buffer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
use bytes::{BufMut, Bytes, BytesMut};
use std::error::Error;
use std::fmt;

#[derive(Debug)]
pub struct BufferFullError {
pub bytes_written: usize,
}

impl fmt::Display for BufferFullError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "buffer is full, {} bytes written", self.bytes_written)
}
}

impl Error for BufferFullError {}

pub struct Buffer {
buffer: BytesMut,
max_size: usize,
}

impl Buffer {
pub fn new(max_size: usize) -> Self {
Buffer {
buffer: BytesMut::with_capacity(max_size),
max_size,
}
}

pub fn write(&mut self, data: &[u8]) -> Result<usize, BufferFullError> {
let available = self.max_size - self.buffer.len();
let write_len = std::cmp::min(data.len(), available);

self.buffer.put_slice(&data[..write_len]);

if self.buffer.len() == self.max_size {
return Err(BufferFullError {
bytes_written: write_len,
});
}
Ok(write_len)
}

pub fn reset(&mut self) {
self.buffer.clear();
}

pub fn as_bytes(&self) -> Bytes {
self.buffer.clone().freeze()
}
}
8 changes: 8 additions & 0 deletions rustbot/src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ pub struct App {
pub llm: LLM,
#[command(flatten)]
pub bot: Bot,
#[command(flatten)]
pub tts: TTS,
}

#[derive(Args, Debug)]
Expand Down Expand Up @@ -40,3 +42,9 @@ pub struct Bot {
#[arg(short = 'b', long, default_value = BOT_SUB_SUBJECT, help = "jetstream subscribe subject")]
pub sub_subject: String,
}

#[derive(Args, Debug)]
pub struct TTS {
#[arg(short, default_value = DEFAULT_VOICE_ID, help = "bot name")]
pub voice_id: String,
}
24 changes: 22 additions & 2 deletions rustbot/src/llm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use tokio::{
self,
sync::mpsc::{Receiver, Sender},
sync::watch,
task::JoinHandle,
};
use tokio_stream::StreamExt;

Expand Down Expand Up @@ -46,7 +47,8 @@ impl LLM {
pub async fn stream(
self,
mut prompts: Receiver<String>,
chunks: Sender<Bytes>,
jet_chunks: Sender<Bytes>,
tts_chunks: Sender<Bytes>,
mut done: watch::Receiver<bool>,
) -> Result<()> {
println!("launching LLM stream");
Expand Down Expand Up @@ -77,7 +79,25 @@ impl LLM {
while let Some(res) = stream.next().await {
let responses = res?;
for resp in responses {
chunks.send(Bytes::from(resp.response)).await?;
let resp_bytes = Bytes::from(resp.response);
let jet_bytes = resp_bytes.clone();
let jet_ch = jet_chunks.clone();
let jet_task: JoinHandle<Result<()>> = tokio::spawn(async move {
jet_ch.send(Bytes::from(jet_bytes)).await?;
Ok(())
});
let tts_bytes = resp_bytes.clone();
let tts_ch = tts_chunks.clone();
let tts_task: JoinHandle<Result<()>> = tokio::spawn(async move {
tts_ch.send(Bytes::from(tts_bytes)).await?;
Ok(())
});
match tokio::try_join!(jet_task, tts_task) {
Ok(_) => {}
Err(e) => {
return Err(Box::new(e));
}
}
}
}
},
Expand Down
68 changes: 60 additions & 8 deletions rustbot/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
use bytes::Bytes;
use bytes::{Bytes, BytesMut};
use clap::Parser;
use prelude::*;
use rodio::{Decoder, OutputStream, Sink};
use std::io::Cursor;
use tokio::{
self, signal,
self,
io::{self, AsyncReadExt},
signal,
sync::{mpsc, watch},
task::JoinHandle,
};

mod buffer;
mod cli;
mod history;
mod jet;
mod llm;
mod prelude;
mod tts;

#[tokio::main]
async fn main() -> Result<()> {
Expand All @@ -21,7 +27,7 @@ async fn main() -> Result<()> {
let seed_prompt = args.prompt.seed.unwrap();
let prompt = system_prompt + "\n" + &seed_prompt;

// NOTE: we could also add Stream::builder to jet module
// NOTE: we could also add Stream::builder to the jet module
// and instead of passing config we could build it by chaining methods.
let c = jet::Config {
durable_name: args.bot.name,
Expand All @@ -32,7 +38,7 @@ async fn main() -> Result<()> {
};
let s = jet::Stream::new(c).await?;

// NOTE: we could also add LLM::builder to llm module
// NOTE: we could also add LLM::builder to the llm module
// and instead of passing config we could build it by chaining methods.
let c = llm::Config {
hist_size: args.llm.hist_size,
Expand All @@ -42,18 +48,34 @@ async fn main() -> Result<()> {
};
let l = llm::LLM::new(c);

// NOTE: we could also add TTS::builder to the tts module
// and instead of passing config we could build it by chaining methods.
let c = tts::Config {
voice_id: Some(args.tts.voice_id),
..tts::Config::default()
};
let t = tts::TTS::new(c);

let (prompts_tx, prompts_rx) = mpsc::channel::<String>(32);
let (chunks_tx, chunks_rx) = mpsc::channel::<Bytes>(32);
let (jet_chunks_tx, jet_chunks_rx) = mpsc::channel::<Bytes>(32);
let (tts_chunks_tx, tts_chunks_rx) = mpsc::channel::<Bytes>(32);

// NOTE: used for cancellation when SIGINT is trapped.
let (watch_tx, watch_rx) = watch::channel(false);
let jet_wr_watch_rx = watch_rx.clone();
let jet_rd_watch_rx = watch_rx.clone();
let tts_watch_rx = watch_rx.clone();

println!("launching workers");

let llm_stream = tokio::spawn(l.stream(prompts_rx, chunks_tx, watch_rx));
let jet_write = tokio::spawn(s.writer.write(chunks_rx, jet_wr_watch_rx));
let (_stream, stream_handle) = OutputStream::try_default().unwrap();
let sink = Sink::try_new(&stream_handle).unwrap();
let (mut audio_wr, mut audio_rd) = io::duplex(1024);

let tts_stream =
tokio::spawn(async move { t.stream(&mut audio_wr, tts_chunks_rx, tts_watch_rx).await });
let llm_stream = tokio::spawn(l.stream(prompts_rx, jet_chunks_tx, tts_chunks_tx, watch_rx));
let jet_write = tokio::spawn(s.writer.write(jet_chunks_rx, jet_wr_watch_rx));
let jet_read = tokio::spawn(s.reader.read(prompts_tx, jet_rd_watch_rx));
let sig_handler: JoinHandle<Result<()>> = tokio::spawn(async move {
tokio::select! {
Expand All @@ -64,8 +86,38 @@ async fn main() -> Result<()> {
}
Ok(())
});
let play_task = tokio::spawn(async move {
let mut audio_data = BytesMut::new();
while let Ok(chunk) = audio_rd.read_buf(&mut audio_data).await {
if chunk == 0 {
break;
}
if audio_data.len() > AUDIO_BUFFER_SIZE {
let cursor = Cursor::new(audio_data.clone().freeze().to_vec());
match Decoder::new(cursor) {
Ok(source) => {
sink.append(source);
audio_data.clear(); // Clear the buffer on successful append
}
Err(e) => {
eprintln!("Failed to decode received audio: {}", e);
}
}
}
}

// Flush any remaining data at the end
if !audio_data.is_empty() {
let cursor = Cursor::new(audio_data.to_vec());
match Decoder::new(cursor) {
Ok(source) => sink.append(source),
Err(e) => println!("Remaining data could not be decoded: {}", e),
}
}
sink.sleep_until_end();
});

match tokio::try_join!(llm_stream, jet_write, jet_read) {
match tokio::try_join!(tts_stream, llm_stream, jet_write, jet_read, play_task) {
Ok(_) => {}
Err(e) => {
println!("Error running bot: {}", e);
Expand Down
4 changes: 4 additions & 0 deletions rustbot/src/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,7 @@ Assistant: Rust's biggest strength lies in its focus on safety, particularly mem
safety, without sacrificing performance. Can you tell me what are some of the biggest \
strengths of Go that make it stand out from other programming languages?
Question: ";

pub const DEFAULT_VOICE_ID: &str = "s3://mockingbird-prod/abigail_vo_6661b91f-4012-44e3-ad12-589fbdee9948/voices/speaker/manifest.json";
pub const MAX_TTS_BUFFER_SIZE: usize = 1000;
pub const AUDIO_BUFFER_SIZE: usize = 1024 * 10;
90 changes: 90 additions & 0 deletions rustbot/src/tts.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
use crate::{buffer, prelude::*};
use bytes::Bytes;
use playht_rs::api::{self, stream::TTSStreamReq, tts::Quality};
use tokio::{self, sync::mpsc::Receiver, sync::watch};

#[derive(Debug, Clone)]
pub struct Config {
pub voice_id: Option<String>,
pub quality: Option<Quality>,
pub speed: Option<f32>,
pub sample_rate: Option<i32>,
pub buf_size: usize,
}

impl Default for Config {
fn default() -> Self {
Config {
voice_id: Some(DEFAULT_VOICE_ID.to_string()),
quality: Some(Quality::Low),
speed: Some(1.0),
sample_rate: Some(24000),
buf_size: MAX_TTS_BUFFER_SIZE,
}
}
}

pub struct TTS {
client: api::Client,
config: Config,
}

impl TTS {
pub fn new(c: Config) -> TTS {
TTS {
client: api::Client::new(),
config: c,
}
}

pub async fn stream<W>(
self,
w: &mut W,
mut chunks: Receiver<Bytes>,
mut done: watch::Receiver<bool>,
) -> Result<()>
where
W: tokio::io::AsyncWriteExt + Unpin,
{
println!("launching TTS stream");
let mut buf = buffer::Buffer::new(self.config.buf_size);
let mut req = TTSStreamReq {
voice: self.config.voice_id,
quality: self.config.quality,
speed: self.config.speed,
sample_rate: self.config.sample_rate,
..Default::default()
};

loop {
tokio::select! {
_ = done.changed() => {
if *done.borrow() {
return Ok(())
}
},
Some(chunk) = chunks.recv() => {
if chunk.is_empty() {
let text = String::from_utf8(buf.as_bytes().to_vec())?;
req.text = Some(text);
self.client.write_audio_stream(w, &req).await?;
buf.reset();
continue
}
match buf.write(chunk.as_ref()) {
Ok(_) => {},
Err(e) => {
let text = String::from_utf8(buf.as_bytes().to_vec())?;
req.text = Some(text);
self.client.write_audio_stream(w, &req).await?;
buf.reset();
let rem = chunk.len() - e.bytes_written;
let chunk_slice = chunk.as_ref();
buf.write(&chunk_slice[rem..])?;
}
}
}
}
}
}
}

0 comments on commit 992dec6

Please sign in to comment.