diff --git a/libsql-replication/src/injector/mod.rs b/libsql-replication/src/injector/mod.rs index e619175d67..bb6258886d 100644 --- a/libsql-replication/src/injector/mod.rs +++ b/libsql-replication/src/injector/mod.rs @@ -24,7 +24,7 @@ pub type FrameBuffer = Arc>>; pub struct Injector { /// The injector is in a transaction state - is_txn: bool, + pub(crate) is_txn: bool, /// Buffer for holding current transaction frames buffer: FrameBuffer, /// Maximum capacity of the frame buffer @@ -80,6 +80,13 @@ impl Injector { Ok(None) } + pub fn rollback(&mut self) { + let conn = self.connection.lock(); + let mut rollback = conn.prepare_cached("ROLLBACK").unwrap(); + let _ = rollback.execute(()); + self.is_txn = false; + } + /// Flush the buffer to libsql WAL. /// Trigger a dummy write, and flush the cache to trigger a call to xFrame. The buffer's frame /// are then injected into the wal. @@ -89,9 +96,7 @@ impl Injector { // something went wrong, rollback the connection to make sure we can retry in a // clean state self.biggest_uncommitted_seen = 0; - let connection = self.connection.lock(); - let mut rollback = connection.prepare_cached("ROLLBACK")?; - let _ = rollback.execute(()); + self.rollback(); Err(e) } Ok(ret) => Ok(ret), diff --git a/libsql-replication/src/replicator.rs b/libsql-replication/src/replicator.rs index 46051f8177..7310f0a5dd 100644 --- a/libsql-replication/src/replicator.rs +++ b/libsql-replication/src/replicator.rs @@ -214,6 +214,11 @@ impl Replicator { ReplicatorState::Exit => unreachable!("trying to step replicator on exit"), }; + // in case of error we rollback the current injector transaction, and start over. + if ret.is_err() { + self.injector.lock().rollback(); + } + self.state = match ret { // perform normal operation state transition Ok(()) => match state { @@ -232,7 +237,12 @@ impl Replicator { ReplicatorState::NeedHandshake } Err(Error::NeedSnapshot) => ReplicatorState::NeedSnapshot, - Err(e) => return Err(e), + Err(e) => { + // an error here could be due to a disconnection, it's safe to rollback to a + // NeedHandshake state again, to avoid entering a busy loop. + self.state = ReplicatorState::NeedHandshake; + return Err(e); + } }; Ok(()) @@ -294,10 +304,12 @@ pub fn map_frame_err(f: Result) -> Result { #[cfg(test)] mod test { - use std::pin::Pin; + use std::{mem::size_of, pin::Pin}; use async_stream::stream; + use crate::frame::{FrameBorrowed, FrameMut}; + use super::*; #[tokio::test] @@ -635,4 +647,107 @@ mod test { Error::Fatal(_) )); } + + #[tokio::test] + async fn transaction_interupted_by_error_and_resumed() { + /// this this is generated by creating a table test, inserting 5 rows into it, and then + /// truncating the wal file of it's header. + const WAL: &[u8] = include_bytes!("../assets/test/test_wallog"); + + fn make_wal_log() -> Vec { + let mut frames = WAL + .chunks(size_of::()) + .map(|b| FrameMut::try_from(b).unwrap()) + .map(|mut f| { + f.header_mut().size_after = 0; + f + }) + .collect::>(); + + let size_after = frames.len(); + frames.last_mut().unwrap().header_mut().size_after = size_after as _; + + frames.into_iter().map(Into::into).collect() + } + + let tmp = tempfile::NamedTempFile::new().unwrap(); + + struct Client { + frames: Vec, + should_error: bool, + committed_frame_no: Option, + } + + #[async_trait::async_trait] + impl ReplicatorClient for Client { + type FrameStream = Pin> + Send + 'static>>; + + /// Perform handshake with remote + async fn handshake(&mut self) -> Result<(), Error> { + Ok(()) + } + /// Return a stream of frames to apply to the database + async fn next_frames(&mut self) -> Result { + if self.should_error { + let frames = self + .frames + .iter() + .take(2) + .cloned() + .map(Ok) + .chain(Some(Err(Error::Client("some client error".into())))) + .collect::>(); + Ok(Box::pin(tokio_stream::iter(frames))) + } else { + let stream = tokio_stream::iter(self.frames.clone().into_iter().map(Ok)); + Ok(Box::pin(stream)) + } + } + /// Return a snapshot for the current replication index. Called after next_frame has returned a + /// NeedSnapshot error + async fn snapshot(&mut self) -> Result { + unimplemented!() + } + /// set the new commit frame_no + async fn commit_frame_no(&mut self, frame_no: FrameNo) -> Result<(), Error> { + self.committed_frame_no = Some(frame_no); + Ok(()) + } + /// Returns the currently committed replication index + fn committed_frame_no(&self) -> Option { + unimplemented!() + } + } + + let client = Client { + frames: make_wal_log(), + should_error: true, + committed_frame_no: None, + }; + + let mut replicator = Replicator::new(client, tmp.path().to_path_buf(), 10000) + .await + .unwrap(); + + replicator.try_replicate_step().await.unwrap(); + assert_eq!(replicator.state, ReplicatorState::NeedFrames); + + assert!(matches!( + replicator.try_replicate_step().await.unwrap_err(), + Error::Client(_) + )); + assert!(!replicator.injector.lock().is_txn); + assert!(replicator.client_mut().committed_frame_no.is_none()); + assert_eq!(replicator.state, ReplicatorState::NeedHandshake); + + replicator.try_replicate_step().await.unwrap(); + assert_eq!(replicator.state, ReplicatorState::NeedFrames); + + replicator.client_mut().should_error = false; + + replicator.try_replicate_step().await.unwrap(); + assert!(!replicator.injector.lock().is_txn); + assert_eq!(replicator.state, ReplicatorState::Exit); + assert_eq!(replicator.client_mut().committed_frame_no, Some(6)); + } }