Skip to content

Commit

Permalink
Merge pull request #574 from tursodatabase/fix-lock-stealing
Browse files Browse the repository at this point in the history
fix bug in lock-stealing
  • Loading branch information
MarinPostma authored Nov 7, 2023
2 parents f9861d8 + fd06e6d commit 84e137f
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 13 deletions.
122 changes: 109 additions & 13 deletions libsql-server/src/connection/libsql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -389,24 +389,32 @@ unsafe extern "C" fn busy_handler<W: WalHook>(state: *mut c_void, _retries: c_in
// the current holder of the transaction has timedout, we will attempt to steal their
// lock.
_ = timeout => {
tracing::info!("transaction has timed-out, stealing lock");
// only a single connection gets to steal the lock, others retry
if let Some(mut lock) = state.slot.try_write() {
// We check that slot wasn't already stolen, and that their is still a slot.
// The ordering is relaxed because the atomic is only set under the slot lock.
if slot.is_stolen.compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed).is_ok() {
// The connection holding the current txn will set itself as stolen when it
// detects a timeout, so if we arrive to this point, then there is
// necessarily a slot, and this slot has to be the one we attempted to
// steal.
assert!(lock.take().is_some());

slot.abort();
tracing::info!("stole transaction lock");
if let Some(ref s) = *lock {
// The state contains the same lock as the one we're attempting to steal
if Arc::ptr_eq(s, &slot) {
// We check that slot wasn't already stolen, and that their is still a slot.
// The ordering is relaxed because the atomic is only set under the slot lock.
if slot.is_stolen.compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed).is_ok() {
// The connection holding the current txn will set itself as stolen when it
// detects a timeout, so if we arrive to this point, then there is
// necessarily a slot, and this slot has to be the one we attempted to
// steal.
assert!(lock.take().is_some());

slot.abort();
tracing::info!("stole transaction lock");
}
}
}
}

1
}
}

})
}

Expand Down Expand Up @@ -537,8 +545,20 @@ impl<W: WalHook> Connection<W> {
}
// lock was downgraded, notify a waiter
(Tx::Write, Tx::None | Tx::Read) => {
state.slot.write().take();
lock.slot.take();
let old_slot = lock
.slot
.take()
.expect("there should be a slot right after downgrading a txn");
let mut maybe_state_slot = state.slot.write();
// We need to make sure that the state slot is our slot before removing it.
if let Some(ref state_slot) = *maybe_state_slot {
if Arc::ptr_eq(state_slot, &old_slot) {
maybe_state_slot.take();
}
}

drop(maybe_state_slot);

state.notify.notify_one();
}
// nothing to do
Expand Down Expand Up @@ -918,10 +938,12 @@ where
#[cfg(test)]
mod test {
use itertools::Itertools;
use rand::Rng;
use sqld_libsql_bindings::wal_hook::TRANSPARENT_METHODS;
use tempfile::tempdir;
use tokio::task::JoinSet;

use crate::connection::Connection as _;
use crate::query_result_builder::test::{test_driver, TestBuilder};
use crate::query_result_builder::QueryResultBuilder;
use crate::DEFAULT_AUTO_CHECKPOINT;
Expand Down Expand Up @@ -1118,4 +1140,78 @@ mod test {
let epsilon = Duration::from_millis(100);
assert!((wait_time..wait_time + epsilon).contains(&elapsed));
}

/// The goal of this test is to run many conccurent transaction and hopefully catch a bug in
/// the lock stealing code. If this test becomes flaky check out the lock stealing code.
#[tokio::test]
async fn test_many_conccurent() {
let tmp = tempdir().unwrap();
let make_conn = MakeLibSqlConn::new(
tmp.path().into(),
&TRANSPARENT_METHODS,
|| (),
Default::default(),
Arc::new(DatabaseConfigStore::load(tmp.path()).unwrap()),
Arc::new([]),
100000000,
100000000,
DEFAULT_AUTO_CHECKPOINT,
watch::channel(None).1,
)
.await
.unwrap();
let auth = Authenticated::Authorized(Authorized {
namespace: None,
permission: Permission::FullAccess,
});

let conn = make_conn.make_connection().await.unwrap();
conn.execute_program(
Program::seq(&["CREATE TABLE test (x)"]),
auth.clone(),
TestBuilder::default(),
None,
)
.await
.unwrap();
let run_conn = |maker: Arc<MakeLibSqlConn<TransparentMethods>>| {
let auth = auth.clone();
async move {
for _ in 0..1000 {
let conn = maker.make_connection().await.unwrap();
let pgm = Program::seq(&["BEGIN IMMEDIATE", "INSERT INTO test VALUES (42)"]);
let res = conn
.execute_program(pgm, auth.clone(), TestBuilder::default(), None)
.await
.unwrap()
.into_ret();
for result in res {
result.unwrap();
}
// with 99% change, commit the txn
if rand::thread_rng().gen_range(0..100) > 1 {
let pgm = Program::seq(&["INSERT INTO test VALUES (43)", "COMMIT"]);
let res = conn
.execute_program(pgm, auth.clone(), TestBuilder::default(), None)
.await
.unwrap()
.into_ret();
for result in res {
result.unwrap();
}
}
}
}
};

let maker = Arc::new(make_conn);
let mut join_set = JoinSet::new();
for _ in 0..3 {
join_set.spawn(run_conn(maker.clone()));
}

while let Some(next) = join_set.join_next().await {
next.unwrap();
}
}
}
3 changes: 3 additions & 0 deletions libsql-server/src/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ pub mod libsql;
pub mod program;
pub mod write_proxy;

#[cfg(not(test))]
const TXN_TIMEOUT: Duration = Duration::from_secs(5);
#[cfg(test)]
const TXN_TIMEOUT: Duration = Duration::from_millis(100);

#[async_trait::async_trait]
pub trait Connection: Send + Sync + 'static {
Expand Down

0 comments on commit 84e137f

Please sign in to comment.