Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

libsql: WAL pull support #1858

Merged
merged 1 commit into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions libsql/examples/offline_writes_pull.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// Example of using a offline writes with libSQL.

use libsql::Builder;

#[tokio::main]
async fn main() {
tracing_subscriber::fmt::init();

// The local database path where the data will be stored.
let db_path = std::env::var("LIBSQL_DB_PATH")
.map_err(|_| {
eprintln!(
"Please set the LIBSQL_DB_PATH environment variable to set to local database path."
)
})
.unwrap();

// The remote sync URL to use.
let sync_url = std::env::var("LIBSQL_SYNC_URL")
.map_err(|_| {
eprintln!(
"Please set the LIBSQL_SYNC_URL environment variable to set to remote sync URL."
)
})
.unwrap();

// The authentication token to use.
let auth_token = std::env::var("LIBSQL_AUTH_TOKEN").unwrap_or("".to_string());

let db_builder = Builder::new_synced_database(db_path, sync_url, auth_token);

let db = match db_builder.build().await {
Ok(db) => db,
Err(error) => {
eprintln!("Error connecting to remote sync server: {}", error);
return;
}
};

println!("Syncing database from remote...");
db.sync().await.unwrap();

let conn = db.connect().unwrap();
let mut results = conn
.query("SELECT * FROM guest_book_entries", ())
.await
.unwrap();
println!("Guest book entries:");
while let Some(row) = results.next().await.unwrap() {
let text: String = row.get(0).unwrap();
println!(" {}", text);
}

println!("Done!");
}
2 changes: 1 addition & 1 deletion libsql/src/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ cfg_replication! {
#[cfg(feature = "replication")]
DbType::Sync { db, encryption_config: _ } => db.sync().await,
#[cfg(feature = "sync")]
DbType::Offline { db } => db.push().await,
DbType::Offline { db } => db.sync_offline().await,
_ => Err(Error::SyncNotSupported(format!("{:?}", self.db_type))),
}
}
Expand Down
40 changes: 40 additions & 0 deletions libsql/src/local/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,46 @@ impl Connection {

Ok(buf)
}

pub(crate) fn wal_insert_begin(&self) -> Result<()> {
let rc = unsafe { libsql_sys::ffi::libsql_wal_insert_begin(self.handle()) };
if rc != 0 {
return Err(crate::errors::Error::SqliteFailure(
rc as std::ffi::c_int,
format!("wal_insert_begin failed"),
));
}
Ok(())
}

pub(crate) fn wal_insert_end(&self) -> Result<()> {
let rc = unsafe { libsql_sys::ffi::libsql_wal_insert_end(self.handle()) };
if rc != 0 {
return Err(crate::errors::Error::SqliteFailure(
rc as std::ffi::c_int,
format!("wal_insert_end failed"),
));
}
Ok(())
}

pub(crate) fn wal_insert_frame(&self, frame: &[u8]) -> Result<()> {
let rc = unsafe {
libsql_sys::ffi::libsql_wal_insert_frame(
self.handle(),
frame.len() as u32,
frame.as_ptr() as *mut std::ffi::c_void,
0
)
};
if rc != 0 {
return Err(crate::errors::Error::SqliteFailure(
rc as std::ffi::c_int,
format!("wal_insert_frame failed"),
));
}
Ok(())
}
}

impl fmt::Debug for Connection {
Expand Down
39 changes: 37 additions & 2 deletions libsql/src/local/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -386,8 +386,8 @@ impl Database {
}

#[cfg(feature = "sync")]
/// Push WAL frames to remote.
pub async fn push(&self) -> Result<crate::database::Replicated> {
/// Sync WAL frames to remote.
pub async fn sync_offline(&self) -> Result<crate::database::Replicated> {
use crate::sync::SyncError;
use crate::Error;

Expand Down Expand Up @@ -425,6 +425,10 @@ impl Database {

let max_frame_no = conn.wal_frame_count();

if max_frame_no == 0 {
return self.try_pull(&mut sync_ctx).await;
}

let generation = sync_ctx.generation(); // TODO: Probe from WAL.
let start_frame_no = sync_ctx.durable_frame_num() + 1;
let end_frame_no = max_frame_no;
Expand All @@ -448,6 +452,10 @@ impl Database {

sync_ctx.write_metadata().await?;

if start_frame_no > end_frame_no {
return self.try_pull(&mut sync_ctx).await;
}

// TODO(lucio): this can underflow if the server previously returned a higher max_frame_no
// than what we have stored here.
let frame_count = end_frame_no - start_frame_no + 1;
Expand All @@ -457,6 +465,33 @@ impl Database {
})
}

#[cfg(feature = "sync")]
async fn try_pull(&self, sync_ctx: &mut SyncContext) -> Result<crate::database::Replicated> {
let generation = sync_ctx.generation();
let mut frame_no = sync_ctx.durable_frame_num() + 1;
let conn = self.connect()?;
conn.wal_insert_begin()?;
loop {
match sync_ctx.pull_one_frame(generation, frame_no).await {
Ok(frame) => {
conn.wal_insert_frame(&frame)?;
frame_no += 1;
}
Err(e) => {
println!("pull_one_frame error: {:?}", e);
break;
}
}

}
conn.wal_insert_end()?;
sync_ctx.write_metadata().await?;
Ok(crate::database::Replicated {
frame_no: None,
frames_synced: 1,
})
}

pub(crate) fn path(&self) -> &str {
&self.db_path
}
Expand Down
64 changes: 64 additions & 0 deletions libsql/src/sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ pub enum SyncError {
InvalidPushFrameNoLow(u32, u32),
#[error("server returned a higher frame_no: sent={0}, got={1}")]
InvalidPushFrameNoHigh(u32, u32),
#[error("failed to pull frame: status={0}, error={1}")]
PullFrame(StatusCode, String),
}

impl SyncError {
Expand Down Expand Up @@ -104,6 +106,21 @@ impl SyncContext {
Ok(me)
}

#[tracing::instrument(skip(self))]
pub(crate) async fn pull_one_frame(&mut self, generation: u32, frame_no: u32) -> Result<Bytes> {
let uri = format!(
"{}/sync/{}/{}/{}",
self.sync_url,
generation,
frame_no,
frame_no + 1
);
tracing::debug!("pulling frame");
let frame = self.pull_with_retry(uri, self.max_retries).await?;
self.durable_frame_num = frame_no;
Ok(frame)
}

#[tracing::instrument(skip(self, frame))]
pub(crate) async fn push_one_frame(
&mut self,
Expand Down Expand Up @@ -215,6 +232,53 @@ impl SyncContext {
}
}

async fn pull_with_retry(&self, uri: String, max_retries: usize) -> Result<Bytes> {
let mut nr_retries = 0;
loop {
let mut req = http::Request::builder().method("GET").uri(uri.clone());

match &self.auth_token {
Some(auth_token) => {
req = req.header("Authorization", auth_token);
}
None => {}
}

let req = req.body(Body::empty())
.expect("valid request");

let res = self
.client
.request(req)
.await
.map_err(SyncError::HttpDispatch)?;

if res.status().is_success() {
let frame = hyper::body::to_bytes(res.into_body())
.await
.map_err(SyncError::HttpBody)?;
return Ok(frame);
}
// If we've retried too many times or the error is not a server error,
// return the error.
if nr_retries > max_retries || !res.status().is_server_error() {
let status = res.status();

let res_body = hyper::body::to_bytes(res.into_body())
.await
.map_err(SyncError::HttpBody)?;

let msg = String::from_utf8_lossy(&res_body[..]);

return Err(SyncError::PullFrame(status, msg.to_string()).into());
}

let delay = std::time::Duration::from_millis(100 * (1 << nr_retries));
tokio::time::sleep(delay).await;
nr_retries += 1;
}
}

pub(crate) fn durable_frame_num(&self) -> u32 {
self.durable_frame_num
}
Expand Down
Loading