diff --git a/libsql-server/src/database/schema.rs b/libsql-server/src/database/schema.rs index be4f61cdc4..bfa26d1474 100644 --- a/libsql-server/src/database/schema.rs +++ b/libsql-server/src/database/schema.rs @@ -51,12 +51,15 @@ impl crate::connection::Connection for SchemaC } else { check_program_auth(&ctx, &migration, &self.config.get()).await?; let connection = self.connection.clone(); - validate_migration(&mut migration)?; + let disable_foreign_key = validate_migration(&mut migration)?; let migration = Arc::new(migration); let builder = tokio::task::spawn_blocking({ let migration = migration.clone(); move || { let res = connection.with_raw(|conn| -> crate::Result<_> { + if disable_foreign_key { + conn.execute("PRAGMA foreign_keys=off", ())?; + } let mut txn = conn .transaction_with_behavior(rusqlite::TransactionBehavior::Immediate) .map_err(|_| { @@ -73,6 +76,9 @@ impl crate::connection::Connection for SchemaC &QueryBuilderConfig::default(), ); txn.rollback().unwrap(); + if disable_foreign_key { + conn.execute("PRAGMA foreign_keys=on", ())?; + } Ok(ret?) }); diff --git a/libsql-server/src/query_analysis.rs b/libsql-server/src/query_analysis.rs index 38886f3df4..5762b86d4c 100644 --- a/libsql-server/src/query_analysis.rs +++ b/libsql-server/src/query_analysis.rs @@ -406,6 +406,14 @@ impl Statement { StmtKind::Read | StmtKind::TxnEnd | StmtKind::TxnBegin ) } + + pub(crate) fn is_pragma(&self) -> bool { + // adding a flag to the program would break the serialization, so we do that instead + match self.stmt.split_whitespace().next() { + Some(s) => s.trim().eq_ignore_ascii_case("pragma"), + None => false, + } + } } /// Given a an initial state and an array of queries, attempts to predict what the final state will diff --git a/libsql-server/src/schema/db.rs b/libsql-server/src/schema/db.rs index f644b2420f..7e20fe7822 100644 --- a/libsql-server/src/schema/db.rs +++ b/libsql-server/src/schema/db.rs @@ -9,6 +9,7 @@ use crate::namespace::NamespaceName; use crate::schema::status::{MigrationJobProgress, MigrationJobSummary}; use super::status::MigrationProgress; +use super::validate_migration; use super::{ status::{MigrationJob, MigrationTask}, Error, MigrationDetails, MigrationJobStatus, MigrationSummary, MigrationTaskStatus, @@ -328,15 +329,17 @@ pub(super) fn get_next_pending_migration_job( |row| { let job_id = row.get::<_, i64>(0)?; let status = MigrationJobStatus::from_int(row.get::<_, u64>(1)?); - let migration = serde_json::from_str(row.get_ref(2)?.as_str()?).unwrap(); + let mut migration = serde_json::from_str(row.get_ref(2)?.as_str()?).unwrap(); let schema = NamespaceName::from_string(row.get::<_, String>(3)?).unwrap(); + let disable_foreign_key = validate_migration(&mut migration).unwrap(); Ok(MigrationJob { schema, job_id, status, - migration, progress: Default::default(), task_error: None, + disable_foreign_key, + migration: migration.into(), }) }, ) diff --git a/libsql-server/src/schema/mod.rs b/libsql-server/src/schema/mod.rs index e7c7681262..7854276ed0 100644 --- a/libsql-server/src/schema/mod.rs +++ b/libsql-server/src/schema/mod.rs @@ -47,37 +47,47 @@ pub use scheduler::Scheduler; pub use status::{MigrationDetails, MigrationJobStatus, MigrationSummary, MigrationTaskStatus}; use crate::connection::program::Program; -use crate::query::{Params, Query}; -use crate::query_analysis::{Statement, StmtKind}; +use crate::query_analysis::StmtKind; -pub fn validate_migration(migration: &mut Program) -> Result<(), Error> { - if !migration.steps.is_empty() - && matches!(migration.steps[0].query.stmt.kind, StmtKind::TxnBegin) - { - if !matches!( - migration.steps.last().map(|s| &s.query.stmt.kind), - Some(&StmtKind::TxnEnd) - ) { - return Err(Error::MigrationContainsTransactionStatements); +// validate program is valid for migration, and return whether foreign keys should be disabled +pub fn validate_migration(migration: &mut Program) -> Result { + let mut steps = migration.steps_mut().unwrap().iter_mut().peekable(); + let mut explicit_tx = false; + let mut disable_foreign_key = false; + // skip pragmas prologue + while steps.next_if(|s| s.query.stmt.is_pragma()).is_some() { + disable_foreign_key = true; + } + + // first step can be a BEGIN + if let Some(step) = steps.next() { + if matches!(step.query.stmt.kind, StmtKind::TxnBegin) { + // neutralize step + step.query.stmt.stmt = r#"SELECT 'neutralized txn begin'"#.into(); + explicit_tx = true; } - migration.steps_mut().unwrap()[0].query = Query { - stmt: Statement::parse("PRAGMA max_page_count") - .next() - .unwrap() - .unwrap(), - params: Params::empty(), - want_rows: false, - }; - while let Some(step) = migration.steps.last() { - if !matches!(step.query.stmt.kind, StmtKind::TxnEnd) { - break; + } + + // skip all steps that are not tx items + while steps.next_if(|s| !s.query.stmt.kind.is_txn()).is_some() {} + + // last stmt can be a tx commit + while let Some(step) = steps.next_if(|s| s.query.stmt.kind.is_txn()) { + if matches!(step.query.stmt.kind, StmtKind::TxnEnd) { + if !explicit_tx { + // transaction is closed but was never opened + return Err(Error::MigrationContainsTransactionStatements); } - migration.steps_mut().unwrap().pop(); + // neutralize step + step.query.stmt.stmt = r#"SELECT 'neutralized txn component'"#.into(); } } - if migration.steps().iter().any(|s| s.query.stmt.kind.is_txn()) { - Err(Error::MigrationContainsTransactionStatements) - } else { - Ok(()) + + // validate pragma epilogue + if steps.by_ref().any(|s| !s.query.stmt.is_pragma()) { + // only accept pragmas after tx end + return Err(Error::MigrationContainsTransactionStatements); } + + Ok(disable_foreign_key) } diff --git a/libsql-server/src/schema/scheduler.rs b/libsql-server/src/schema/scheduler.rs index 8aacd62627..a88fa4f739 100644 --- a/libsql-server/src/schema/scheduler.rs +++ b/libsql-server/src/schema/scheduler.rs @@ -272,6 +272,7 @@ impl Scheduler { job.job_id(), self.namespace_store.clone(), self.migration_db.clone(), + job.disable_foreign_key, )); // do not enqueue anything until the schema migration is complete self.has_work = false; @@ -374,6 +375,7 @@ impl Scheduler { job.migration.clone(), task, block_writes, + job.disable_foreign_key, )); } else { // there is still a job, but the queue is empty, it means that we are waiting for the @@ -434,6 +436,7 @@ async fn try_step_task( migration: Arc, mut task: MigrationTask, block_writes: Arc, + disable_foreign_key: bool, ) -> WorkResult { let old_status = *task.status(); let error = match try_step_task_inner( @@ -443,6 +446,7 @@ async fn try_step_task( migration, &task, block_writes, + disable_foreign_key, ) .await { @@ -485,6 +489,7 @@ async fn try_step_task_inner( migration: Arc, task: &MigrationTask, block_writes: Arc, + disable_foreign_key: bool, ) -> Result<(MigrationTaskStatus, Option), Error> { let status = *task.status(); let mut db_connection = connection_maker @@ -508,6 +513,9 @@ async fn try_step_task_inner( let job_id = task.job_id(); let (status, error) = tokio::task::spawn_blocking(move || -> Result<_, Error> { db_connection.with_raw(move |conn| { + if disable_foreign_key { + conn.execute("PRAGMA foreign_keys=off", ())?; + } let mut txn = conn.transaction()?; match status { @@ -526,6 +534,10 @@ async fn try_step_task_inner( let (new_status, error) = step_task(&mut txn, job_id)?; txn.commit()?; + if disable_foreign_key { + conn.execute("PRAGMA foreign_keys=off", ())?; + } + if new_status.is_finished() { block_writes.store(false, std::sync::atomic::Ordering::SeqCst); } @@ -737,6 +749,7 @@ async fn step_job_run_success( job_id: i64, namespace_store: NamespaceStore, migration_db: Arc>, + disable_foreign_key: bool, ) -> WorkResult { try_step_job(MigrationJobStatus::WaitingRun, async move { // TODO: check that all tasks actually reported success before migration @@ -757,6 +770,9 @@ async fn step_job_run_success( .map_err(|e| Error::FailedToConnect(schema.clone(), e.into()))?; tokio::task::spawn_blocking(move || -> Result<(), Error> { connection.with_raw(|conn| -> Result<(), Error> { + if disable_foreign_key { + conn.execute("PRAGMA foreign_keys=off", ())?; + } let mut txn = conn.transaction()?; let schema_version = txn.query_row("PRAGMA schema_version", (), |row| row.get::<_, i64>(0))?; @@ -774,6 +790,9 @@ async fn step_job_run_success( txn.pragma_update(None, "schema_version", job_id)?; // update schema version to job_id? txn.commit()?; + if disable_foreign_key { + conn.execute("PRAGMA foreign_keys=on", ())?; + } } Ok(()) diff --git a/libsql-server/src/schema/snapshots/libsql_server__schema__db__test__pending_job-3.snap b/libsql-server/src/schema/snapshots/libsql_server__schema__db__test__pending_job-3.snap index c1aa3f09cd..61cede754e 100644 --- a/libsql-server/src/schema/snapshots/libsql_server__schema__db__test__pending_job-3.snap +++ b/libsql-server/src/schema/snapshots/libsql_server__schema__db__test__pending_job-3.snap @@ -35,4 +35,5 @@ MigrationJob { 0, ], task_error: None, + disable_foreign_key: false, } diff --git a/libsql-server/src/schema/snapshots/libsql_server__schema__db__test__pending_job.snap b/libsql-server/src/schema/snapshots/libsql_server__schema__db__test__pending_job.snap index 1f2502c4f1..f2bfca9d43 100644 --- a/libsql-server/src/schema/snapshots/libsql_server__schema__db__test__pending_job.snap +++ b/libsql-server/src/schema/snapshots/libsql_server__schema__db__test__pending_job.snap @@ -35,4 +35,5 @@ MigrationJob { 0, ], task_error: None, + disable_foreign_key: false, } diff --git a/libsql-server/src/schema/status.rs b/libsql-server/src/schema/status.rs index 30e241986c..b2b46533c8 100644 --- a/libsql-server/src/schema/status.rs +++ b/libsql-server/src/schema/status.rs @@ -45,6 +45,7 @@ pub struct MigrationJob { pub(super) progress: MigrationProgress, /// error info for the task that failed the job pub(super) task_error: Option<(i64, String, NamespaceName)>, + pub(super) disable_foreign_key: bool, } impl MigrationJob { diff --git a/libsql-server/tests/namespaces/shared_schema.rs b/libsql-server/tests/namespaces/shared_schema.rs index 680d2e80af..98faa310b8 100644 --- a/libsql-server/tests/namespaces/shared_schema.rs +++ b/libsql-server/tests/namespaces/shared_schema.rs @@ -216,53 +216,6 @@ fn no_job_created_when_migration_job_is_invalid() { sim.run().unwrap(); } -#[test] -fn migration_contains_txn_statements() { - let mut sim = Builder::new() - .simulation_duration(Duration::from_secs(100000)) - .build(); - let tmp = tempdir().unwrap(); - make_primary(&mut sim, tmp.path().to_path_buf()); - - sim.client("client", async { - let client = Client::new(); - client - .post( - "http://primary:9090/v1/namespaces/schema/create", - json!({"shared_schema": true }), - ) - .await - .unwrap(); - - let schema_db = Database::open_remote_with_connector( - "http://schema.primary:8080", - String::new(), - TurmoilConnector, - ) - .unwrap(); - let schema_conn = schema_db.connect().unwrap(); - schema_conn - .execute_batch("begin; create table test1 (c);commit") - .await - .unwrap(); - assert_debug_snapshot!(schema_conn - .execute_batch("begin; create table test (c)") - .await - .unwrap_err()); - - let resp = client - .get("http://schema.primary:8080/v1/jobs/2") - .await - .unwrap(); - assert_eq!(resp.status(), StatusCode::NOT_FOUND); - assert_debug_snapshot!(resp.json_value().await.unwrap()); - - Ok(()) - }); - - sim.run().unwrap(); -} - #[test] fn dry_run_failure() { let mut sim = Builder::new()