Skip to content

Commit

Permalink
let rustbot speak first before sending a prompt to gobot
Browse files Browse the repository at this point in the history
Signed-off-by: Milos Gajdos <[email protected]>
  • Loading branch information
milosgajdos committed Apr 21, 2024
1 parent 5477f6b commit 572ff02
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 48 deletions.
40 changes: 40 additions & 0 deletions rustbot/src/audio.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
use crate::prelude::*;
use bytes::BytesMut;
use rodio::{Decoder, Sink};
use std::io::Cursor;
use tokio::{
self,
io::{self, AsyncReadExt},
};

pub async fn play(mut audio_rd: io::DuplexStream, sink: Sink) -> Result<()> {
let mut audio_data = BytesMut::new();
while let Ok(chunk) = audio_rd.read_buf(&mut audio_data).await {
if chunk == 0 {
break;
}
if audio_data.len() > AUDIO_BUFFER_SIZE {
let cursor = Cursor::new(audio_data.clone().freeze().to_vec());
match Decoder::new(cursor) {
Ok(source) => {
sink.append(source);
audio_data.clear(); // Clear the buffer on successful append
}
Err(e) => {
eprintln!("Failed to decode received audio: {}", e);
}
}
}
}

// Flush any remaining data at the end
if !audio_data.is_empty() {
let cursor = Cursor::new(audio_data.to_vec());
match Decoder::new(cursor) {
Ok(source) => sink.append(source),
Err(e) => println!("Remaining data could not be decoded: {}", e),
}
}
sink.sleep_until_end();
Ok(())
}
17 changes: 13 additions & 4 deletions rustbot/src/jet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ impl Writer {
pub async fn write(
self,
mut chunks: Receiver<Bytes>,
mut tts_done: watch::Receiver<bool>,
mut done: watch::Receiver<bool>,
) -> Result<()> {
println!("launching JetStream Writer");
Expand All @@ -129,10 +130,18 @@ impl Writer {
if chunk.is_empty() {
let msg = String::from_utf8(b.to_vec()).unwrap();
println!("\n[A]: {}", msg);
self.tx.publish(self.subject.to_string(), b.clone().freeze())
.await?;
b.clear();
continue;
loop {
tokio::select! {
_ = tts_done.changed() => {
if *tts_done.borrow() {
self.tx.publish(self.subject.to_string(), b.clone().freeze())
.await?;
b.clear();
break;
}
},
}
}
}
b.extend_from_slice(&chunk);
}
Expand Down
51 changes: 10 additions & 41 deletions rustbot/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
use bytes::{Bytes, BytesMut};
use bytes::Bytes;
use clap::Parser;
use prelude::*;
use rodio::{Decoder, OutputStream, Sink};
use std::io::Cursor;
use rodio::{OutputStream, Sink};
use tokio::{
self,
io::{self, AsyncReadExt},
signal,
self, io, signal,
sync::{mpsc, watch},
task::JoinHandle,
};

mod audio;
mod buffer;
mod cli;
mod history;
Expand Down Expand Up @@ -59,6 +57,7 @@ async fn main() -> Result<()> {
let (prompts_tx, prompts_rx) = mpsc::channel::<String>(32);
let (jet_chunks_tx, jet_chunks_rx) = mpsc::channel::<Bytes>(32);
let (tts_chunks_tx, tts_chunks_rx) = mpsc::channel::<Bytes>(32);
let (tts_done_tx, tts_done_rx) = watch::channel(false);

// NOTE: used for cancellation when SIGINT is trapped.
let (watch_tx, watch_rx) = watch::channel(false);
Expand All @@ -70,13 +69,13 @@ async fn main() -> Result<()> {

let (_stream, stream_handle) = OutputStream::try_default().unwrap();
let sink = Sink::try_new(&stream_handle).unwrap();
let (mut audio_wr, mut audio_rd) = io::duplex(1024);
let (audio_wr, audio_rd) = io::duplex(1024);

let tts_stream =
tokio::spawn(async move { t.stream(&mut audio_wr, tts_chunks_rx, tts_watch_rx).await });
let tts_stream = tokio::spawn(t.stream(audio_wr, tts_chunks_rx, tts_done_tx, tts_watch_rx));
let llm_stream = tokio::spawn(l.stream(prompts_rx, jet_chunks_tx, tts_chunks_tx, watch_rx));
let jet_write = tokio::spawn(s.writer.write(jet_chunks_rx, jet_wr_watch_rx));
let jet_write = tokio::spawn(s.writer.write(jet_chunks_rx, tts_done_rx, jet_wr_watch_rx));
let jet_read = tokio::spawn(s.reader.read(prompts_tx, jet_rd_watch_rx));
let audio_task = tokio::spawn(audio::play(audio_rd, sink));
let sig_handler: JoinHandle<Result<()>> = tokio::spawn(async move {
tokio::select! {
_ = signal::ctrl_c() => {
Expand All @@ -86,38 +85,8 @@ async fn main() -> Result<()> {
}
Ok(())
});
let play_task = tokio::spawn(async move {
let mut audio_data = BytesMut::new();
while let Ok(chunk) = audio_rd.read_buf(&mut audio_data).await {
if chunk == 0 {
break;
}
if audio_data.len() > AUDIO_BUFFER_SIZE {
let cursor = Cursor::new(audio_data.clone().freeze().to_vec());
match Decoder::new(cursor) {
Ok(source) => {
sink.append(source);
audio_data.clear(); // Clear the buffer on successful append
}
Err(e) => {
eprintln!("Failed to decode received audio: {}", e);
}
}
}
}

// Flush any remaining data at the end
if !audio_data.is_empty() {
let cursor = Cursor::new(audio_data.to_vec());
match Decoder::new(cursor) {
Ok(source) => sink.append(source),
Err(e) => println!("Remaining data could not be decoded: {}", e),
}
}
sink.sleep_until_end();
});

match tokio::try_join!(tts_stream, llm_stream, jet_write, jet_read, play_task) {
match tokio::try_join!(tts_stream, llm_stream, jet_write, jet_read, audio_task) {
Ok(_) => {}
Err(e) => {
println!("Error running bot: {}", e);
Expand Down
9 changes: 6 additions & 3 deletions rustbot/src/tts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@ impl TTS {

pub async fn stream<W>(
self,
w: &mut W,
mut w: W,
mut chunks: Receiver<Bytes>,
tts_done: watch::Sender<bool>,
mut done: watch::Receiver<bool>,
) -> Result<()>
where
Expand All @@ -67,17 +68,19 @@ impl TTS {
if chunk.is_empty() {
let text = String::from_utf8(buf.as_bytes().to_vec())?;
req.text = Some(text);
self.client.write_audio_stream(w, &req).await?;
self.client.write_audio_stream(&mut w, &req).await?;
buf.reset();
tts_done.send(true)?;
continue
}
match buf.write(chunk.as_ref()) {
Ok(_) => {},
Err(e) => {
let text = String::from_utf8(buf.as_bytes().to_vec())?;
req.text = Some(text);
self.client.write_audio_stream(w, &req).await?;
self.client.write_audio_stream(&mut w, &req).await?;
buf.reset();
tts_done.send(true)?;
let rem = chunk.len() - e.bytes_written;
let chunk_slice = chunk.as_ref();
buf.write(&chunk_slice[rem..])?;
Expand Down

0 comments on commit 572ff02

Please sign in to comment.