diff --git a/rustbot/src/audio.rs b/rustbot/src/audio.rs new file mode 100644 index 0000000..94c69b3 --- /dev/null +++ b/rustbot/src/audio.rs @@ -0,0 +1,40 @@ +use crate::prelude::*; +use bytes::BytesMut; +use rodio::{Decoder, Sink}; +use std::io::Cursor; +use tokio::{ + self, + io::{self, AsyncReadExt}, +}; + +pub async fn play(mut audio_rd: io::DuplexStream, sink: Sink) -> Result<()> { + let mut audio_data = BytesMut::new(); + while let Ok(chunk) = audio_rd.read_buf(&mut audio_data).await { + if chunk == 0 { + break; + } + if audio_data.len() > AUDIO_BUFFER_SIZE { + let cursor = Cursor::new(audio_data.clone().freeze().to_vec()); + match Decoder::new(cursor) { + Ok(source) => { + sink.append(source); + audio_data.clear(); // Clear the buffer on successful append + } + Err(e) => { + eprintln!("Failed to decode received audio: {}", e); + } + } + } + } + + // Flush any remaining data at the end + if !audio_data.is_empty() { + let cursor = Cursor::new(audio_data.to_vec()); + match Decoder::new(cursor) { + Ok(source) => sink.append(source), + Err(e) => println!("Remaining data could not be decoded: {}", e), + } + } + sink.sleep_until_end(); + Ok(()) +} diff --git a/rustbot/src/jet.rs b/rustbot/src/jet.rs index 6580038..bdff0f2 100644 --- a/rustbot/src/jet.rs +++ b/rustbot/src/jet.rs @@ -114,6 +114,7 @@ impl Writer { pub async fn write( self, mut chunks: Receiver, + mut tts_done: watch::Receiver, mut done: watch::Receiver, ) -> Result<()> { println!("launching JetStream Writer"); @@ -129,10 +130,18 @@ impl Writer { if chunk.is_empty() { let msg = String::from_utf8(b.to_vec()).unwrap(); println!("\n[A]: {}", msg); - self.tx.publish(self.subject.to_string(), b.clone().freeze()) - .await?; - b.clear(); - continue; + loop { + tokio::select! { + _ = tts_done.changed() => { + if *tts_done.borrow() { + self.tx.publish(self.subject.to_string(), b.clone().freeze()) + .await?; + b.clear(); + break; + } + }, + } + } } b.extend_from_slice(&chunk); } diff --git a/rustbot/src/main.rs b/rustbot/src/main.rs index f17ff99..3896a3f 100644 --- a/rustbot/src/main.rs +++ b/rustbot/src/main.rs @@ -1,16 +1,14 @@ -use bytes::{Bytes, BytesMut}; +use bytes::Bytes; use clap::Parser; use prelude::*; -use rodio::{Decoder, OutputStream, Sink}; -use std::io::Cursor; +use rodio::{OutputStream, Sink}; use tokio::{ - self, - io::{self, AsyncReadExt}, - signal, + self, io, signal, sync::{mpsc, watch}, task::JoinHandle, }; +mod audio; mod buffer; mod cli; mod history; @@ -59,6 +57,7 @@ async fn main() -> Result<()> { let (prompts_tx, prompts_rx) = mpsc::channel::(32); let (jet_chunks_tx, jet_chunks_rx) = mpsc::channel::(32); let (tts_chunks_tx, tts_chunks_rx) = mpsc::channel::(32); + let (tts_done_tx, tts_done_rx) = watch::channel(false); // NOTE: used for cancellation when SIGINT is trapped. let (watch_tx, watch_rx) = watch::channel(false); @@ -70,13 +69,13 @@ async fn main() -> Result<()> { let (_stream, stream_handle) = OutputStream::try_default().unwrap(); let sink = Sink::try_new(&stream_handle).unwrap(); - let (mut audio_wr, mut audio_rd) = io::duplex(1024); + let (audio_wr, audio_rd) = io::duplex(1024); - let tts_stream = - tokio::spawn(async move { t.stream(&mut audio_wr, tts_chunks_rx, tts_watch_rx).await }); + let tts_stream = tokio::spawn(t.stream(audio_wr, tts_chunks_rx, tts_done_tx, tts_watch_rx)); let llm_stream = tokio::spawn(l.stream(prompts_rx, jet_chunks_tx, tts_chunks_tx, watch_rx)); - let jet_write = tokio::spawn(s.writer.write(jet_chunks_rx, jet_wr_watch_rx)); + let jet_write = tokio::spawn(s.writer.write(jet_chunks_rx, tts_done_rx, jet_wr_watch_rx)); let jet_read = tokio::spawn(s.reader.read(prompts_tx, jet_rd_watch_rx)); + let audio_task = tokio::spawn(audio::play(audio_rd, sink)); let sig_handler: JoinHandle> = tokio::spawn(async move { tokio::select! { _ = signal::ctrl_c() => { @@ -86,38 +85,8 @@ async fn main() -> Result<()> { } Ok(()) }); - let play_task = tokio::spawn(async move { - let mut audio_data = BytesMut::new(); - while let Ok(chunk) = audio_rd.read_buf(&mut audio_data).await { - if chunk == 0 { - break; - } - if audio_data.len() > AUDIO_BUFFER_SIZE { - let cursor = Cursor::new(audio_data.clone().freeze().to_vec()); - match Decoder::new(cursor) { - Ok(source) => { - sink.append(source); - audio_data.clear(); // Clear the buffer on successful append - } - Err(e) => { - eprintln!("Failed to decode received audio: {}", e); - } - } - } - } - - // Flush any remaining data at the end - if !audio_data.is_empty() { - let cursor = Cursor::new(audio_data.to_vec()); - match Decoder::new(cursor) { - Ok(source) => sink.append(source), - Err(e) => println!("Remaining data could not be decoded: {}", e), - } - } - sink.sleep_until_end(); - }); - match tokio::try_join!(tts_stream, llm_stream, jet_write, jet_read, play_task) { + match tokio::try_join!(tts_stream, llm_stream, jet_write, jet_read, audio_task) { Ok(_) => {} Err(e) => { println!("Error running bot: {}", e); diff --git a/rustbot/src/tts.rs b/rustbot/src/tts.rs index ca51d0f..73167d2 100644 --- a/rustbot/src/tts.rs +++ b/rustbot/src/tts.rs @@ -39,8 +39,9 @@ impl TTS { pub async fn stream( self, - w: &mut W, + mut w: W, mut chunks: Receiver, + tts_done: watch::Sender, mut done: watch::Receiver, ) -> Result<()> where @@ -67,8 +68,9 @@ impl TTS { if chunk.is_empty() { let text = String::from_utf8(buf.as_bytes().to_vec())?; req.text = Some(text); - self.client.write_audio_stream(w, &req).await?; + self.client.write_audio_stream(&mut w, &req).await?; buf.reset(); + tts_done.send(true)?; continue } match buf.write(chunk.as_ref()) { @@ -76,8 +78,9 @@ impl TTS { Err(e) => { let text = String::from_utf8(buf.as_bytes().to_vec())?; req.text = Some(text); - self.client.write_audio_stream(w, &req).await?; + self.client.write_audio_stream(&mut w, &req).await?; buf.reset(); + tts_done.send(true)?; let rem = chunk.len() - e.bytes_written; let chunk_slice = chunk.as_ref(); buf.write(&chunk_slice[rem..])?;