Skip to content

Commit

Permalink
libsql: rework sync v2 structure
Browse files Browse the repository at this point in the history
This refactors the new sync code to live in sync.rs. This also includes
support for writing metadata to disk to improve syncs on restarts.
  • Loading branch information
LucioFranco committed Nov 15, 2024
1 parent 89b3460 commit bdb428e
Show file tree
Hide file tree
Showing 4 changed files with 213 additions and 59 deletions.
12 changes: 12 additions & 0 deletions libsql/src/database/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,19 @@ cfg_sync! {

let path = path.to_str().ok_or(crate::Error::InvalidUTF8Path)?.to_owned();

// TODO: add config to set custom connector
let https = super::connector()?;

use tower::ServiceExt;

let svc = https
.map_err(|e| e.into())
.map_response(|s| Box::new(s) as Box<dyn crate::util::Socket>);

let connector = crate::util::ConnectorService::new(svc);

let db = crate::local::Database::open_local_with_offline_writes(
connector,
path,
flags,
url,
Expand Down
34 changes: 34 additions & 0 deletions libsql/src/local/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use super::{Database, Error, Result, Rows, RowsFuture, Statement, Transaction};

use crate::TransactionBehavior;

use bytes::{BufMut, BytesMut};
use libsql_sys::ffi;
use std::{ffi::c_int, fmt, path::Path, sync::Arc};

Expand Down Expand Up @@ -445,6 +446,39 @@ impl Connection {
}
}
}

pub fn wal_frame_count(&self) -> u32 {
let mut max_frame_no: std::os::raw::c_uint = 0;
unsafe { libsql_sys::ffi::libsql_wal_frame_count(self.handle(), &mut max_frame_no) };

max_frame_no
}

pub fn wal_get_frame(&self, frame_no: u32, page_size: u32) -> Result<BytesMut> {
let frame_size: usize = 24 + page_size as usize;

let mut buf = BytesMut::with_capacity(frame_size);

let rc = unsafe {
libsql_sys::ffi::libsql_wal_get_frame(
self.handle(),
frame_no,
buf.chunk_mut().as_mut_ptr() as *mut _,
frame_size as u32,
)
};

if rc != 0 {
return Err(crate::errors::Error::SqliteFailure(
rc as std::ffi::c_int,
format!("Failed to get frame: {}", frame_no),
));
}

unsafe { buf.advance_mut(frame_size) };

Ok(buf)
}
}

impl fmt::Debug for Connection {
Expand Down
83 changes: 30 additions & 53 deletions libsql/src/local/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ pub struct Database {
#[cfg(feature = "replication")]
pub replication_ctx: Option<ReplicationContext>,
#[cfg(feature = "sync")]
pub sync_ctx: Option<SyncContext>,
pub sync_ctx: Option<tokio::sync::Mutex<SyncContext>>,
}

impl Database {
Expand Down Expand Up @@ -131,6 +131,7 @@ impl Database {
#[cfg(feature = "sync")]
#[doc(hidden)]
pub async fn open_local_with_offline_writes(
connector: crate::util::ConnectorService,
db_path: impl Into<String>,
flags: OpenFlags,
endpoint: String,
Expand All @@ -143,7 +144,10 @@ impl Database {
endpoint
};
let mut db = Database::open(&db_path, flags)?;
db.sync_ctx = Some(SyncContext::new(endpoint, Some(auth_token)));

let ctx = SyncContext::new(endpoint, Some(auth_token), &db_path, connector).await;

db.sync_ctx = Some(tokio::sync::Mutex::new(ctx));
Ok(db)
}

Expand Down Expand Up @@ -320,7 +324,10 @@ impl Database {

#[cfg(feature = "replication")]
/// Sync with primary at least to a given replication index
pub async fn sync_until(&self, replication_index: FrameNo) -> Result<crate::database::Replicated> {
pub async fn sync_until(
&self,
replication_index: FrameNo,
) -> Result<crate::database::Replicated> {
if let Some(ctx) = &self.replication_ctx {
let mut frame_no: Option<FrameNo> = ctx.replicator.committed_frame_no().await;
let mut frames_synced: usize = 0;
Expand Down Expand Up @@ -380,84 +387,54 @@ impl Database {
#[cfg(feature = "sync")]
/// Push WAL frames to remote.
pub async fn push(&self) -> Result<crate::database::Replicated> {
let sync_ctx = self.sync_ctx.as_ref().unwrap();
let mut ctx = match &self.sync_ctx {
Some(ctx) => ctx.lock().await,
None => panic!("sync context not set"),
};

let conn = self.connect()?;

// TODO: can this be cached?
let page_size = {
let rows = conn.query("PRAGMA page_size", crate::params::Params::None)?.unwrap();
let rows = conn
.query("PRAGMA page_size", crate::params::Params::None)?
.unwrap();
let row = rows.next()?.unwrap();
let page_size = row.get::<u32>(0)?;
page_size
};

let mut max_frame_no: std::os::raw::c_uint = 0;
unsafe { libsql_sys::ffi::libsql_wal_frame_count(conn.handle(), &mut max_frame_no) };

let max_frame_no = conn.wal_frame_count();

let generation = 1; // TODO: Probe from WAL.
let start_frame_no = sync_ctx.durable_frame_num + 1;
let start_frame_no = ctx.durable_frame_num() + 1;
let end_frame_no = max_frame_no;

// TODO: figure out relation to durable_frame_num
// let max_frame_no = ctx.max_frame_no();

let mut frame_no = start_frame_no;
while frame_no <= end_frame_no {
// The server returns its maximum frame number. To avoid resending
// frames the server already knows about, we need to update the
// frame number to the one returned by the server.
let max_frame_no = self.push_one_frame(&conn, &sync_ctx, generation, frame_no, page_size).await?;
let frame = conn.wal_get_frame(frame_no, page_size)?;

let max_frame_no = ctx.send_frame(frame.freeze(), generation, frame_no).await?;

if max_frame_no > frame_no {
frame_no = max_frame_no;
}
frame_no += 1;
}

let frame_count = end_frame_no - start_frame_no + 1;
Ok(crate::database::Replicated{
Ok(crate::database::Replicated {
frame_no: None,
frames_synced: frame_count as usize,
})
}

#[cfg(feature = "sync")]
async fn push_one_frame(&self, conn: &Connection, sync_ctx: &SyncContext, generation: u32, frame_no: u32, page_size: u32) -> Result<u32> {
let frame_size: usize = 24+page_size as usize;
let frame = vec![0; frame_size];
let rc = unsafe {
libsql_sys::ffi::libsql_wal_get_frame(conn.handle(), frame_no, frame.as_ptr() as *mut _, frame_size as u32)
};
if rc != 0 {
return Err(crate::errors::Error::SqliteFailure(rc as std::ffi::c_int, format!("Failed to get frame: {}", frame_no)));
}
let uri = format!("{}/sync/{}/{}/{}", sync_ctx.sync_url, generation, frame_no, frame_no+1);
let max_frame_no = self.push_with_retry(uri, &sync_ctx.auth_token, frame.to_vec(), sync_ctx.max_retries).await?;
Ok(max_frame_no)
}

#[cfg(feature = "sync")]
async fn push_with_retry(&self, uri: String, auth_token: &Option<String>, frame: Vec<u8>, max_retries: usize) -> Result<u32> {
let mut nr_retries = 0;
loop {
let client = reqwest::Client::new();
let mut builder = client.post(uri.to_owned());
match auth_token {
Some(ref auth_token) => {
builder = builder.header("Authorization", format!("Bearer {}", auth_token.to_owned()));
}
None => {}
}
let res = builder.body(frame.to_vec()).send().await.unwrap();
if res.status().is_success() {
let resp = res.json::<serde_json::Value>().await.unwrap();
let max_frame_no = resp.get("max_frame_no").unwrap().as_u64().unwrap();
return Ok(max_frame_no as u32);
}
if nr_retries > max_retries {
return Err(crate::errors::Error::ConnectionFailed(format!("Failed to push frame: {}", res.status())));
}
let delay = std::time::Duration::from_millis(100 * (1 << nr_retries));
tokio::time::sleep(delay).await;
nr_retries += 1;
}
}

pub(crate) fn path(&self) -> &str {
&self.db_path
}
Expand Down
143 changes: 137 additions & 6 deletions libsql/src/sync.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,149 @@
const DEFAULT_MAX_RETRIES: usize = 5;

use bytes::Bytes;
use http::{HeaderValue, Request, Uri};
use tokio::sync::Mutex;

use crate::{util::ConnectorService, Result};

pub struct SyncContext {
pub sync_url: String,
pub auth_token: Option<String>,
pub max_retries: usize,
pub durable_frame_num: u32,
sync_url: String,
auth_token: Option<String>,
max_retries: usize,
durable_frame_num: u32,
db_path: String,
max_frame_no: u32,

client: hyper::Client<ConnectorService, hyper::Body>,
}

impl SyncContext {
pub fn new(sync_url: String, auth_token: Option<String>) -> Self {
Self {
pub async fn new(
sync_url: String,
auth_token: Option<String>,
db_path: impl Into<String>,
connector: ConnectorService,
) -> Self {
let mut ctx = Self {
sync_url,
auth_token,
durable_frame_num: 0,
max_retries: DEFAULT_MAX_RETRIES,
db_path: db_path.into(),
max_frame_no: 0,
client: hyper::Client::builder().build(connector),
};

ctx.read_and_update_metadata().await.unwrap();

ctx
}

pub(crate) async fn send_frame(
&mut self,
frame: Bytes,
generation: u32,
frame_no: u32,
) -> Result<u32> {
let url = format!(
"{}/sync/{}/{}/{}",
self.sync_url,
generation,
frame_no,
frame_no + 1
);

let maybe_auth_header = if let Some(auth_token) = &self.auth_token {
Some(HeaderValue::from_str(&format!("Bearer {}", auth_token)).unwrap())
} else {
None
};

let mut attempts = 0;

loop {
let mut req = Request::post(url.clone());

if let Some(auth_header) = &maybe_auth_header {
req.headers_mut()
.unwrap()
.insert("Authorization", auth_header.clone());
}

let req = req.body(frame.clone().into()).unwrap();

let res = self.client.request(req).await.unwrap();

if res.status().is_success() {
let body = hyper::body::to_bytes(res.into_body()).await.unwrap();

let resp = serde_json::from_slice::<serde_json::Value>(&body[..]).unwrap();

let max_frame_no = resp.get("max_frame_no").unwrap().as_u64().unwrap() as u32;

// Update our best known max_frame_no from the server and write it to disk.
self.set_max_frame_no(max_frame_no).await.unwrap();

return Ok(max_frame_no);
} else if res.status().is_server_error() || attempts < self.max_retries {
let delay = std::time::Duration::from_millis(100 * (1 << attempts));
tokio::time::sleep(delay).await;
attempts += 1;

continue;
} else {
return Err(crate::errors::Error::ConnectionFailed(format!(
"Failed to push frame: {}",
res.status()
)));
}
}
}

pub(crate) fn max_frame_no(&self) -> u32 {
self.max_frame_no
}

pub(crate) fn durable_frame_num(&self) -> u32 {
self.durable_frame_num
}

pub(crate) async fn set_max_frame_no(&mut self, max_frame_no: u32) -> Result<()> {
// TODO: check if max_frame_no is larger than current known max_frame_no
self.max_frame_no = max_frame_no;

self.update_metadata().await?;

Ok(())
}

async fn update_metadata(&mut self) -> Result<()> {
let path = format!("{}-info", self.db_path);

let contents = serde_json::to_vec(&MetadataJson {
max_frame_no: self.max_frame_no,
})
.unwrap();

tokio::fs::write(path, contents).await.unwrap();

Ok(())
}

async fn read_and_update_metadata(&mut self) -> Result<()> {
let path = format!("{}-info", self.db_path);

let contents = tokio::fs::read(&path).await.unwrap();

let metadata = serde_json::from_slice::<MetadataJson>(&contents[..]).unwrap();

self.max_frame_no = metadata.max_frame_no;

Ok(())
}
}

#[derive(serde::Serialize, serde::Deserialize)]
struct MetadataJson {
max_frame_no: u32,
}

0 comments on commit bdb428e

Please sign in to comment.