diff --git a/libsql/src/connection.rs b/libsql/src/connection.rs index 08966f5aad..8b553c700a 100644 --- a/libsql/src/connection.rs +++ b/libsql/src/connection.rs @@ -9,6 +9,13 @@ use crate::statement::Statement; use crate::transaction::Transaction; use crate::{Result, TransactionBehavior}; +#[derive(Clone, Copy, Debug, PartialEq)] +pub enum Op { + Insert = 0, + Delete = 1, + Update = 2, +} + #[async_trait::async_trait] pub(crate) trait Conn { async fn execute(&self, sql: &str, params: Params) -> Result; @@ -38,6 +45,10 @@ pub(crate) trait Conn { fn load_extension(&self, _dylib_path: &Path, _entry_point: Option<&str>) -> Result<()> { Err(crate::Error::LoadExtensionNotSupported) } + + fn add_update_hook(&self, _cb: Box) -> Result<()> { + Err(crate::Error::UpdateHookNotSupported) + } } /// A set of rows returned from `execute_batch`/`execute_transactional_batch`. It is essentially @@ -244,6 +255,13 @@ impl Connection { ) -> Result<()> { self.conn.load_extension(dylib_path.as_ref(), entry_point) } + + pub fn add_update_hook( + &self, + cb: Box, + ) -> Result<()> { + self.conn.add_update_hook(cb) + } } impl fmt::Debug for Connection { diff --git a/libsql/src/errors.rs b/libsql/src/errors.rs index cc82028826..ff793f2f9f 100644 --- a/libsql/src/errors.rs +++ b/libsql/src/errors.rs @@ -21,6 +21,8 @@ pub enum Error { SyncNotSupported(String), // Not in rusqlite #[error("Loading extension is only supported in local databases.")] LoadExtensionNotSupported, // Not in rusqlite + #[error("Update hooks are only supported in local databases.")] + UpdateHookNotSupported, // Not in rusqlite #[error("Column not found: {0}")] ColumnNotFound(i32), // Not in rusqlite #[error("Hrana: `{0}`")] diff --git a/libsql/src/lib.rs b/libsql/src/lib.rs index 06985f9ec2..e35a339bfb 100644 --- a/libsql/src/lib.rs +++ b/libsql/src/lib.rs @@ -175,7 +175,7 @@ cfg_hrana! { } pub use self::{ - connection::{BatchRows, Connection}, + connection::{BatchRows, Connection, Op}, database::{Builder, Database}, load_extension_guard::LoadExtensionGuard, rows::{Column, Row, Rows}, diff --git a/libsql/src/local/connection.rs b/libsql/src/local/connection.rs index bb1c7b7ab0..9141d95062 100644 --- a/libsql/src/local/connection.rs +++ b/libsql/src/local/connection.rs @@ -2,7 +2,10 @@ use crate::local::rows::BatchedRows; use crate::params::Params; -use crate::{connection::BatchRows, errors}; +use crate::{ + connection::{BatchRows, Op}, + errors, +}; use super::{Database, Error, Result, Rows, RowsFuture, Statement, Transaction}; @@ -11,6 +14,10 @@ use crate::TransactionBehavior; use libsql_sys::ffi; use std::{ffi::c_int, fmt, path::Path, sync::Arc}; +struct Container { + cb: Box, +} + /// A connection to a libSQL database. #[derive(Clone)] pub struct Connection { @@ -384,6 +391,24 @@ impl Connection { }) } + /// Installs update hook + pub fn add_update_hook(&self, cb: Box) { + let c = Box::new(Container { cb }); + let ptr: *mut Container = std::ptr::from_mut(Box::leak(c)); + + let old_data = unsafe { + ffi::sqlite3_update_hook( + self.raw, + Some(update_hook_cb), + ptr as *mut ::std::os::raw::c_void, + ) + }; + + if !old_data.is_null() { + let _ = unsafe { Box::from_raw(old_data as *mut Container) }; + } + } + pub fn enable_load_extension(&self, onoff: bool) -> Result<()> { // SQLITE_DBCONFIG_ENABLE_LOAD_EXTENSION configration verb accepts 2 additional parameters: an on/off flag and a pointer to an c_int where new state of the parameter will be written (or NULL if reporting back the setting is not needed) // See: https://sqlite.org/c3ref/c_dbconfig_defensive.html#sqlitedbconfigenableloadextension @@ -489,3 +514,25 @@ impl fmt::Debug for Connection { f.debug_struct("Connection").finish() } } + +#[no_mangle] +extern "C" fn update_hook_cb( + data: *mut ::std::os::raw::c_void, + op: ::std::os::raw::c_int, + db_name: *const ::std::os::raw::c_char, + table_name: *const ::std::os::raw::c_char, + row_id: i64, +) { + let db = unsafe { std::ffi::CStr::from_ptr(db_name).to_string_lossy() }; + let table = unsafe { std::ffi::CStr::from_ptr(table_name).to_string_lossy() }; + + let c = unsafe { &mut *(data as *mut Container) }; + let o = match op { + 9 => Op::Delete, + 18 => Op::Insert, + 23 => Op::Update, + _ => unreachable!("Unknown operation {op}"), + }; + + (*c.cb)(o, &db, &table, row_id); +} diff --git a/libsql/src/local/impls.rs b/libsql/src/local/impls.rs index 2338317a34..a935896632 100644 --- a/libsql/src/local/impls.rs +++ b/libsql/src/local/impls.rs @@ -1,9 +1,8 @@ use std::sync::Arc; use std::{fmt, path::Path}; -use crate::connection::BatchRows; +use crate::connection::{Conn, BatchRows, Op}; use crate::{ - connection::Conn, params::Params, rows::{ColumnsInner, RowInner, RowsInner}, statement::Stmt, @@ -79,6 +78,10 @@ impl Conn for LibsqlConnection { fn load_extension(&self, dylib_path: &Path, entry_point: Option<&str>) -> Result<()> { self.conn.load_extension(dylib_path, entry_point) } + + fn add_update_hook(&self, cb: Box) -> Result<()> { + Ok(self.conn.add_update_hook(cb)) + } } impl Drop for LibsqlConnection { diff --git a/libsql/tests/integration_tests.rs b/libsql/tests/integration_tests.rs index 92d8d358d8..85401a815a 100644 --- a/libsql/tests/integration_tests.rs +++ b/libsql/tests/integration_tests.rs @@ -4,11 +4,12 @@ use futures::{StreamExt, TryStreamExt}; use libsql::{ named_params, params, params::{IntoParams, IntoValue}, - Connection, Database, Value, + Connection, Database, Op, Value, }; use rand::distributions::Uniform; use rand::prelude::*; use std::collections::HashSet; +use std::sync::{Arc, Mutex}; async fn setup() -> Connection { let db = Database::open(":memory:").unwrap(); @@ -27,6 +28,77 @@ async fn enable_disable_extension() { conn.load_extension_disable().unwrap(); } +#[tokio::test] +async fn add_update_hook() { + let conn = setup().await; + + #[derive(PartialEq, Debug)] + struct Data { + op: Op, + db: String, + table: String, + row_id: i64, + } + + let d = Arc::new(Mutex::new(None::)); + + let d_clone = d.clone(); + conn.add_update_hook(Box::new(move |op, db, table, row_id| { + *d_clone.lock().unwrap() = Some(Data { + op, + db: db.to_string(), + table: table.to_string(), + row_id, + }); + })) + .unwrap(); + + let _ = conn + .execute("INSERT INTO users (id, name) VALUES (2, 'Alice')", ()) + .await + .unwrap(); + + assert_eq!( + *d.lock().unwrap().as_ref().unwrap(), + Data { + op: Op::Insert, + db: "main".to_string(), + table: "users".to_string(), + row_id: 1, + } + ); + + let _ = conn + .execute("UPDATE users SET name = 'Bob' WHERE id = 2", ()) + .await + .unwrap(); + + assert_eq!( + *d.lock().unwrap().as_ref().unwrap(), + Data { + op: Op::Update, + db: "main".to_string(), + table: "users".to_string(), + row_id: 1, + } + ); + + let _ = conn + .execute("DELETE FROM users WHERE id = 2", ()) + .await + .unwrap(); + + assert_eq!( + *d.lock().unwrap().as_ref().unwrap(), + Data { + op: Op::Delete, + db: "main".to_string(), + table: "users".to_string(), + row_id: 1, + } + ); +} + #[tokio::test] async fn connection_drops_before_statements() { let db = Database::open(":memory:").unwrap();