From 0888d85f69285e5e4a200c7199f2c3cd92d131a7 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Tue, 31 Oct 2023 14:50:01 +0100 Subject: [PATCH 01/26] introduce streaming proxy --- libsql-server/src/rpc/mod.rs | 1 + libsql-server/src/rpc/streaming_exec.rs | 555 ++++++++++++++++++++++++ 2 files changed, 556 insertions(+) create mode 100644 libsql-server/src/rpc/streaming_exec.rs diff --git a/libsql-server/src/rpc/mod.rs b/libsql-server/src/rpc/mod.rs index b29f9d1dc3..e81dc1aa10 100644 --- a/libsql-server/src/rpc/mod.rs +++ b/libsql-server/src/rpc/mod.rs @@ -20,6 +20,7 @@ pub mod proxy; pub mod replica_proxy; pub mod replication_log; pub mod replication_log_proxy; +pub mod streaming_exec; /// A tonic error code to signify that a namespace doesn't exist. pub const NAMESPACE_DOESNT_EXIST: &str = "NAMESPACE_DOESNT_EXIST"; diff --git a/libsql-server/src/rpc/streaming_exec.rs b/libsql-server/src/rpc/streaming_exec.rs new file mode 100644 index 0000000000..420bc90fe7 --- /dev/null +++ b/libsql-server/src/rpc/streaming_exec.rs @@ -0,0 +1,555 @@ +use std::sync::Arc; + +use futures_core::future::BoxFuture; +use futures_core::Stream; +use futures_option::OptionExt; +use libsql_replication::rpc::proxy::exec_req::Request; +use libsql_replication::rpc::proxy::exec_resp::{self, Response}; +use libsql_replication::rpc::proxy::resp_step::Step; +use libsql_replication::rpc::proxy::row_value::Value; +use libsql_replication::rpc::proxy::{ + AddRowValue, BeginRow, BeginRows, BeginStep, ColsDescription, DescribeCol, DescribeParam, + DescribeResp, ExecReq, ExecResp, Finish, FinishRow, FinishRows, FinishStep, Init, ProgramResp, + RespStep, RowValue, State as RpcState, StepError, StreamDescribeReq, +}; +use prost::Message; +use rusqlite::types::ValueRef; +use tokio::pin; +use tokio::sync::mpsc; +use tokio_stream::StreamExt; +use tonic::{Code, Status}; + +use crate::auth::Authenticated; +use crate::connection::Connection; +use crate::error::Error; +use crate::query_analysis::TxnStatus; +use crate::query_result_builder::{ + Column, QueryBuilderConfig, QueryResultBuilder, QueryResultBuilderError, +}; +use crate::replication::FrameNo; + +const MAX_RESPONSE_SIZE: usize = bytesize::ByteSize::mb(1).as_u64() as usize; + +pub fn make_proxy_stream( + conn: C, + auth: Authenticated, + request_stream: S, +) -> impl Stream> +where + S: Stream>, + C: Connection, +{ + make_proxy_stream_inner(conn, auth, request_stream, MAX_RESPONSE_SIZE) +} + +fn make_proxy_stream_inner( + conn: C, + auth: Authenticated, + request_stream: S, + max_program_resp_size: usize, +) -> impl Stream> +where + S: Stream>, + C: Connection, +{ + async_stream::stream! { + let mut current_request_fut: Option, u32)>> = None; + let (snd, mut recv) = mpsc::channel(1); + let conn = Arc::new(conn); + + pin!(request_stream); + + loop { + tokio::select! { + biased; + maybe_req = request_stream.next() => { + let Some(maybe_req) = maybe_req else { break }; + match maybe_req { + Err(e) => { + tracing::error!("stream error: {e}"); + break + } + Ok(req) => { + let request_id = req.request_id; + match req.request { + Some(Request::Execute(pgm)) => { + let Ok(pgm) = + crate::connection::program::Program::try_from(pgm.pgm.unwrap()) else { + yield Err(Status::new(Code::InvalidArgument, "invalid program")); + break + }; + let conn = conn.clone(); + let auth = auth.clone(); + let sender = snd.clone(); + + let fut = async move { + let builder = StreamResponseBuilder { + request_id, + sender, + current: None, + current_size: 0, + max_program_resp_size, + }; + + let ret = conn.execute_program(pgm, auth, builder, None).await.map(|_| ()); + (ret, request_id) + }; + + current_request_fut.replace(Box::pin(fut)); + } + Some(Request::Describe(StreamDescribeReq { stmt })) => { + let auth = auth.clone(); + let sender = snd.clone(); + let conn = conn.clone(); + let fut = async move { + let do_describe = || async move { + let ret = conn.describe(stmt, auth, None).await??; + Ok(DescribeResp { + cols: ret.cols.into_iter().map(|c| DescribeCol { name: c.name, decltype: c.decltype }).collect(), + params: ret.params.into_iter().map(|p| DescribeParam { name: p.name }).collect(), + is_explain: ret.is_explain, + is_readonly: ret.is_readonly + }) + }; + + let ret: crate::Result<()> = match do_describe().await { + Ok(resp) => { + let _ = sender.send(ExecResp { request_id, response: Some(Response::DescribeResp(resp)) }).await; + Ok(()) + } + Err(e) => Err(e), + }; + + (ret, request_id) + }; + + current_request_fut.replace(Box::pin(fut)); + + }, + None => { + yield Err(Status::new(Code::InvalidArgument, "invalid request")); + break + } + } + } + } + }, + Some(res) = recv.recv() => { + yield Ok(res); + }, + (ret, request_id) = current_request_fut.current(), if current_request_fut.is_some() => { + if let Err(e) = ret { + yield Ok(ExecResp { request_id, response: Some(Response::Error(e.into())) }) + } + }, + else => break, + } + } + } +} + +struct StreamResponseBuilder { + request_id: u32, + sender: mpsc::Sender, + current: Option, + current_size: usize, + max_program_resp_size: usize, +} + +impl StreamResponseBuilder { + fn current(&mut self) -> &mut ProgramResp { + self.current + .get_or_insert_with(|| ProgramResp { steps: Vec::new() }) + } + + fn push(&mut self, step: Step) -> Result<(), QueryResultBuilderError> { + let current = self.current(); + let step = RespStep { step: Some(step) }; + let size = step.encoded_len(); + current.steps.push(step); + self.current_size += size; + + if self.current_size >= self.max_program_resp_size { + self.flush()?; + } + + Ok(()) + } + + fn flush(&mut self) -> Result<(), QueryResultBuilderError> { + if let Some(current) = self.current.take() { + let resp = ExecResp { + request_id: self.request_id, + response: Some(exec_resp::Response::ProgramResp(current)), + }; + self.current_size = 0; + self.sender + .blocking_send(resp) + .map_err(|_| QueryResultBuilderError::Internal(anyhow::anyhow!("stream closed")))?; + } + + Ok(()) + } +} + +/// Apply the response to the the builder, and return whether the builder need more steps +pub fn apply_program_resp_to_builder( + config: &QueryBuilderConfig, + builder: &mut B, + resp: ProgramResp, + mut on_finish: impl FnMut(Option, TxnStatus), +) -> crate::Result { + for step in resp.steps { + let Some(step) = step.step else { + return Err(Error::PrimaryStreamMisuse); + }; + match step { + Step::Init(_) => builder.init(config)?, + Step::BeginStep(_) => builder.begin_step()?, + Step::FinishStep(FinishStep { + affected_row_count, + last_insert_rowid, + }) => builder.finish_step(affected_row_count, last_insert_rowid)?, + Step::StepError(StepError { error: Some(err) }) => { + builder.step_error(crate::error::Error::RpcQueryError(err))? + } + Step::ColsDescription(ColsDescription { columns }) => { + let cols = columns.iter().map(|c| Column { + name: &c.name, + decl_ty: c.decltype.as_deref(), + }); + builder.cols_description(cols)? + } + Step::BeginRows(_) => builder.begin_rows()?, + Step::BeginRow(_) => builder.begin_row()?, + Step::AddRowValue(AddRowValue { + val: Some(RowValue { value: Some(val) }), + }) => { + let val = match &val { + Value::Text(s) => ValueRef::Text(s.as_bytes()), + Value::Integer(i) => ValueRef::Integer(*i), + Value::Real(x) => ValueRef::Real(*x), + Value::Blob(b) => ValueRef::Blob(b.as_slice()), + Value::Null(_) => ValueRef::Null, + }; + builder.add_row_value(val)?; + } + Step::FinishRow(_) => builder.finish_row()?, + Step::FinishRows(_) => builder.finish_rows()?, + Step::Finish(f @ Finish { last_frame_no, .. }) => { + let txn_status = TxnStatus::from(f.state()); + on_finish(last_frame_no, txn_status); + builder.finish(last_frame_no, txn_status)?; + return Ok(false); + } + _ => return Err(Error::PrimaryStreamMisuse), + } + } + + Ok(true) +} + +impl QueryResultBuilder for StreamResponseBuilder { + type Ret = (); + + fn init(&mut self, _config: &QueryBuilderConfig) -> Result<(), QueryResultBuilderError> { + self.push(Step::Init(Init {}))?; + Ok(()) + } + + fn begin_step(&mut self) -> Result<(), QueryResultBuilderError> { + self.push(Step::BeginStep(BeginStep {}))?; + Ok(()) + } + + fn finish_step( + &mut self, + affected_row_count: u64, + last_insert_rowid: Option, + ) -> Result<(), QueryResultBuilderError> { + self.push(Step::FinishStep(FinishStep { + affected_row_count, + last_insert_rowid, + }))?; + Ok(()) + } + + fn step_error(&mut self, error: crate::error::Error) -> Result<(), QueryResultBuilderError> { + self.push(Step::StepError(StepError { + error: Some(error.into()), + }))?; + Ok(()) + } + + fn cols_description<'a>( + &mut self, + cols: impl IntoIterator>>, + ) -> Result<(), QueryResultBuilderError> { + self.push(Step::ColsDescription(ColsDescription { + columns: cols + .into_iter() + .map(Into::into) + .map(|c| libsql_replication::rpc::proxy::Column { + name: c.name.into(), + decltype: c.decl_ty.map(Into::into), + }) + .collect::>(), + }))?; + Ok(()) + } + + fn begin_rows(&mut self) -> Result<(), QueryResultBuilderError> { + self.push(Step::BeginRows(BeginRows {}))?; + Ok(()) + } + + fn begin_row(&mut self) -> Result<(), QueryResultBuilderError> { + self.push(Step::BeginRow(BeginRow {}))?; + Ok(()) + } + + fn add_row_value(&mut self, v: ValueRef) -> Result<(), QueryResultBuilderError> { + self.push(Step::AddRowValue(AddRowValue { + val: Some(v.into()), + }))?; + Ok(()) + } + + fn finish_row(&mut self) -> Result<(), QueryResultBuilderError> { + self.push(Step::FinishRow(FinishRow {}))?; + Ok(()) + } + + fn finish_rows(&mut self) -> Result<(), QueryResultBuilderError> { + self.push(Step::FinishRows(FinishRows {}))?; + Ok(()) + } + + fn finish( + &mut self, + last_frame_no: Option, + state: TxnStatus, + ) -> Result<(), QueryResultBuilderError> { + self.push(Step::Finish(Finish { + last_frame_no, + state: RpcState::from(state).into(), + }))?; + self.flush()?; + Ok(()) + } + + fn into_ret(self) -> Self::Ret {} +} + +#[cfg(test)] +pub mod test { + use insta::{assert_debug_snapshot, assert_snapshot}; + use tempfile::tempdir; + use tokio_stream::wrappers::ReceiverStream; + + use crate::auth::{Authorized, Permission}; + use crate::connection::libsql::LibSqlConnection; + use crate::connection::program::Program; + use crate::query_result_builder::test::{ + fsm_builder_driver, random_transition, TestBuilder, ValidateTraceBuilder, + }; + use crate::rpc::proxy::rpc::StreamProgramReq; + + use super::*; + + fn exec_req_stmt(s: &str, id: u32) -> ExecReq { + ExecReq { + request_id: id, + request: Some(Request::Execute(StreamProgramReq { + pgm: Some(Program::seq(&[s]).into()), + })), + } + } + + #[tokio::test] + async fn invalid_request() { + let tmp = tempdir().unwrap(); + let conn = LibSqlConnection::new_test(tmp.path()); + let (snd, rcv) = mpsc::channel(1); + let stream = make_proxy_stream(conn, Authenticated::Anonymous, ReceiverStream::new(rcv)); + pin!(stream); + + let req = ExecReq { + request_id: 0, + request: None, + }; + + snd.send(Ok(req)).await.unwrap(); + + assert_snapshot!(stream.next().await.unwrap().unwrap_err().to_string()); + } + + #[tokio::test] + async fn request_stream_dropped() { + let tmp = tempdir().unwrap(); + let conn = LibSqlConnection::new_test(tmp.path()); + let (snd, rcv) = mpsc::channel(1); + let auth = Authenticated::Authorized(Authorized { + namespace: None, + permission: Permission::FullAccess, + }); + let stream = make_proxy_stream(conn, auth, ReceiverStream::new(rcv)); + + pin!(stream); + + drop(snd); + + assert!(stream.next().await.is_none()); + } + + #[tokio::test] + async fn perform_query_simple() { + let tmp = tempdir().unwrap(); + let conn = LibSqlConnection::new_test(tmp.path()); + let (snd, rcv) = mpsc::channel(1); + let auth = Authenticated::Authorized(Authorized { + namespace: None, + permission: Permission::FullAccess, + }); + let stream = make_proxy_stream(conn, auth, ReceiverStream::new(rcv)); + + pin!(stream); + + let req = exec_req_stmt("create table test (foo)", 0); + + snd.send(Ok(req)).await.unwrap(); + + assert_debug_snapshot!(stream.next().await.unwrap().unwrap()); + } + + #[tokio::test] + async fn single_query_split_response() { + let tmp = tempdir().unwrap(); + let conn = LibSqlConnection::new_test(tmp.path()); + let (snd, rcv) = mpsc::channel(1); + let auth = Authenticated::Authorized(Authorized { + namespace: None, + permission: Permission::FullAccess, + }); + // limit the size of the response to force a split + let stream = make_proxy_stream_inner(conn, auth, ReceiverStream::new(rcv), 500); + + pin!(stream); + + let req = exec_req_stmt("create table test (foo)", 0); + snd.send(Ok(req)).await.unwrap(); + let resp = stream.next().await.unwrap().unwrap(); + assert_eq!(resp.request_id, 0); + for i in 1..50 { + let req = exec_req_stmt( + r#"insert into test values ("something moderately long")"#, + i, + ); + snd.send(Ok(req)).await.unwrap(); + let resp = stream.next().await.unwrap().unwrap(); + assert_eq!(resp.request_id, i); + } + + let req = exec_req_stmt("select * from test", 100); + snd.send(Ok(req)).await.unwrap(); + + let mut num_resp = 0; + let mut builder = TestBuilder::default(); + loop { + let Response::ProgramResp(resp) = + stream.next().await.unwrap().unwrap().response.unwrap() + else { + panic!() + }; + if !apply_program_resp_to_builder( + &QueryBuilderConfig::default(), + &mut builder, + resp, + |_, _| (), + ) + .unwrap() + { + break; + } + num_resp += 1; + } + + assert_eq!(num_resp, 3); + assert_debug_snapshot!(builder.into_ret()); + } + + #[tokio::test] + async fn request_interupted() { + let tmp = tempdir().unwrap(); + let conn = LibSqlConnection::new_test(tmp.path()); + let (snd, rcv) = mpsc::channel(2); + let auth = Authenticated::Authorized(Authorized { + namespace: None, + permission: Permission::FullAccess, + }); + let stream = make_proxy_stream(conn, auth, ReceiverStream::new(rcv)); + + pin!(stream); + + // request 0 should be dropped, and request 1 should be processed instead + let req1 = exec_req_stmt("create table test (foo)", 0); + let req2 = exec_req_stmt("create table test (foo)", 1); + snd.send(Ok(req1)).await.unwrap(); + snd.send(Ok(req2)).await.unwrap(); + + let resp = stream.next().await.unwrap().unwrap(); + assert_eq!(resp.request_id, 1); + } + + #[tokio::test] + async fn describe() { + let tmp = tempdir().unwrap(); + let conn = LibSqlConnection::new_test(tmp.path()); + let (snd, rcv) = mpsc::channel(1); + let auth = Authenticated::Authorized(Authorized { + namespace: None, + permission: Permission::FullAccess, + }); + let stream = make_proxy_stream(conn, auth, ReceiverStream::new(rcv)); + + pin!(stream); + + // request 0 should be dropped, and request 1 should be processed instead + let req = ExecReq { + request_id: 0, + request: Some(Request::Describe(StreamDescribeReq { + stmt: "select $hello".into(), + })), + }; + + snd.send(Ok(req)).await.unwrap(); + + assert_debug_snapshot!(stream.next().await.unwrap().unwrap()); + } + + /// This fuction returns a random, valid, program resp for use in other tests + pub fn random_valid_program_resp( + size: usize, + max_resp_size: usize, + ) -> (impl Stream, ValidateTraceBuilder) { + let (sender, receiver) = mpsc::channel(1); + let builder = StreamResponseBuilder { + request_id: 0, + sender, + current: None, + current_size: 0, + max_program_resp_size: max_resp_size, + }; + + let trace = random_transition(size); + tokio::task::spawn_blocking({ + let trace = trace.clone(); + move || fsm_builder_driver(&trace, builder) + }); + + ( + ReceiverStream::new(receiver), + ValidateTraceBuilder::new(trace), + ) + } +} From 80789ad88362c1579fbdec6b231b1b11c4975f00 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Tue, 31 Oct 2023 14:57:50 +0100 Subject: [PATCH 02/26] remove libsql proto --- libsql/proto/proxy.proto | 158 ----------------------------- libsql/proto/replication_log.proto | 32 ------ 2 files changed, 190 deletions(-) delete mode 100644 libsql/proto/proxy.proto delete mode 100644 libsql/proto/replication_log.proto diff --git a/libsql/proto/proxy.proto b/libsql/proto/proxy.proto deleted file mode 100644 index 87da8f967c..0000000000 --- a/libsql/proto/proxy.proto +++ /dev/null @@ -1,158 +0,0 @@ -syntax = "proto3"; -package proxy; - -message Queries { - repeated Query queries = 1; - // Uuid - string clientId = 2; -} - -message Query { - string stmt = 1; - oneof Params { - Positional positional = 2; - Named named = 3; - } - bool skip_rows = 4; -} - -message Positional { - repeated Value values = 1; -} - -message Named { - repeated string names = 1; - repeated Value values = 2; -} - -message QueryResult { - oneof row_result { - Error error = 1; - ResultRows row = 2; - } -} - -message Error { - enum ErrorCode { - SQLError = 0; - TxBusy = 1; - TxTimeout = 2; - Internal = 3; - } - - ErrorCode code = 1; - string message = 2; - int32 extended_code = 3; -} - -message ResultRows { - repeated Column column_descriptions = 1; - repeated Row rows = 2; - uint64 affected_row_count = 3; - optional int64 last_insert_rowid = 4; -} - -message DescribeRequest { - string client_id = 1; - string stmt = 2; -} - -message DescribeResult { - oneof describe_result { - Error error = 1; - Description description = 2; - } -} - -message Description { - repeated Column column_descriptions = 1; - repeated string param_names = 2; - uint64 param_count = 3; -} - -message Value { - /// bincode encoded Value - bytes data = 1; -} - -message Row { - repeated Value values = 1; -} - -message Column { - string name = 1; - optional string decltype = 3; -} - -message DisconnectMessage { - string clientId = 1; -} - -message Ack { } - -message ExecuteResults { - repeated QueryResult results = 1; - enum State { - Init = 0; - Invalid = 1; - Txn = 2; - } - /// State after executing the queries - State state = 2; - /// Primary frame_no after executing the request. - optional uint64 current_frame_no = 3; -} - -message Program { - repeated Step steps = 1; -} - -message Step { - optional Cond cond = 1; - Query query = 2; -} - -message Cond { - oneof cond { - OkCond ok = 1; - ErrCond err = 2; - NotCond not = 3; - AndCond and = 4; - OrCond or = 5; - IsAutocommitCond is_autocommit = 6; - } -} - -message OkCond { - int64 step = 1; -} - -message ErrCond { - int64 step = 1; -} - -message NotCond { - Cond cond = 1; -} - -message AndCond { - repeated Cond conds = 1; -} - -message OrCond { - repeated Cond conds = 1; -} - -message IsAutocommitCond { -} - -message ProgramReq { - string client_id = 1; - Program pgm = 2; -} - -service Proxy { - rpc Execute(ProgramReq) returns (ExecuteResults) {} - rpc Describe(DescribeRequest) returns (DescribeResult) {} - rpc Disconnect(DisconnectMessage) returns (Ack) {} -} diff --git a/libsql/proto/replication_log.proto b/libsql/proto/replication_log.proto deleted file mode 100644 index 4d5e651243..0000000000 --- a/libsql/proto/replication_log.proto +++ /dev/null @@ -1,32 +0,0 @@ -syntax = "proto3"; -package wal_log; - -message LogOffset { - uint64 next_offset = 1; -} - -message HelloRequest { } - -message HelloResponse { - /// Uuid of the current generation - string generation_id = 1; - /// First frame_no in the current generation - uint64 generation_start_index = 2; - /// Uuid of the database being replicated - string database_id = 3; -} - -message Frame { - bytes data = 1; -} - -message Frames { - repeated Frame frames = 1; -} - -service ReplicationLog { - rpc Hello(HelloRequest) returns (HelloResponse) {} - rpc LogEntries(LogOffset) returns (stream Frame) {} - rpc BatchLogEntries(LogOffset) returns (Frames) {} - rpc Snapshot(LogOffset) returns (stream Frame) {} -} From f6d73619f9f5726983490d9692535b441e5c9434 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Tue, 31 Oct 2023 14:58:09 +0100 Subject: [PATCH 03/26] add streaming replication proto --- libsql-replication/build.rs | 20 ++++ libsql-replication/proto/proxy.proto | 132 ++++++++++++++++++++++++--- libsql-replication/src/lib.rs | 22 ----- libsql-replication/src/rpc.rs | 18 ++++ 4 files changed, 157 insertions(+), 35 deletions(-) create mode 100644 libsql-replication/build.rs diff --git a/libsql-replication/build.rs b/libsql-replication/build.rs new file mode 100644 index 0000000000..fbf6f79416 --- /dev/null +++ b/libsql-replication/build.rs @@ -0,0 +1,20 @@ +use prost_build::Config; + +fn main() -> Result<(), Box> { + std::env::set_var("PROTOC", protobuf_src::protoc()); + + let mut config = Config::new(); + config.bytes([".wal_log"]); + tonic_build::configure() + .protoc_arg("--experimental_allow_proto3_optional") + .type_attribute(".proxy", "#[cfg_attr(test, derive(arbitrary::Arbitrary))]") + .compile_with_config( + config, + &["proto/replication_log.proto", "proto/proxy.proto"], + &["proto"], + )?; + + println!("cargo:rerun-if-changed=proto/"); + + Ok(()) +} diff --git a/libsql-replication/proto/proxy.proto b/libsql-replication/proto/proxy.proto index 87da8f967c..cbbfb67307 100644 --- a/libsql-replication/proto/proxy.proto +++ b/libsql-replication/proto/proxy.proto @@ -4,7 +4,7 @@ package proxy; message Queries { repeated Query queries = 1; // Uuid - string clientId = 2; + string client_id = 2; } message Query { @@ -34,10 +34,10 @@ message QueryResult { message Error { enum ErrorCode { - SQLError = 0; - TxBusy = 1; - TxTimeout = 2; - Internal = 3; + SQL_ERROR = 0; + TX_BUSY = 1; + TX_TIMEOUT = 2; + INTERNAL = 3; } ErrorCode code = 1; @@ -72,7 +72,7 @@ message Description { message Value { /// bincode encoded Value - bytes data = 1; + bytes data = 1; } message Row { @@ -85,18 +85,19 @@ message Column { } message DisconnectMessage { - string clientId = 1; + string client_id = 1; } message Ack { } +enum State { + INIT = 0; + INVALID = 1; + TXN = 2; +} + message ExecuteResults { repeated QueryResult results = 1; - enum State { - Init = 0; - Invalid = 1; - Txn = 2; - } /// State after executing the queries State state = 2; /// Primary frame_no after executing the request. @@ -111,7 +112,6 @@ message Step { optional Cond cond = 1; Query query = 2; } - message Cond { oneof cond { OkCond ok = 1; @@ -151,7 +151,113 @@ message ProgramReq { Program pgm = 2; } +/// Streaming exec request +message ExecReq { + /// id of the request. The response will contain this id. + uint32 request_id = 1; + oneof request { + StreamProgramReq execute = 2; + StreamDescribeReq describe = 3; + } +} + +/// Describe request for the streaming protocol +message StreamProgramReq { + Program pgm = 1; +} + +/// descibre request for the streaming protocol +message StreamDescribeReq { + string stmt = 1; +} + +/// Response message for the streaming proto + +/// Request response types +message Init { } +message BeginStep { } +message FinishStep { + uint64 affected_row_count = 1; + optional int64 last_insert_rowid = 2; +} +message StepError { + Error error = 1; +} +message ColsDescription { + repeated Column columns = 1; +} +message RowValue { + oneof value { + string text = 1; + int64 integer = 2; + double real = 3; + bytes blob = 4; + // null if present + bool null = 5; + } +} +message BeginRows { } +message BeginRow { } +message AddRowValue { + RowValue val = 1; +} +message FinishRow { } +message FinishRows { } +message Finish { + optional uint64 last_frame_no = 1; + State state = 2; +} + +/// Stream execx dexcribe response messages +message DescribeParam { + optional string name = 1; +} + +message DescribeCol { + string name = 1; + optional string decltype = 2; +} + +message DescribeResp { + repeated DescribeParam params = 1; + repeated DescribeCol cols = 2; + bool is_explain = 3; + bool is_readonly = 4; +} + +message RespStep { + oneof step { + Init init = 1; + BeginStep begin_step = 2; + FinishStep finish_step = 3; + StepError step_error = 4; + ColsDescription cols_description = 5; + BeginRows begin_rows = 6; + BeginRow begin_row = 7; + AddRowValue add_row_value = 8; + FinishRow finish_row = 9; + FinishRows finish_rows = 10; + Finish finish = 11; + } +} + +message ProgramResp { + repeated RespStep steps = 1; +} + +message ExecResp { + uint32 request_id = 1; + oneof response { + ProgramResp program_resp = 2; + DescribeResp describe_resp = 3; + Error error = 4; + } +} + service Proxy { + rpc StreamExec(stream ExecReq) returns (stream ExecResp) {} + + // Deprecated: rpc Execute(ProgramReq) returns (ExecuteResults) {} rpc Describe(DescribeRequest) returns (DescribeResult) {} rpc Disconnect(DisconnectMessage) returns (Ack) {} diff --git a/libsql-replication/src/lib.rs b/libsql-replication/src/lib.rs index 32986167de..b569ceca41 100644 --- a/libsql-replication/src/lib.rs +++ b/libsql-replication/src/lib.rs @@ -8,25 +8,3 @@ pub mod snapshot; mod error; pub const LIBSQL_PAGE_SIZE: usize = 4096; - -#[cfg(test)] -pub mod test { - use arbitrary::Unstructured; - use bytes::Bytes; - - /// generate an arbitrary rpc value. see build.rs for usage. - pub fn arbitrary_rpc_value(_u: &mut Unstructured) -> arbitrary::Result> { - todo!(); - // let data = bincode::serialize(&crate::query::Value::arbitrary(u)?).unwrap(); - // - // Ok(data) - } - - /// generate an arbitrary `Bytes` value. see build.rs for usage. - pub fn arbitrary_bytes(_u: &mut Unstructured) -> arbitrary::Result { - todo!() - // let v: Vec = Arbitrary::arbitrary(u)?; - // - // Ok(v.into()) - } -} diff --git a/libsql-replication/src/rpc.rs b/libsql-replication/src/rpc.rs index 47d1fe7644..667daad513 100644 --- a/libsql-replication/src/rpc.rs +++ b/libsql-replication/src/rpc.rs @@ -1,6 +1,24 @@ pub mod proxy { #![allow(clippy::all)] include!("generated/proxy.rs"); + + use sqld_libsql_bindings::rusqlite::types::ValueRef; + + impl From> for RowValue { + fn from(value: ValueRef<'_>) -> Self { + use row_value::Value; + + let value = Some(match value { + ValueRef::Null => Value::Null(true), + ValueRef::Integer(i) => Value::Integer(i), + ValueRef::Real(x) => Value::Real(x), + ValueRef::Text(s) => Value::Text(String::from_utf8(s.to_vec()).unwrap()), + ValueRef::Blob(b) => Value::Blob(b.to_vec()), + }); + + RowValue { value } + } + } } pub mod replication { From 8a093e9b68a650758686791b21db2f69f4418cd8 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Tue, 31 Oct 2023 15:02:38 +0100 Subject: [PATCH 04/26] add futures_option dep --- libsql-server/Cargo.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/libsql-server/Cargo.toml b/libsql-server/Cargo.toml index e8e006cbff..6f500ab786 100644 --- a/libsql-server/Cargo.toml +++ b/libsql-server/Cargo.toml @@ -26,6 +26,7 @@ enclose = "1.1" fallible-iterator = "0.3.0" futures = "0.3.25" futures-core = "0.3" +futures-option = "0.2.0" hmac = "0.12" hyper = { version = "0.14.23", features = ["http2"] } hyper-rustls = { git = "https://github.com/rustls/hyper-rustls.git", rev = "163b3f5" } From f66fd99b273702cd91eb7e41efeb122990a55909 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Tue, 31 Oct 2023 15:03:09 +0100 Subject: [PATCH 05/26] implement streaming proxy --- libsql-server/src/connection/write_proxy.rs | 396 +++++++++++++------- 1 file changed, 261 insertions(+), 135 deletions(-) diff --git a/libsql-server/src/connection/write_proxy.rs b/libsql-server/src/connection/write_proxy.rs index 3d8f209e81..f7c96b8efc 100644 --- a/libsql-server/src/connection/write_proxy.rs +++ b/libsql-server/src/connection/write_proxy.rs @@ -1,35 +1,34 @@ use std::path::PathBuf; use std::sync::Arc; +use futures_core::future::BoxFuture; +use futures_core::Stream; +use libsql_replication::rpc::proxy::proxy_client::ProxyClient; +use libsql_replication::rpc::proxy::{ + exec_req, exec_resp, ExecReq, ExecResp, StreamDescribeReq, StreamProgramReq, +}; use libsql_replication::rpc::replication::NAMESPACE_METADATA_KEY; use parking_lot::Mutex as PMutex; -use rusqlite::types::ValueRef; use sqld_libsql_bindings::wal_hook::{TransparentMethods, TRANSPARENT_METHODS}; -use tokio::sync::{watch, Mutex}; +use tokio::sync::{mpsc, watch, Mutex}; +use tokio_stream::StreamExt; use tonic::metadata::BinaryMetadataValue; use tonic::transport::Channel; -use tonic::Request; -use uuid::Uuid; +use tonic::{Request, Streaming}; use crate::auth::Authenticated; +use crate::connection::program::{DescribeCol, DescribeParam}; use crate::error::Error; -use crate::metrics::REQUESTS_PROXIED; use crate::namespace::NamespaceName; -use crate::query::Value; -use crate::query_analysis::State; -use crate::query_result_builder::{ - Column, QueryBuilderConfig, QueryResultBuilder, QueryResultBuilderError, -}; +use crate::query_analysis::TxnStatus; +use crate::query_result_builder::{QueryBuilderConfig, QueryResultBuilder}; use crate::replication::FrameNo; -use crate::rpc::proxy::rpc::proxy_client::ProxyClient; -use crate::rpc::proxy::rpc::query_result::RowResult; -use crate::rpc::proxy::rpc::{DisconnectMessage, ExecuteResults}; use crate::stats::Stats; use crate::{Result, DEFAULT_AUTO_CHECKPOINT}; use super::config::DatabaseConfigStore; use super::libsql::{LibSqlConnection, MakeLibSqlConn}; -use super::program::DescribeResult; +use super::program::DescribeResponse; use super::Connection; use super::{MakeConnection, Program}; @@ -105,13 +104,11 @@ impl MakeConnection for MakeWriteProxyConn { } } -#[derive(Debug)] -pub struct WriteProxyConnection { +pub struct WriteProxyConnection> { /// Lazily initialized read connection read_conn: LibSqlConnection, write_proxy: ProxyClient, - state: Mutex, - client_id: Uuid, + state: Mutex, /// FrameNo of the last write performed by this connection on the primary. /// any subsequent read on this connection must wait for the replicator to catch up with this /// frame_no @@ -121,51 +118,8 @@ pub struct WriteProxyConnection { builder_config: QueryBuilderConfig, stats: Arc, namespace: NamespaceName, -} - -fn execute_results_to_builder( - execute_result: ExecuteResults, - mut builder: B, - config: &QueryBuilderConfig, -) -> Result { - builder.init(config)?; - for result in execute_result.results { - match result.row_result { - Some(RowResult::Row(rows)) => { - builder.begin_step()?; - builder.cols_description(rows.column_descriptions.iter().map(|c| Column { - name: &c.name, - decl_ty: c.decltype.as_deref(), - }))?; - - builder.begin_rows()?; - for row in rows.rows { - builder.begin_row()?; - for value in row.values { - let value: Value = bincode::deserialize(&value.data) - // something is wrong, better stop right here - .map_err(QueryResultBuilderError::from_any)?; - builder.add_row_value(ValueRef::from(&value))?; - } - builder.finish_row()?; - } - - builder.finish_rows()?; - - builder.finish_step(rows.affected_row_count, rows.last_insert_rowid)?; - } - Some(RowResult::Error(err)) => { - builder.begin_step()?; - builder.step_error(Error::RpcQueryError(err))?; - builder.finish_step(0, None)?; - } - None => (), - } - } - - builder.finish(execute_result.current_frame_no)?; - Ok(builder) + remote_conn: Mutex>>, } impl WriteProxyConnection { @@ -181,58 +135,73 @@ impl WriteProxyConnection { Ok(Self { read_conn, write_proxy, - state: Mutex::new(State::Init), - client_id: Uuid::new_v4(), - last_write_frame_no: PMutex::new(None), + state: Mutex::new(TxnStatus::Init), + last_write_frame_no: Default::default(), applied_frame_no_receiver, builder_config, stats, namespace, + remote_conn: Default::default(), }) } + async fn with_remote_conn( + &self, + auth: Authenticated, + builder_config: QueryBuilderConfig, + cb: F, + ) -> crate::Result + where + F: FnOnce(&mut RemoteConnection) -> BoxFuture<'_, crate::Result>, + { + let mut remote_conn = self.remote_conn.lock().await; + if remote_conn.is_some() { + cb(remote_conn.as_mut().unwrap()).await + } else { + let conn = RemoteConnection::connect( + self.write_proxy.clone(), + self.namespace.clone(), + auth, + builder_config, + ) + .await?; + let conn = remote_conn.insert(conn); + cb(conn).await + } + } + async fn execute_remote( &self, pgm: Program, - state: &mut State, + status: &mut TxnStatus, auth: Authenticated, builder: B, - ) -> Result<(B, State)> { - REQUESTS_PROXIED.increment(1); - + ) -> Result { self.stats.inc_write_requests_delegated(); - let mut client = self.write_proxy.clone(); - - let mut req = Request::new(crate::rpc::proxy::rpc::ProgramReq { - client_id: self.client_id.to_string(), - pgm: Some(pgm.into()), - }); - - let namespace = BinaryMetadataValue::from_bytes(self.namespace.as_slice()); - req.metadata_mut() - .insert_bin(NAMESPACE_METADATA_KEY, namespace); - auth.upgrade_grpc_request(&mut req); - - match client.execute(req).await { - Ok(r) => { - let execute_result = r.into_inner(); - *state = execute_result.state().into(); - let current_frame_no = execute_result.current_frame_no; - let builder = - execute_results_to_builder(execute_result, builder, &self.builder_config)?; - if let Some(current_frame_no) = current_frame_no { - self.update_last_write_frame_no(current_frame_no); - } - - Ok((builder, *state)) - } - Err(e) => { - // Set state to invalid, so next call is sent to remote, and we have a chance - // to recover state. - *state = State::Invalid; - Err(Error::RpcQueryExecutionError(e)) + *status = TxnStatus::Invalid; + let res = self + .with_remote_conn(auth, self.builder_config, |conn| { + Box::pin(conn.execute(pgm, builder)) + }) + .await; + + let (builder, new_status, new_frame_no) = match res { + Ok(res) => res, + Err(e @ (Error::PrimaryStreamDisconnect | Error::PrimaryStreamMisuse)) => { + // drop the connection, and reset the state. + self.remote_conn.lock().await.take(); + *status = TxnStatus::Init; + return Err(e); } + Err(e) => return Err(e), + }; + + *status = new_status; + if let Some(current_frame_no) = new_frame_no { + self.update_last_write_frame_no(current_frame_no); } + + Ok(builder) } fn update_last_write_frame_no(&self, new_frame_no: FrameNo) { @@ -264,6 +233,161 @@ impl WriteProxyConnection { } } +struct RemoteConnection> { + response_stream: R, + request_sender: mpsc::Sender, + current_request_id: u32, + builder_config: QueryBuilderConfig, +} + +impl RemoteConnection { + async fn connect( + mut client: ProxyClient, + namespace: NamespaceName, + auth: Authenticated, + builder_config: QueryBuilderConfig, + ) -> crate::Result { + let (request_sender, receiver) = mpsc::channel(1); + + let stream = tokio_stream::wrappers::ReceiverStream::new(receiver); + let mut req = Request::new(stream); + let namespace = BinaryMetadataValue::from_bytes(namespace.as_slice()); + req.metadata_mut() + .insert_bin(NAMESPACE_METADATA_KEY, namespace); + auth.upgrade_grpc_request(&mut req); + let response_stream = client.stream_exec(req).await.unwrap().into_inner(); + + Ok(Self { + response_stream, + request_sender, + current_request_id: 0, + builder_config, + }) + } +} + +impl RemoteConnection +where + R: Stream> + Unpin, +{ + /// Perform a request on to the remote peer, and call message_cb for every message received for + /// that request. message cb should return whether to expect more message for that request. + async fn make_request( + &mut self, + req: exec_req::Request, + mut response_cb: impl FnMut(exec_resp::Response) -> crate::Result, + ) -> crate::Result<()> { + let request_id = self.current_request_id; + self.current_request_id += 1; + + let req = ExecReq { + request_id, + request: Some(req), + }; + + self.request_sender + .send(req) + .await + .map_err(|_| Error::PrimaryStreamDisconnect)?; + + while let Some(resp) = self.response_stream.next().await { + match resp { + Ok(resp) => { + // there was an interuption, and we moved to the next query + if resp.request_id > request_id { + return Err(Error::PrimaryStreamInterupted); + } + + // we can ignore response for previously interupted requests + if resp.request_id < request_id { + continue; + } + + if !response_cb(resp.response.ok_or(Error::PrimaryStreamMisuse)?)? { + break; + } + } + Err(e) => { + tracing::error!("received an error from connection stream: {e}"); + return Err(Error::PrimaryStreamDisconnect); + } + } + } + + Ok(()) + } + + async fn execute( + &mut self, + program: Program, + mut builder: B, + ) -> crate::Result<(B, TxnStatus, Option)> { + let mut txn_status = TxnStatus::Invalid; + let mut new_frame_no = None; + let builder_config = self.builder_config; + let cb = |response: exec_resp::Response| match response { + exec_resp::Response::ProgramResp(resp) => { + crate::rpc::streaming_exec::apply_program_resp_to_builder( + &builder_config, + &mut builder, + resp, + |last_frame_no, status| { + txn_status = status; + new_frame_no = last_frame_no; + }, + ) + } + exec_resp::Response::DescribeResp(_) => Err(Error::PrimaryStreamMisuse), + exec_resp::Response::Error(e) => Err(Error::RpcQueryError(e)), + }; + + self.make_request( + exec_req::Request::Execute(StreamProgramReq { + pgm: Some(program.into()), + }), + cb, + ) + .await?; + + Ok((builder, txn_status, new_frame_no)) + } + + #[allow(dead_code)] // reference implementation + async fn describe(&mut self, stmt: String) -> crate::Result { + let mut out = None; + let cb = |response: exec_resp::Response| match response { + exec_resp::Response::DescribeResp(resp) => { + out = Some(DescribeResponse { + params: resp + .params + .into_iter() + .map(|p| DescribeParam { name: p.name }) + .collect(), + cols: resp + .cols + .into_iter() + .map(|c| DescribeCol { + name: c.name, + decltype: c.decltype, + }) + .collect(), + is_explain: resp.is_explain, + is_readonly: resp.is_readonly, + }); + + Ok(false) + } + exec_resp::Response::Error(e) => Err(Error::RpcQueryError(e)), + exec_resp::Response::ProgramResp(_) => Err(Error::PrimaryStreamMisuse), + }; + + self.make_request(exec_req::Request::Describe(StreamDescribeReq { stmt }), cb) + .await?; + + out.ok_or(Error::PrimaryStreamMisuse) + } +} + #[async_trait::async_trait] impl Connection for WriteProxyConnection { async fn execute_program( @@ -272,26 +396,30 @@ impl Connection for WriteProxyConnection { auth: Authenticated, builder: B, replication_index: Option, - ) -> Result<(B, State)> { + ) -> Result { let mut state = self.state.lock().await; // This is a fresh namespace, and it is not replicated yet, proxy the first request. if self.applied_frame_no_receiver.borrow().is_none() { self.execute_remote(pgm, &mut state, auth, builder).await - } else if *state == State::Init && pgm.is_read_only() { + } else if *state == TxnStatus::Init && pgm.is_read_only() { + // set the state to invalid before doing anything, and set it to a valid state after. + *state = TxnStatus::Invalid; self.wait_replication_sync(replication_index).await?; // We know that this program won't perform any writes. We attempt to run it on the // replica. If it leaves an open transaction, then this program is an interactive // transaction, so we rollback the replica, and execute again on the primary. - let (builder, new_state) = self + let builder = self .read_conn .execute_program(pgm.clone(), auth.clone(), builder, replication_index) .await?; - if new_state != State::Init { + let new_state = self.read_conn.txn_status()?; + if new_state != TxnStatus::Init { self.read_conn.rollback(auth.clone()).await?; self.execute_remote(pgm, &mut state, auth, builder).await } else { - Ok((builder, new_state)) + *state = new_state; + Ok(builder) } } else { self.execute_remote(pgm, &mut state, auth, builder).await @@ -303,7 +431,7 @@ impl Connection for WriteProxyConnection { sql: String, auth: Authenticated, replication_index: Option, - ) -> Result { + ) -> Result> { self.wait_replication_sync(replication_index).await?; self.read_conn.describe(sql, auth, replication_index).await } @@ -311,8 +439,8 @@ impl Connection for WriteProxyConnection { async fn is_autocommit(&self) -> Result { let state = self.state.lock().await; Ok(match *state { - State::Txn => false, - State::Init | State::Invalid => true, + TxnStatus::Txn => false, + TxnStatus::Init | TxnStatus::Invalid => true, }) } @@ -331,34 +459,32 @@ impl Connection for WriteProxyConnection { } } -impl Drop for WriteProxyConnection { - fn drop(&mut self) { - // best effort attempt to disconnect - let mut remote = self.write_proxy.clone(); - let client_id = self.client_id.to_string(); - tokio::spawn(async move { - let _ = remote.disconnect(DisconnectMessage { client_id }).await; - }); - } -} - #[cfg(test)] pub mod test { - // use arbitrary::{Arbitrary, Unstructured}; - // use rand::Fill; - // - // use super::*; - // use crate::query_result_builder::test::test_driver; - - // In this test, we generate random ExecuteResults, and ensures that the `execute_results_to_builder` drives the builder FSM correctly. - // #[test] - // fn test_execute_results_to_builder() { - // test_driver(1000, |b| { - // let mut data = [0; 10_000]; - // data.try_fill(&mut rand::thread_rng()).unwrap(); - // let mut un = Unstructured::new(&data); - // let res = ExecuteResults::arbitrary(&mut un).unwrap(); - // execute_results_to_builder(res, b, &QueryBuilderConfig::default()) - // }); - // } + use super::*; + use crate::rpc::streaming_exec::test::random_valid_program_resp; + + #[tokio::test] + // in this test we do a roundtrip: generate a random valid program, stream it to + // RemoteConnection, and make sure that the remote connection drives the builder with the same + // state transitions. + async fn validate_random_stream_response() { + for _ in 0..10 { + let (response_stream, validator) = random_valid_program_resp(500, 150); + let (request_sender, _request_recver) = mpsc::channel(1); + let mut remote = RemoteConnection { + response_stream: response_stream.map(Ok), + request_sender, + current_request_id: 0, + builder_config: QueryBuilderConfig::default(), + }; + + remote + .execute(Program::seq(&[]), validator) + .await + .unwrap() + .0 + .into_ret(); + } + } } From bdfc01dd533ab3befb3ee5fb2846f5e668528334 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Tue, 31 Oct 2023 15:03:47 +0100 Subject: [PATCH 06/26] implement rpc streaming proxy --- libsql-server/src/rpc/proxy.rs | 128 +++++++++++++++++++++++---------- 1 file changed, 89 insertions(+), 39 deletions(-) diff --git a/libsql-server/src/rpc/proxy.rs b/libsql-server/src/rpc/proxy.rs index 0d0dd6d87d..e11fee8338 100644 --- a/libsql-server/src/rpc/proxy.rs +++ b/libsql-server/src/rpc/proxy.rs @@ -1,25 +1,30 @@ use std::collections::HashMap; +use std::pin::Pin; use std::str::FromStr; use std::sync::Arc; use async_lock::{RwLock, RwLockUpgradableReadGuard}; +use futures_core::Stream; +use libsql_replication::rpc::proxy::proxy_server::Proxy; +use libsql_replication::rpc::proxy::query_result::RowResult; +use libsql_replication::rpc::proxy::{ + describe_result, Ack, DescribeRequest, DescribeResult, Description, DisconnectMessage, ExecReq, + ExecResp, ExecuteResults, QueryResult, ResultRows, Row, +}; +use rusqlite::types::ValueRef; use uuid::Uuid; use crate::auth::{Auth, Authenticated}; use crate::connection::Connection; use crate::database::{Database, PrimaryConnection}; use crate::namespace::{NamespaceStore, PrimaryNamespaceMaker}; +use crate::query_analysis::TxnStatus; use crate::query_result_builder::{ Column, QueryBuilderConfig, QueryResultBuilder, QueryResultBuilderError, }; use crate::replication::FrameNo; +use crate::rpc::streaming_exec::make_proxy_stream; -use self::rpc::proxy_server::Proxy; -use self::rpc::query_result::RowResult; -use self::rpc::{ - describe_result, Ack, DescribeRequest, DescribeResult, Description, DisconnectMessage, - ExecuteResults, QueryResult, ResultRows, Row, -}; use super::NAMESPACE_DOESNT_EXIST; pub mod rpc { @@ -31,7 +36,7 @@ pub mod rpc { use crate::query_analysis::Statement; use crate::{connection, error::Error as SqldError}; - use self::{error::ErrorCode, execute_results::State}; + use error::ErrorCode; impl From for Error { fn from(other: SqldError) -> Self { @@ -56,22 +61,33 @@ pub mod rpc { } } - impl From for State { - fn from(other: crate::query_analysis::State) -> Self { + impl From for ErrorCode { + fn from(other: SqldError) -> Self { match other { - crate::query_analysis::State::Txn => Self::Txn, - crate::query_analysis::State::Init => Self::Init, - crate::query_analysis::State::Invalid => Self::Invalid, + SqldError::LibSqlInvalidQueryParams(_) => ErrorCode::SqlError, + SqldError::LibSqlTxTimeout => ErrorCode::TxTimeout, + SqldError::LibSqlTxBusy => ErrorCode::TxBusy, + _ => ErrorCode::Internal, } } } - impl From for crate::query_analysis::State { + impl From for State { + fn from(other: crate::query_analysis::TxnStatus) -> Self { + match other { + crate::query_analysis::TxnStatus::Txn => Self::Txn, + crate::query_analysis::TxnStatus::Init => Self::Init, + crate::query_analysis::TxnStatus::Invalid => Self::Invalid, + } + } + } + + impl From for crate::query_analysis::TxnStatus { fn from(other: State) -> Self { match other { - State::Txn => crate::query_analysis::State::Txn, - State::Init => crate::query_analysis::State::Init, - State::Invalid => crate::query_analysis::State::Invalid, + State::Txn => crate::query_analysis::TxnStatus::Txn, + State::Init => crate::query_analysis::TxnStatus::Init, + State::Invalid => crate::query_analysis::TxnStatus::Invalid, } } } @@ -291,7 +307,8 @@ impl ProxyService { } #[derive(Debug, Default)] -struct ExecuteResultBuilder { +struct ExecuteResultsBuilder { + output: Option, results: Vec, current_rows: Vec, current_row: rpc::Row, @@ -302,8 +319,8 @@ struct ExecuteResultBuilder { current_step_size: u64, } -impl QueryResultBuilder for ExecuteResultBuilder { - type Ret = Vec; +impl QueryResultBuilder for ExecuteResultsBuilder { + type Ret = ExecuteResults; fn init(&mut self, config: &QueryBuilderConfig) -> Result<(), QueryResultBuilderError> { *self = Self { @@ -398,10 +415,7 @@ impl QueryResultBuilder for ExecuteResultBuilder { Ok(()) } - fn add_row_value( - &mut self, - v: rusqlite::types::ValueRef, - ) -> Result<(), QueryResultBuilderError> { + fn add_row_value(&mut self, v: ValueRef) -> Result<(), QueryResultBuilderError> { let data = bincode::serialize( &crate::query::Value::try_from(v).map_err(QueryResultBuilderError::from_any)?, ) @@ -436,12 +450,21 @@ impl QueryResultBuilder for ExecuteResultBuilder { Ok(()) } - fn finish(&mut self, _last_frame_no: Option) -> Result<(), QueryResultBuilderError> { + fn finish( + &mut self, + last_frame_no: Option, + txn_status: TxnStatus, + ) -> Result<(), QueryResultBuilderError> { + self.output = Some(ExecuteResults { + results: std::mem::take(&mut self.results), + state: rpc::State::from(txn_status).into(), + current_frame_no: last_frame_no, + }); Ok(()) } fn into_ret(self) -> Self::Ret { - self.results + self.output.unwrap() } } @@ -460,6 +483,42 @@ pub async fn garbage_collect(clients: &mut HashMap> #[tonic::async_trait] impl Proxy for ProxyService { + type StreamExecStream = Pin> + Send>>; + + async fn stream_exec( + &self, + req: tonic::Request>, + ) -> Result, tonic::Status> { + let auth = if let Some(auth) = &self.auth { + auth.authenticate_grpc(&req, self.disable_namespaces)? + } else { + Authenticated::from_proxy_grpc_request(&req, self.disable_namespaces)? + }; + + let namespace = super::extract_namespace(self.disable_namespaces, &req)?; + let (connection_maker, _new_frame_notifier) = self + .namespaces + .with(namespace, |ns| { + let connection_maker = ns.db.connection_maker(); + let notifier = ns.db.logger.new_frame_notifier.subscribe(); + (connection_maker, notifier) + }) + .await + .map_err(|e| { + if let crate::error::Error::NamespaceDoesntExist(_) = e { + tonic::Status::failed_precondition(NAMESPACE_DOESNT_EXIST) + } else { + tonic::Status::internal(e.to_string()) + } + })?; + + let conn = connection_maker.create().await.unwrap(); + + let stream = make_proxy_stream(conn, auth, req.into_inner()); + + Ok(tonic::Response::new(Box::pin(stream))) + } + async fn execute( &self, req: tonic::Request, @@ -475,13 +534,9 @@ impl Proxy for ProxyService { .map_err(|e| tonic::Status::new(tonic::Code::InvalidArgument, e.to_string()))?; let client_id = Uuid::from_str(&req.client_id).unwrap(); - let (connection_maker, new_frame_notifier) = self + let connection_maker = self .namespaces - .with(namespace, |ns| { - let connection_maker = ns.db.connection_maker(); - let notifier = ns.db.logger.new_frame_notifier.subscribe(); - (connection_maker, notifier) - }) + .with(namespace, |ns| ns.db.connection_maker()) .await .map_err(|e| { if let crate::error::Error::NamespaceDoesntExist(_) = e { @@ -510,19 +565,14 @@ impl Proxy for ProxyService { tracing::debug!("executing request for {client_id}"); - let builder = ExecuteResultBuilder::default(); - let (results, state) = db + let builder = ExecuteResultsBuilder::default(); + let builder = db .execute_program(pgm, auth, builder, None) .await // TODO: this is no necessarily a permission denied error! .map_err(|e| tonic::Status::new(tonic::Code::PermissionDenied, e.to_string()))?; - let current_frame_no = *new_frame_notifier.borrow(); - Ok(tonic::Response::new(ExecuteResults { - current_frame_no, - results: results.into_ret(), - state: rpc::execute_results::State::from(state).into(), - })) + Ok(tonic::Response::new(builder.into_ret())) } //TODO: also handle cleanup on peer disconnect From 20bcca17e02018342cf6a2651700e0c84164091e Mon Sep 17 00:00:00 2001 From: ad hoc Date: Tue, 31 Oct 2023 15:04:00 +0100 Subject: [PATCH 07/26] implement replica rpc streaming proxy --- libsql-server/src/rpc/replica_proxy.rs | 37 +++++++++++++++++++++----- 1 file changed, 31 insertions(+), 6 deletions(-) diff --git a/libsql-server/src/rpc/replica_proxy.rs b/libsql-server/src/rpc/replica_proxy.rs index c4aa71798c..e0910c83e7 100644 --- a/libsql-server/src/rpc/replica_proxy.rs +++ b/libsql-server/src/rpc/replica_proxy.rs @@ -1,15 +1,15 @@ use std::sync::Arc; use hyper::Uri; +use libsql_replication::rpc::proxy::{ + proxy_client::ProxyClient, proxy_server::Proxy, Ack, DescribeRequest, DescribeResult, + DisconnectMessage, ExecReq, ExecResp, ExecuteResults, ProgramReq, +}; +use tokio_stream::StreamExt; use tonic::{transport::Channel, Request, Status}; use crate::auth::Auth; -use super::proxy::rpc::{ - self, proxy_client::ProxyClient, proxy_server::Proxy, Ack, DescribeRequest, DescribeResult, - DisconnectMessage, ExecuteResults, -}; - pub struct ReplicaProxyService { client: ProxyClient, auth: Arc, @@ -32,9 +32,34 @@ impl ReplicaProxyService { #[tonic::async_trait] impl Proxy for ReplicaProxyService { + type StreamExecStream = tonic::codec::Streaming; + + async fn stream_exec( + &self, + req: tonic::Request>, + ) -> Result, tonic::Status> { + let (meta, ext, mut stream) = req.into_parts(); + let stream = async_stream::stream! { + while let Some(it) = stream.next().await { + match it { + Ok(it) => yield it, + Err(e) => { + // close the stream on error + tracing::error!("error proxying stream request: {e}"); + break + }, + } + } + }; + let mut req = tonic::Request::from_parts(meta, ext, stream); + self.do_auth(&mut req)?; + let mut client = self.client.clone(); + client.stream_exec(req).await + } + async fn execute( &self, - mut req: tonic::Request, + mut req: tonic::Request, ) -> Result, tonic::Status> { self.do_auth(&mut req)?; From 7ad783b051be3b3bf48bcf928f9120c301d830c1 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Tue, 31 Oct 2023 15:04:38 +0100 Subject: [PATCH 08/26] add test_connection methods to libsql connection creates a test connection with the transparent wal methods. --- libsql-server/src/connection/libsql.rs | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/libsql-server/src/connection/libsql.rs b/libsql-server/src/connection/libsql.rs index 7aa6ed3f36..312009ee60 100644 --- a/libsql-server/src/connection/libsql.rs +++ b/libsql-server/src/connection/libsql.rs @@ -227,6 +227,27 @@ where inner: Arc::new(Mutex::new(conn)), }) } +#[cfg(test)] +impl LibSqlConnection { + pub fn new_test(path: &Path) -> Self { + let (_snd, rcv) = watch::channel(None); + let conn = Connection::new( + path, + Arc::new([]), + &crate::libsql_bindings::wal_hook::TRANSPARENT_METHODS, + (), + Default::default(), + DatabaseConfigStore::new_test().into(), + QueryBuilderConfig::default(), + rcv, + Default::default(), + ) + .unwrap(); + + Self { + inner: Arc::new(Mutex::new(conn)), + } + } } struct Connection { From d576a5c887bd0b58e4afa0709ea4ee353ef31f10 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Tue, 31 Oct 2023 15:05:25 +0100 Subject: [PATCH 09/26] add txn_status method to libsql connections --- libsql-server/src/connection/libsql.rs | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/libsql-server/src/connection/libsql.rs b/libsql-server/src/connection/libsql.rs index 312009ee60..a18482dc45 100644 --- a/libsql-server/src/connection/libsql.rs +++ b/libsql-server/src/connection/libsql.rs @@ -227,6 +227,17 @@ where inner: Arc::new(Mutex::new(conn)), }) } + + pub fn txn_status(&self) -> crate::Result { + Ok(self + .inner + .lock() + .conn + .transaction_state(Some(DatabaseName::Main))? + .into()) + } +} + #[cfg(test)] impl LibSqlConnection { pub fn new_test(path: &Path) -> Self { From 1c3ed4abd5810311e258225fc52cc85cf556e040 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Tue, 31 Oct 2023 15:07:19 +0100 Subject: [PATCH 10/26] update result builder finish to take txn state --- libsql-server/src/query_result_builder.rs | 174 +++++++++++++++++++--- 1 file changed, 150 insertions(+), 24 deletions(-) diff --git a/libsql-server/src/query_result_builder.rs b/libsql-server/src/query_result_builder.rs index ec7c2f97ca..121d1a000f 100644 --- a/libsql-server/src/query_result_builder.rs +++ b/libsql-server/src/query_result_builder.rs @@ -8,12 +8,14 @@ use serde::Serialize; use serde_json::ser::Formatter; use std::sync::atomic::AtomicUsize; +use crate::query_analysis::TxnStatus; use crate::replication::FrameNo; pub static TOTAL_RESPONSE_SIZE: AtomicUsize = AtomicUsize::new(0); #[derive(Debug)] pub enum QueryResultBuilderError { + /// The response payload is too large ResponseTooLarge(u64), Internal(anyhow::Error), } @@ -120,7 +122,11 @@ pub trait QueryResultBuilder: Send + 'static { /// end adding rows fn finish_rows(&mut self) -> Result<(), QueryResultBuilderError>; /// finish serialization. - fn finish(&mut self, last_frame_no: Option) -> Result<(), QueryResultBuilderError>; + fn finish( + &mut self, + last_frame_no: Option, + state: TxnStatus, + ) -> Result<(), QueryResultBuilderError>; /// returns the inner ret fn into_ret(self) -> Self::Ret; /// Returns a `QueryResultBuilder` that wraps Self and takes at most `n` steps @@ -311,7 +317,11 @@ impl QueryResultBuilder for StepResultsBuilder { Ok(()) } - fn finish(&mut self, _last_frame_no: Option) -> Result<(), QueryResultBuilderError> { + fn finish( + &mut self, + _last_frame_no: Option, + _state: TxnStatus, + ) -> Result<(), QueryResultBuilderError> { Ok(()) } @@ -372,7 +382,11 @@ impl QueryResultBuilder for IgnoreResult { Ok(()) } - fn finish(&mut self, _last_frame_no: Option) -> Result<(), QueryResultBuilderError> { + fn finish( + &mut self, + _last_frame_no: Option, + _state: TxnStatus, + ) -> Result<(), QueryResultBuilderError> { Ok(()) } @@ -481,8 +495,12 @@ impl QueryResultBuilder for Take { } } - fn finish(&mut self, last_frame_no: Option) -> Result<(), QueryResultBuilderError> { - self.inner.finish(last_frame_no) + fn finish( + &mut self, + last_frame_no: Option, + state: TxnStatus, + ) -> Result<(), QueryResultBuilderError> { + self.inner.finish(last_frame_no, state) } fn into_ret(self) -> Self::Ret { @@ -599,6 +617,7 @@ pub mod test { fn finish( &mut self, _last_frame_no: Option, + _txn_status: TxnStatus, ) -> Result<(), QueryResultBuilderError> { Ok(()) } @@ -614,7 +633,7 @@ pub mod test { #[derive(Debug, PartialEq, Eq, Clone, Copy)] #[repr(usize)] // do not reorder! - enum FsmState { + pub enum FsmState { Init = 0, Finish, BeginStep, @@ -658,12 +677,12 @@ pub mod test { .filter_map(|(i, ss)| ss[self as usize].then_some(i)) .collect_vec(); // distribution is somewhat tweaked to be biased towards more real-world test cases - let weights = valid_next_states + let weigths = valid_next_states .iter() .enumerate() .map(|(p, i)| i.pow(p as _)) .collect_vec(); - let dist = WeightedIndex::new(weights).unwrap(); + let dist = WeightedIndex::new(weigths).unwrap(); unsafe { std::mem::transmute(valid_next_states[dist.sample(&mut thread_rng())]) } } @@ -683,11 +702,31 @@ pub mod test { } } - pub fn random_builder_driver(mut max_steps: usize, mut b: B) -> B { + pub fn random_transition(mut max_steps: usize) -> Vec { + let mut trace = Vec::with_capacity(max_steps); + let mut state = Init; + trace.push(state); + loop { + if max_steps > 0 { + state = state.rand_transition(false); + } else { + state = state.toward_finish() + } + + trace.push(state); + if state == FsmState::Finish { + break; + } + + max_steps = max_steps.saturating_sub(1); + } + trace + } + + pub fn fsm_builder_driver(trace: &[FsmState], mut b: B) -> B { let mut rand_data = [0; 10_000]; rand_data.try_fill(&mut rand::thread_rng()).unwrap(); let mut u = Unstructured::new(&rand_data); - let mut trace = Vec::new(); #[derive(Arbitrary)] pub enum ValueRef<'a> { @@ -710,9 +749,7 @@ pub mod test { } } - let mut state = Init; - trace.push(state); - loop { + for state in trace { match state { Init => b.init(&QueryBuilderConfig::default()).unwrap(), BeginStep => b.begin_step().unwrap(), @@ -734,26 +771,114 @@ pub mod test { FinishRow => b.finish_row().unwrap(), FinishRows => b.finish_rows().unwrap(), Finish => { - b.finish(Some(0)).unwrap(); + b.finish(Some(0), TxnStatus::Init).unwrap(); break; } BuilderError => return b, } + } - if max_steps > 0 { - state = state.rand_transition(false); - } else { - state = state.toward_finish() - } + b + } - trace.push(state); + /// A Builder that validates a given execution trace + pub struct ValidateTraceBuilder { + trace: Vec, + current: usize, + } - max_steps = max_steps.saturating_sub(1); + impl ValidateTraceBuilder { + pub fn new(trace: Vec) -> Self { + Self { trace, current: 0 } } + } - // this can be useful to help debug the generated test case + impl QueryResultBuilder for ValidateTraceBuilder { + type Ret = (); - b + fn init(&mut self, _config: &QueryBuilderConfig) -> Result<(), QueryResultBuilderError> { + assert_eq!(self.trace[self.current], FsmState::Init); + self.current += 1; + Ok(()) + } + + fn begin_step(&mut self) -> Result<(), QueryResultBuilderError> { + assert_eq!(self.trace[self.current], FsmState::BeginStep); + self.current += 1; + Ok(()) + } + + fn finish_step( + &mut self, + _affected_row_count: u64, + _last_insert_rowid: Option, + ) -> Result<(), QueryResultBuilderError> { + assert_eq!(self.trace[self.current], FsmState::FinishStep); + self.current += 1; + Ok(()) + } + + fn step_error( + &mut self, + _error: crate::error::Error, + ) -> Result<(), QueryResultBuilderError> { + assert_eq!(self.trace[self.current], FsmState::StepError); + self.current += 1; + Ok(()) + } + + fn cols_description<'a>( + &mut self, + _cols: impl IntoIterator>>, + ) -> Result<(), QueryResultBuilderError> { + assert_eq!(self.trace[self.current], FsmState::ColsDescription); + self.current += 1; + Ok(()) + } + + fn begin_rows(&mut self) -> Result<(), QueryResultBuilderError> { + assert_eq!(self.trace[self.current], FsmState::BeginRows); + self.current += 1; + Ok(()) + } + + fn begin_row(&mut self) -> Result<(), QueryResultBuilderError> { + assert_eq!(self.trace[self.current], FsmState::BeginRow); + self.current += 1; + Ok(()) + } + + fn add_row_value(&mut self, _v: ValueRef) -> Result<(), QueryResultBuilderError> { + assert_eq!(self.trace[self.current], FsmState::AddRowValue); + self.current += 1; + Ok(()) + } + + fn finish_row(&mut self) -> Result<(), QueryResultBuilderError> { + assert_eq!(self.trace[self.current], FsmState::FinishRow); + self.current += 1; + Ok(()) + } + + fn finish_rows(&mut self) -> Result<(), QueryResultBuilderError> { + assert_eq!(self.trace[self.current], FsmState::FinishRows); + self.current += 1; + Ok(()) + } + + fn finish( + &mut self, + _last_frame_no: Option, + _state: TxnStatus, + ) -> Result<(), QueryResultBuilderError> { + assert_eq!(self.trace[self.current], FsmState::Finish); + self.current += 1; + Ok(()) + } + + fn into_ret(self) -> Self::Ret { + assert_eq!(self.current, self.trace.len()); + } } pub struct FsmQueryBuilder { @@ -881,6 +1006,7 @@ pub mod test { fn finish( &mut self, _last_frame_no: Option, + _txn_status: TxnStatus, ) -> Result<(), QueryResultBuilderError> { self.maybe_inject_error()?; self.transition(Finish) @@ -929,7 +1055,7 @@ pub mod test { builder.finish_rows().unwrap(); builder.finish_step(0, None).unwrap(); - builder.finish(Some(0)).unwrap(); + builder.finish(Some(0), TxnStatus::Init).unwrap(); } #[test] From a5e6561983024cff2520e9c5d8a6aefe29465ad4 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Tue, 31 Oct 2023 15:07:44 +0100 Subject: [PATCH 11/26] fix http result builder --- libsql-server/src/http/user/result_builder.rs | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/libsql-server/src/http/user/result_builder.rs b/libsql-server/src/http/user/result_builder.rs index fa7c471089..4f56f7dbb1 100644 --- a/libsql-server/src/http/user/result_builder.rs +++ b/libsql-server/src/http/user/result_builder.rs @@ -6,6 +6,7 @@ use serde::{Serialize, Serializer}; use serde_json::ser::{CompactFormatter, Formatter}; use std::sync::atomic::Ordering; +use crate::query_analysis::TxnStatus; use crate::query_result_builder::{ Column, JsonFormatter, QueryBuilderConfig, QueryResultBuilder, QueryResultBuilderError, TOTAL_RESPONSE_SIZE, @@ -293,7 +294,11 @@ impl QueryResultBuilder for JsonHttpPayloadBuilder { } // TODO: how do we return last_frame_no? - fn finish(&mut self, _last_frame_no: Option) -> Result<(), QueryResultBuilderError> { + fn finish( + &mut self, + _last_frame_no: Option, + _state: TxnStatus, + ) -> Result<(), QueryResultBuilderError> { self.formatter.end_array(&mut self.buffer)?; Ok(()) @@ -306,7 +311,8 @@ impl QueryResultBuilder for JsonHttpPayloadBuilder { #[cfg(test)] mod test { - use crate::query_result_builder::test::random_builder_driver; + + use crate::query_result_builder::test::{fsm_builder_driver, random_transition}; use super::*; @@ -314,7 +320,8 @@ mod test { fn test_json_builder() { for _ in 0..1000 { let builder = JsonHttpPayloadBuilder::new(); - let ret = random_builder_driver(100, builder).into_ret(); + let trace = random_transition(100); + let ret = fsm_builder_driver(&trace, builder).into_ret(); println!("{}", std::str::from_utf8(&ret).unwrap()); // we produce valid json serde_json::from_slice::>(&ret).unwrap(); From 099eb14329939bd49f7167180159a2eb39e4b60c Mon Sep 17 00:00:00 2001 From: ad hoc Date: Tue, 31 Oct 2023 15:08:00 +0100 Subject: [PATCH 12/26] fix hrana result builder --- libsql-server/src/hrana/result_builder.rs | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/libsql-server/src/hrana/result_builder.rs b/libsql-server/src/hrana/result_builder.rs index d2e199109e..70d1890af6 100644 --- a/libsql-server/src/hrana/result_builder.rs +++ b/libsql-server/src/hrana/result_builder.rs @@ -6,6 +6,7 @@ use bytes::Bytes; use rusqlite::types::ValueRef; use crate::hrana::stmt::{proto_error_from_stmt_error, stmt_error_from_sqld_error}; +use crate::query_analysis::TxnStatus; use crate::query_result_builder::{ Column, QueryBuilderConfig, QueryResultBuilder, QueryResultBuilderError, TOTAL_RESPONSE_SIZE, }; @@ -225,7 +226,11 @@ impl QueryResultBuilder for SingleStatementBuilder { Ok(()) } - fn finish(&mut self, last_frame_no: Option) -> Result<(), QueryResultBuilderError> { + fn finish( + &mut self, + last_frame_no: Option, + _state: TxnStatus, + ) -> Result<(), QueryResultBuilderError> { self.last_frame_no = last_frame_no; Ok(()) } @@ -344,7 +349,11 @@ impl QueryResultBuilder for HranaBatchProtoBuilder { Ok(()) } - fn finish(&mut self, _last_frame_no: Option) -> Result<(), QueryResultBuilderError> { + fn finish( + &mut self, + _last_frame_no: Option, + _state: TxnStatus, + ) -> Result<(), QueryResultBuilderError> { Ok(()) } From d646403ed57c407b243f2e517b5154e5c3a02fea Mon Sep 17 00:00:00 2001 From: ad hoc Date: Tue, 31 Oct 2023 15:21:20 +0100 Subject: [PATCH 13/26] connection execute doesn't return txn state anymore the state is returned by the result builder --- libsql-server/src/connection/libsql.rs | 103 ++++++++++++++---------- libsql-server/src/connection/mod.rs | 28 +++---- libsql-server/src/connection/program.rs | 2 - libsql-server/src/hrana/batch.rs | 4 +- libsql-server/src/hrana/cursor.rs | 1 + libsql-server/src/hrana/stmt.rs | 2 +- 6 files changed, 79 insertions(+), 61 deletions(-) diff --git a/libsql-server/src/connection/libsql.rs b/libsql-server/src/connection/libsql.rs index a18482dc45..d5dffc069b 100644 --- a/libsql-server/src/connection/libsql.rs +++ b/libsql-server/src/connection/libsql.rs @@ -15,14 +15,14 @@ use crate::error::Error; use crate::libsql_bindings::wal_hook::WalHook; use crate::metrics::{READ_QUERY_COUNT, VACUUM_COUNT, WAL_CHECKPOINT_COUNT, WRITE_QUERY_COUNT}; use crate::query::Query; -use crate::query_analysis::{State, StmtKind}; +use crate::query_analysis::{StmtKind, TxnStatus}; use crate::query_result_builder::{QueryBuilderConfig, QueryResultBuilder}; use crate::replication::FrameNo; use crate::stats::Stats; use crate::Result; use super::config::DatabaseConfigStore; -use super::program::{Cond, DescribeCol, DescribeParam, DescribeResponse, DescribeResult}; +use super::program::{Cond, DescribeCol, DescribeParam, DescribeResponse}; use super::{MakeConnection, Program, Step, TXN_TIMEOUT}; pub struct MakeLibSqlConn { @@ -146,7 +146,6 @@ where } } -#[derive(Clone)] pub struct LibSqlConnection { inner: Arc>>, } @@ -162,6 +161,14 @@ impl std::fmt::Debug for LibSqlConnection { } } +impl Clone for LibSqlConnection { + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + } + } +} + pub fn open_conn( path: &Path, wal_methods: &'static WalMethodsHook, @@ -287,7 +294,7 @@ struct TxnSlot { /// is stolen. conn: Arc>>, /// Time at which the transaction can be stolen - created_at: tokio::time::Instant, + timeout_at: tokio::time::Instant, /// The transaction lock was stolen is_stolen: AtomicBool, } @@ -414,6 +421,17 @@ fn value_size(val: &rusqlite::types::ValueRef) -> usize { } } +impl From for TxnStatus { + fn from(value: TransactionState) -> Self { + use TransactionState as Tx; + match value { + Tx::None => TxnStatus::Init, + Tx::Read | Tx::Write => TxnStatus::Txn, + _ => unreachable!(), + } + } +} + impl Connection { fn new( path: &Path, @@ -468,7 +486,7 @@ impl Connection { this: Arc>, pgm: Program, mut builder: B, - ) -> Result<(B, State)> { + ) -> Result { use rusqlite::TransactionState as Tx; let state = this.lock().state.clone(); @@ -532,20 +550,18 @@ impl Connection { results.push(res); } - builder.finish(*this.lock().current_frame_no_receiver.borrow_and_update())?; + let status = this + .lock() + .conn + .transaction_state(Some(DatabaseName::Main))? + .into(); - let state = if matches!( - this.lock() - .conn - .transaction_state(Some(DatabaseName::Main))?, - Tx::Read | Tx::Write - ) { - State::Txn - } else { - State::Init - }; + builder.finish( + *this.lock().current_frame_no_receiver.borrow_and_update(), + status, + )?; - Ok((builder, state)) + Ok(builder) } fn execute_step( @@ -605,7 +621,7 @@ impl Connection { let blocked = match query.stmt.kind { StmtKind::Read | StmtKind::TxnBegin | StmtKind::Other => config.block_reads, StmtKind::Write => config.block_reads || config.block_writes, - StmtKind::TxnEnd | StmtKind::Release | StmtKind::Savepoint => false, + StmtKind::TxnEnd => false, }; if blocked { return Err(Error::Blocked(config.block_reason.clone())); @@ -735,7 +751,7 @@ impl Connection { .update_query_metrics(sql, rows_read, rows_written, mem_used, elapsed) } - fn describe(&self, sql: &str) -> DescribeResult { + fn describe(&self, sql: &str) -> crate::Result { let stmt = self.conn.prepare(sql)?; let params = (1..=stmt.parameter_count()) @@ -840,7 +856,7 @@ where auth: Authenticated, builder: B, _replication_index: Option, - ) -> Result<(B, State)> { + ) -> Result { check_program_auth(auth, &pgm)?; let conn = self.inner.clone(); tokio::task::spawn_blocking(move || Connection::run(conn, pgm, builder)) @@ -853,7 +869,7 @@ where sql: String, auth: Authenticated, _replication_index: Option, - ) -> Result { + ) -> Result> { check_describe_auth(auth)?; let conn = self.inner.clone(); let res = tokio::task::spawn_blocking(move || conn.lock().describe(&sql)) @@ -932,7 +948,7 @@ mod test { fn test_libsql_conn_builder_driver() { test_driver(1000, |b| { let conn = setup_test_conn(); - Connection::run(conn, Program::seq(&["select * from test"]), b).map(|x| x.0) + Connection::run(conn, Program::seq(&["select * from test"]), b) }) } @@ -956,23 +972,23 @@ mod test { tokio::time::pause(); let conn = make_conn.make_connection().await.unwrap(); - let (_builder, state) = Connection::run( + let _builder = Connection::run( conn.inner.clone(), Program::seq(&["BEGIN IMMEDIATE"]), TestBuilder::default(), ) .unwrap(); - assert_eq!(state, State::Txn); + assert_eq!(conn.txn_status().unwrap(), TxnStatus::Txn); tokio::time::advance(TXN_TIMEOUT * 2).await; - let (builder, state) = Connection::run( + let builder = Connection::run( conn.inner.clone(), Program::seq(&["BEGIN IMMEDIATE"]), TestBuilder::default(), ) .unwrap(); - assert_eq!(state, State::Init); + assert_eq!(conn.txn_status().unwrap(), TxnStatus::Init); assert!(matches!(builder.into_ret()[0], Err(Error::LibSqlTxTimeout))); } @@ -1000,13 +1016,13 @@ mod test { for _ in 0..10 { let conn = make_conn.make_connection().await.unwrap(); set.spawn_blocking(move || { - let (builder, state) = Connection::run( - conn.inner, + let builder = Connection::run( + conn.inner.clone(), Program::seq(&["BEGIN IMMEDIATE"]), TestBuilder::default(), ) .unwrap(); - assert_eq!(state, State::Txn); + assert_eq!(conn.txn_status().unwrap(), TxnStatus::Txn); assert!(builder.into_ret()[0].is_ok()); }); } @@ -1041,15 +1057,15 @@ mod test { let conn1 = make_conn.make_connection().await.unwrap(); tokio::task::spawn_blocking({ - let conn = conn1.inner.clone(); + let conn = conn1.clone(); move || { - let (builder, state) = Connection::run( - conn, + let builder = Connection::run( + conn.inner.clone(), Program::seq(&["BEGIN IMMEDIATE"]), TestBuilder::default(), ) .unwrap(); - assert_eq!(state, State::Txn); + assert_eq!(conn.txn_status().unwrap(), TxnStatus::Txn); assert!(builder.into_ret()[0].is_ok()); } }) @@ -1058,16 +1074,16 @@ mod test { let conn2 = make_conn.make_connection().await.unwrap(); let handle = tokio::task::spawn_blocking({ - let conn = conn2.inner.clone(); + let conn = conn2.clone(); move || { let before = Instant::now(); - let (builder, state) = Connection::run( - conn, + let builder = Connection::run( + conn.inner.clone(), Program::seq(&["BEGIN IMMEDIATE"]), TestBuilder::default(), ) .unwrap(); - assert_eq!(state, State::Txn); + assert_eq!(conn.txn_status().unwrap(), TxnStatus::Txn); assert!(builder.into_ret()[0].is_ok()); before.elapsed() } @@ -1077,12 +1093,15 @@ mod test { tokio::time::sleep(wait_time).await; tokio::task::spawn_blocking({ - let conn = conn1.inner.clone(); + let conn = conn1.clone(); move || { - let (builder, state) = - Connection::run(conn, Program::seq(&["COMMIT"]), TestBuilder::default()) - .unwrap(); - assert_eq!(state, State::Init); + let builder = Connection::run( + conn.inner.clone(), + Program::seq(&["COMMIT"]), + TestBuilder::default(), + ) + .unwrap(); + assert_eq!(conn.txn_status().unwrap(), TxnStatus::Init); assert!(builder.into_ret()[0].is_ok()); } }) diff --git a/libsql-server/src/connection/mod.rs b/libsql-server/src/connection/mod.rs index e439ad2dc6..43334bc194 100644 --- a/libsql-server/src/connection/mod.rs +++ b/libsql-server/src/connection/mod.rs @@ -10,12 +10,12 @@ use crate::auth::Authenticated; use crate::error::Error; use crate::metrics::CONCCURENT_CONNECTIONS_COUNT; use crate::query::{Params, Query}; -use crate::query_analysis::{State, Statement}; +use crate::query_analysis::Statement; use crate::query_result_builder::{IgnoreResult, QueryResultBuilder}; use crate::replication::FrameNo; use crate::Result; -use self::program::{Cond, DescribeResult, Program, Step}; +use self::program::{Cond, DescribeResponse, Program, Step}; pub mod config; pub mod dump; @@ -34,7 +34,7 @@ pub trait Connection: Send + Sync + 'static { auth: Authenticated, response_builder: B, replication_index: Option, - ) -> Result<(B, State)>; + ) -> Result; /// Execute all the queries in the batch sequentially. /// If an query in the batch fails, the remaining queries are ignores, and the batch current @@ -45,12 +45,12 @@ pub trait Connection: Send + Sync + 'static { auth: Authenticated, result_builder: B, replication_index: Option, - ) -> Result<(B, State)> { + ) -> Result { let batch_len = batch.len(); let mut steps = make_batch_program(batch); if !steps.is_empty() { - // We add a conditional rollback step if the last step was not successful. + // We add a conditional rollback step if the last step was not sucessful. steps.push(Step { query: Query { stmt: Statement::parse("ROLLBACK").next().unwrap().unwrap(), @@ -69,11 +69,11 @@ pub trait Connection: Send + Sync + 'static { // ignore the rollback result let builder = result_builder.take(batch_len); - let (builder, state) = self + let builder = self .execute_program(pgm, auth, builder, replication_index) .await?; - Ok((builder.into_inner(), state)) + Ok(builder.into_inner()) } /// Execute all the queries in the batch sequentially. @@ -84,7 +84,7 @@ pub trait Connection: Send + Sync + 'static { auth: Authenticated, result_builder: B, replication_index: Option, - ) -> Result<(B, State)> { + ) -> Result { let steps = make_batch_program(batch); let pgm = Program::new(steps); self.execute_program(pgm, auth, result_builder, replication_index) @@ -113,7 +113,7 @@ pub trait Connection: Send + Sync + 'static { sql: String, auth: Authenticated, replication_index: Option, - ) -> Result; + ) -> Result>; /// Check whether the connection is in autocommit mode. async fn is_autocommit(&self) -> Result; @@ -336,7 +336,7 @@ impl Connection for TrackedConnection { auth: Authenticated, builder: B, replication_index: Option, - ) -> crate::Result<(B, State)> { + ) -> crate::Result { self.atime.store(now_millis(), Ordering::Relaxed); self.inner .execute_program(pgm, auth, builder, replication_index) @@ -349,7 +349,7 @@ impl Connection for TrackedConnection { sql: String, auth: Authenticated, replication_index: Option, - ) -> crate::Result { + ) -> crate::Result> { self.atime.store(now_millis(), Ordering::Relaxed); self.inner.describe(sql, auth, replication_index).await } @@ -377,7 +377,7 @@ impl Connection for TrackedConnection { } #[cfg(test)] -mod test { +pub mod test { use super::*; #[derive(Debug)] @@ -391,7 +391,7 @@ mod test { _auth: Authenticated, _builder: B, _replication_index: Option, - ) -> crate::Result<(B, State)> { + ) -> crate::Result { unreachable!() } @@ -400,7 +400,7 @@ mod test { _sql: String, _auth: Authenticated, _replication_index: Option, - ) -> crate::Result { + ) -> crate::Result> { unreachable!() } diff --git a/libsql-server/src/connection/program.rs b/libsql-server/src/connection/program.rs index fabfbd18bf..3017232af7 100644 --- a/libsql-server/src/connection/program.rs +++ b/libsql-server/src/connection/program.rs @@ -60,8 +60,6 @@ pub enum Cond { IsAutocommit, } -pub type DescribeResult = crate::Result; - #[derive(Debug, Clone)] pub struct DescribeResponse { pub params: Vec, diff --git a/libsql-server/src/hrana/batch.rs b/libsql-server/src/hrana/batch.rs index cb0deb6388..a2ddd9e291 100644 --- a/libsql-server/src/hrana/batch.rs +++ b/libsql-server/src/hrana/batch.rs @@ -110,7 +110,7 @@ pub async fn execute_batch( replication_index: Option, ) -> Result { let batch_builder = HranaBatchProtoBuilder::default(); - let (builder, _state) = db + let builder = db .execute_program(pgm, auth, batch_builder, replication_index) .await .map_err(catch_batch_error)?; @@ -151,7 +151,7 @@ pub async fn execute_sequence( replication_index: Option, ) -> Result<()> { let builder = StepResultsBuilder::default(); - let (builder, _state) = db + let builder = db .execute_program(pgm, auth, builder, replication_index) .await .map_err(catch_batch_error)?; diff --git a/libsql-server/src/hrana/cursor.rs b/libsql-server/src/hrana/cursor.rs index 005799a737..3383e02ce1 100644 --- a/libsql-server/src/hrana/cursor.rs +++ b/libsql-server/src/hrana/cursor.rs @@ -8,6 +8,7 @@ use tokio::sync::{mpsc, oneshot}; use crate::auth::Authenticated; use crate::connection::program::Program; use crate::connection::Connection; +use crate::query_analysis::TxnStatus; use crate::query_result_builder::{ Column, QueryBuilderConfig, QueryResultBuilder, QueryResultBuilderError, }; diff --git a/libsql-server/src/hrana/stmt.rs b/libsql-server/src/hrana/stmt.rs index e595bf1ab8..56569f57ce 100644 --- a/libsql-server/src/hrana/stmt.rs +++ b/libsql-server/src/hrana/stmt.rs @@ -58,7 +58,7 @@ pub async fn execute_stmt( replication_index: Option, ) -> Result { let builder = SingleStatementBuilder::default(); - let (stmt_res, _) = db + let stmt_res = db .execute_batch(vec![query], auth, builder, replication_index) .await .map_err(catch_stmt_error)?; From e664e46e6dcdd021211fcdae5a39b2bd1afb1294 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Tue, 31 Oct 2023 15:22:23 +0100 Subject: [PATCH 14/26] handle stream proxy error --- libsql-server/src/error.rs | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/libsql-server/src/error.rs b/libsql-server/src/error.rs index 8ac081b18e..a31a46fdf1 100644 --- a/libsql-server/src/error.rs +++ b/libsql-server/src/error.rs @@ -82,6 +82,13 @@ pub enum Error { Fork(#[from] ForkError), #[error("Fatal replication error")] FatalReplicationError, + + #[error("Connection with primary broken")] + PrimaryStreamDisconnect, + #[error("Proxy protocal misuse")] + PrimaryStreamMisuse, + #[error("Proxy request interupted")] + PrimaryStreamInterupted, } trait ResponseError: std::error::Error { @@ -135,6 +142,9 @@ impl IntoResponse for Error { FatalReplicationError => self.format_err(StatusCode::INTERNAL_SERVER_ERROR), ReplicatorError(_) => self.format_err(StatusCode::INTERNAL_SERVER_ERROR), ReplicaMetaError(_) => self.format_err(StatusCode::INTERNAL_SERVER_ERROR), + PrimaryStreamDisconnect => self.format_err(StatusCode::INTERNAL_SERVER_ERROR), + PrimaryStreamMisuse => self.format_err(StatusCode::INTERNAL_SERVER_ERROR), + PrimaryStreamInterupted => self.format_err(StatusCode::INTERNAL_SERVER_ERROR), } } } From 9b1e121631e2bd6f73df57eea641aed94c72574e Mon Sep 17 00:00:00 2001 From: ad hoc Date: Tue, 31 Oct 2023 15:23:58 +0100 Subject: [PATCH 15/26] fix import in libsql --- libsql/src/replication/connection.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libsql/src/replication/connection.rs b/libsql/src/replication/connection.rs index 675287b7a4..49e667afca 100644 --- a/libsql/src/replication/connection.rs +++ b/libsql/src/replication/connection.rs @@ -5,7 +5,7 @@ use std::sync::Arc; use parking_lot::Mutex; use libsql_replication::rpc::proxy::{ - describe_result, execute_results::State as RemoteState, query_result::RowResult, DescribeResult, + describe_result, State as RemoteState, query_result::RowResult, DescribeResult, ExecuteResults, ResultRows }; From 98c5c3f9c0be03bb95dd233fd8e5749a61729078 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Tue, 31 Oct 2023 15:24:12 +0100 Subject: [PATCH 16/26] fix hrana cursor result builder --- libsql-server/src/hrana/cursor.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/libsql-server/src/hrana/cursor.rs b/libsql-server/src/hrana/cursor.rs index 3383e02ce1..428f966035 100644 --- a/libsql-server/src/hrana/cursor.rs +++ b/libsql-server/src/hrana/cursor.rs @@ -256,7 +256,11 @@ impl QueryResultBuilder for CursorResultBuilder { Ok(()) } - fn finish(&mut self, last_frame_no: Option) -> Result<(), QueryResultBuilderError> { + fn finish( + &mut self, + last_frame_no: Option, + _status: TxnStatus, + ) -> Result<(), QueryResultBuilderError> { self.emit_entry(Ok(SizedEntry { entry: proto::CursorEntry::ReplicationIndex { replication_index: last_frame_no, From 3ea8d5eb4c0cc4b8e51a0b16ca79d556a59d78fa Mon Sep 17 00:00:00 2001 From: ad hoc Date: Tue, 31 Oct 2023 15:24:59 +0100 Subject: [PATCH 17/26] fix http builder return type --- libsql-server/src/http/user/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libsql-server/src/http/user/mod.rs b/libsql-server/src/http/user/mod.rs index 6e42df567f..06bb53d554 100644 --- a/libsql-server/src/http/user/mod.rs +++ b/libsql-server/src/http/user/mod.rs @@ -132,7 +132,7 @@ async fn handle_query( let db = connection_maker.create().await?; let builder = JsonHttpPayloadBuilder::new(); - let (builder, _) = db + let builder = db .execute_batch_or_rollback(batch, auth, builder, query.replication_index) .await?; From e82aa175a19a21ea82cbe32c7747c3cc1e3cb65f Mon Sep 17 00:00:00 2001 From: ad hoc Date: Tue, 31 Oct 2023 15:27:15 +0100 Subject: [PATCH 18/26] rename State to TxnStatus --- libsql-server/src/connection/libsql.rs | 6 +++--- libsql-server/src/http/user/mod.rs | 10 +++++----- libsql-server/src/query_analysis.rs | 26 +++++++++++--------------- 3 files changed, 19 insertions(+), 23 deletions(-) diff --git a/libsql-server/src/connection/libsql.rs b/libsql-server/src/connection/libsql.rs index d5dffc069b..78f3abe3a3 100644 --- a/libsql-server/src/connection/libsql.rs +++ b/libsql-server/src/connection/libsql.rs @@ -5,7 +5,7 @@ use std::sync::Arc; use metrics::histogram; use parking_lot::{Mutex, RwLock}; -use rusqlite::{DatabaseName, ErrorCode, OpenFlags, StatementStatus}; +use rusqlite::{DatabaseName, ErrorCode, OpenFlags, StatementStatus, TransactionState}; use sqld_libsql_bindings::wal_hook::{TransparentMethods, WalMethodsHook}; use tokio::sync::{watch, Notify}; use tokio::time::{Duration, Instant}; @@ -294,7 +294,7 @@ struct TxnSlot { /// is stolen. conn: Arc>>, /// Time at which the transaction can be stolen - timeout_at: tokio::time::Instant, + created_at: tokio::time::Instant, /// The transaction lock was stolen is_stolen: AtomicBool, } @@ -621,7 +621,7 @@ impl Connection { let blocked = match query.stmt.kind { StmtKind::Read | StmtKind::TxnBegin | StmtKind::Other => config.block_reads, StmtKind::Write => config.block_reads || config.block_writes, - StmtKind::TxnEnd => false, + StmtKind::TxnEnd | StmtKind::Release | StmtKind::Savepoint => false, }; if blocked { return Err(Error::Blocked(config.block_reason.clone())); diff --git a/libsql-server/src/http/user/mod.rs b/libsql-server/src/http/user/mod.rs index 06bb53d554..74aa35a38e 100644 --- a/libsql-server/src/http/user/mod.rs +++ b/libsql-server/src/http/user/mod.rs @@ -37,7 +37,7 @@ use crate::http::user::types::HttpQuery; use crate::namespace::{MakeNamespace, NamespaceStore}; use crate::net::Accept; use crate::query::{self, Query}; -use crate::query_analysis::{predict_final_state, State, Statement}; +use crate::query_analysis::{predict_final_state, TxnStatus, Statement}; use crate::query_result_builder::QueryResultBuilder; use crate::rpc::proxy::rpc::proxy_server::{Proxy, ProxyServer}; use crate::rpc::replication_log::rpc::replication_log_server::ReplicationLog; @@ -108,15 +108,15 @@ fn parse_queries(queries: Vec) -> crate::Result> { )); } - match predict_final_state(State::Init, out.iter().map(|q| &q.stmt)) { - State::Txn => { + match predict_final_state(TxnStatus::Init, out.iter().map(|q| &q.stmt)) { + TxnStatus::Txn => { return Err(Error::QueryError( "interactive transaction not allowed in HTTP queries".to_string(), )) } - State::Init => (), + TxnStatus::Init => (), // maybe we should err here, but let's sqlite deal with that. - State::Invalid => (), + TxnStatus::Invalid => (), } Ok(out) diff --git a/libsql-server/src/query_analysis.rs b/libsql-server/src/query_analysis.rs index f8f807be41..174688112c 100644 --- a/libsql-server/src/query_analysis.rs +++ b/libsql-server/src/query_analysis.rs @@ -201,7 +201,7 @@ impl StmtKind { /// The state of a transaction for a series of statement #[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub enum State { +pub enum TxnStatus { /// The txn in an opened state Txn, /// The txn in a closed state @@ -210,20 +210,20 @@ pub enum State { Invalid, } -impl State { +impl TxnStatus { pub fn step(&mut self, kind: StmtKind) { *self = match (*self, kind) { - (State::Txn, StmtKind::TxnBegin) | (State::Init, StmtKind::TxnEnd) => State::Invalid, - (State::Txn, StmtKind::TxnEnd) => State::Init, + (TxnStatus::Txn, StmtKind::TxnBegin) | (TxnStatus::Init, StmtKind::TxnEnd) => TxnStatus::Invalid, + (TxnStatus::Txn, StmtKind::TxnEnd) => TxnStatus::Init, (state, StmtKind::Other | StmtKind::Write | StmtKind::Read) => state, - (State::Invalid, _) => State::Invalid, - (State::Init, StmtKind::TxnBegin) => State::Txn, - _ => State::Invalid, + (TxnStatus::Invalid, _) => TxnStatus::Invalid, + (TxnStatus::Init, StmtKind::TxnBegin) => TxnStatus::Txn, + _ => TxnStatus::Invalid, }; } pub fn reset(&mut self) { - *self = State::Init + *self = TxnStatus::Init } } @@ -307,11 +307,7 @@ impl Statement { pub fn is_read_only(&self) -> bool { matches!( self.kind, - StmtKind::Read - | StmtKind::TxnEnd - | StmtKind::TxnBegin - | StmtKind::Release - | StmtKind::Savepoint + StmtKind::Read | StmtKind::TxnEnd | StmtKind::TxnBegin ) } } @@ -319,9 +315,9 @@ impl Statement { /// Given a an initial state and an array of queries, attempts to predict what the final state will /// be pub fn predict_final_state<'a>( - mut state: State, + mut state: TxnStatus, stmts: impl Iterator, -) -> State { +) -> TxnStatus { for stmt in stmts { state.step(stmt.kind); } From 86b032825468a41d0d145ebe6721ba17ff91ef1d Mon Sep 17 00:00:00 2001 From: ad hoc Date: Tue, 31 Oct 2023 15:58:44 +0100 Subject: [PATCH 19/26] add test snapshot files --- ...__rpc__streaming_exec__test__describe.snap | 28 ++ ...streaming_exec__test__invalid_request.snap | 5 + ...ming_exec__test__perform_query_simple.snap | 72 +++++ ...ec__test__single_query_split_response.snap | 255 ++++++++++++++++++ .../tests__namespaces__create_namespace.snap | 5 + 5 files changed, 365 insertions(+) create mode 100644 libsql-server/src/rpc/snapshots/sqld__rpc__streaming_exec__test__describe.snap create mode 100644 libsql-server/src/rpc/snapshots/sqld__rpc__streaming_exec__test__invalid_request.snap create mode 100644 libsql-server/src/rpc/snapshots/sqld__rpc__streaming_exec__test__perform_query_simple.snap create mode 100644 libsql-server/src/rpc/snapshots/sqld__rpc__streaming_exec__test__single_query_split_response.snap create mode 100644 libsql-server/tests/namespaces/snapshots/tests__namespaces__create_namespace.snap diff --git a/libsql-server/src/rpc/snapshots/sqld__rpc__streaming_exec__test__describe.snap b/libsql-server/src/rpc/snapshots/sqld__rpc__streaming_exec__test__describe.snap new file mode 100644 index 0000000000..172a3ac14b --- /dev/null +++ b/libsql-server/src/rpc/snapshots/sqld__rpc__streaming_exec__test__describe.snap @@ -0,0 +1,28 @@ +--- +source: libsql-server/src/rpc/streaming_exec.rs +expression: stream.next().await.unwrap().unwrap() +--- +ExecResp { + request_id: 0, + response: Some( + DescribeResp( + DescribeResp { + params: [ + DescribeParam { + name: Some( + "$hello", + ), + }, + ], + cols: [ + DescribeCol { + name: "$hello", + decltype: None, + }, + ], + is_explain: false, + is_readonly: true, + }, + ), + ), +} diff --git a/libsql-server/src/rpc/snapshots/sqld__rpc__streaming_exec__test__invalid_request.snap b/libsql-server/src/rpc/snapshots/sqld__rpc__streaming_exec__test__invalid_request.snap new file mode 100644 index 0000000000..00d6fd952e --- /dev/null +++ b/libsql-server/src/rpc/snapshots/sqld__rpc__streaming_exec__test__invalid_request.snap @@ -0,0 +1,5 @@ +--- +source: libsql-server/src/rpc/streaming_exec.rs +expression: stream.next().await.unwrap().unwrap_err().to_string() +--- +status: InvalidArgument, message: "invalid request", details: [], metadata: MetadataMap { headers: {} } diff --git a/libsql-server/src/rpc/snapshots/sqld__rpc__streaming_exec__test__perform_query_simple.snap b/libsql-server/src/rpc/snapshots/sqld__rpc__streaming_exec__test__perform_query_simple.snap new file mode 100644 index 0000000000..0bd177ac9e --- /dev/null +++ b/libsql-server/src/rpc/snapshots/sqld__rpc__streaming_exec__test__perform_query_simple.snap @@ -0,0 +1,72 @@ +--- +source: libsql-server/src/rpc/streaming_exec.rs +expression: stream.next().await.unwrap().unwrap() +--- +ExecResp { + request_id: 0, + response: Some( + ProgramResp( + ProgramResp { + steps: [ + RespStep { + step: Some( + Init( + Init, + ), + ), + }, + RespStep { + step: Some( + BeginStep( + BeginStep, + ), + ), + }, + RespStep { + step: Some( + ColsDescription( + ColsDescription { + columns: [], + }, + ), + ), + }, + RespStep { + step: Some( + BeginRows( + BeginRows, + ), + ), + }, + RespStep { + step: Some( + FinishRows( + FinishRows, + ), + ), + }, + RespStep { + step: Some( + FinishStep( + FinishStep { + affected_row_count: 0, + last_insert_rowid: None, + }, + ), + ), + }, + RespStep { + step: Some( + Finish( + Finish { + last_frame_no: None, + state: Init, + }, + ), + ), + }, + ], + }, + ), + ), +} diff --git a/libsql-server/src/rpc/snapshots/sqld__rpc__streaming_exec__test__single_query_split_response.snap b/libsql-server/src/rpc/snapshots/sqld__rpc__streaming_exec__test__single_query_split_response.snap new file mode 100644 index 0000000000..c2707dffb3 --- /dev/null +++ b/libsql-server/src/rpc/snapshots/sqld__rpc__streaming_exec__test__single_query_split_response.snap @@ -0,0 +1,255 @@ +--- +source: libsql-server/src/rpc/streaming_exec.rs +expression: builder.into_ret() +--- +[ + Ok( + [ + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + [ + Text( + "something moderately long", + ), + ], + ], + ), +] diff --git a/libsql-server/tests/namespaces/snapshots/tests__namespaces__create_namespace.snap b/libsql-server/tests/namespaces/snapshots/tests__namespaces__create_namespace.snap new file mode 100644 index 0000000000..4e0aded2ba --- /dev/null +++ b/libsql-server/tests/namespaces/snapshots/tests__namespaces__create_namespace.snap @@ -0,0 +1,5 @@ +--- +source: libsql-server/tests/namespaces/mod.rs +expression: e.to_string() +--- +Hrana: `api error: `{"error":"Namespace `foo` doesn't exist"}`` From e23f1f25163dbccad0992a647b3bb423e75103ce Mon Sep 17 00:00:00 2001 From: ad hoc Date: Tue, 31 Oct 2023 16:43:39 +0100 Subject: [PATCH 20/26] add more exec-stream tests --- libsql-replication/build.rs | 1 + ...xec__test__perform_multiple_queries-2.snap | 56 +++++++++++++++++ ..._exec__test__perform_multiple_queries.snap | 56 +++++++++++++++++ ...ery_number_less_than_previous_query-2.snap | 11 ++++ ...query_number_less_than_previous_query.snap | 56 +++++++++++++++++ libsql-server/src/rpc/streaming_exec.rs | 62 ++++++++++++++++++- 6 files changed, 241 insertions(+), 1 deletion(-) create mode 100644 libsql-server/src/rpc/snapshots/sqld__rpc__streaming_exec__test__perform_multiple_queries-2.snap create mode 100644 libsql-server/src/rpc/snapshots/sqld__rpc__streaming_exec__test__perform_multiple_queries.snap create mode 100644 libsql-server/src/rpc/snapshots/sqld__rpc__streaming_exec__test__query_number_less_than_previous_query-2.snap create mode 100644 libsql-server/src/rpc/snapshots/sqld__rpc__streaming_exec__test__query_number_less_than_previous_query.snap diff --git a/libsql-replication/build.rs b/libsql-replication/build.rs index fbf6f79416..b82b69a98f 100644 --- a/libsql-replication/build.rs +++ b/libsql-replication/build.rs @@ -7,6 +7,7 @@ fn main() -> Result<(), Box> { config.bytes([".wal_log"]); tonic_build::configure() .protoc_arg("--experimental_allow_proto3_optional") + .type_attribute(".proxy", "#[derive(serde::Serialize, serde::Deserialize)]") .type_attribute(".proxy", "#[cfg_attr(test, derive(arbitrary::Arbitrary))]") .compile_with_config( config, diff --git a/libsql-server/src/rpc/snapshots/sqld__rpc__streaming_exec__test__perform_multiple_queries-2.snap b/libsql-server/src/rpc/snapshots/sqld__rpc__streaming_exec__test__perform_multiple_queries-2.snap new file mode 100644 index 0000000000..fd37c2e44c --- /dev/null +++ b/libsql-server/src/rpc/snapshots/sqld__rpc__streaming_exec__test__perform_multiple_queries-2.snap @@ -0,0 +1,56 @@ +--- +source: libsql-server/src/rpc/streaming_exec.rs +expression: stream.next().await.unwrap().unwrap() +--- +{ + "request_id": 1, + "response": { + "ProgramResp": { + "steps": [ + { + "step": { + "Init": {} + } + }, + { + "step": { + "BeginStep": {} + } + }, + { + "step": { + "ColsDescription": { + "columns": [] + } + } + }, + { + "step": { + "BeginRows": {} + } + }, + { + "step": { + "FinishRows": {} + } + }, + { + "step": { + "FinishStep": { + "affected_row_count": 1, + "last_insert_rowid": 1 + } + } + }, + { + "step": { + "Finish": { + "last_frame_no": null, + "state": 0 + } + } + } + ] + } + } +} diff --git a/libsql-server/src/rpc/snapshots/sqld__rpc__streaming_exec__test__perform_multiple_queries.snap b/libsql-server/src/rpc/snapshots/sqld__rpc__streaming_exec__test__perform_multiple_queries.snap new file mode 100644 index 0000000000..fd0a871c9d --- /dev/null +++ b/libsql-server/src/rpc/snapshots/sqld__rpc__streaming_exec__test__perform_multiple_queries.snap @@ -0,0 +1,56 @@ +--- +source: libsql-server/src/rpc/streaming_exec.rs +expression: stream.next().await.unwrap().unwrap() +--- +{ + "request_id": 0, + "response": { + "ProgramResp": { + "steps": [ + { + "step": { + "Init": {} + } + }, + { + "step": { + "BeginStep": {} + } + }, + { + "step": { + "ColsDescription": { + "columns": [] + } + } + }, + { + "step": { + "BeginRows": {} + } + }, + { + "step": { + "FinishRows": {} + } + }, + { + "step": { + "FinishStep": { + "affected_row_count": 0, + "last_insert_rowid": null + } + } + }, + { + "step": { + "Finish": { + "last_frame_no": null, + "state": 0 + } + } + } + ] + } + } +} diff --git a/libsql-server/src/rpc/snapshots/sqld__rpc__streaming_exec__test__query_number_less_than_previous_query-2.snap b/libsql-server/src/rpc/snapshots/sqld__rpc__streaming_exec__test__query_number_less_than_previous_query-2.snap new file mode 100644 index 0000000000..684d2cb827 --- /dev/null +++ b/libsql-server/src/rpc/snapshots/sqld__rpc__streaming_exec__test__query_number_less_than_previous_query-2.snap @@ -0,0 +1,11 @@ +--- +source: libsql-server/src/rpc/streaming_exec.rs +expression: resp +--- +Err( + Status { + code: InvalidArgument, + message: "received request with id less than last received request, closing stream", + source: None, + }, +) diff --git a/libsql-server/src/rpc/snapshots/sqld__rpc__streaming_exec__test__query_number_less_than_previous_query.snap b/libsql-server/src/rpc/snapshots/sqld__rpc__streaming_exec__test__query_number_less_than_previous_query.snap new file mode 100644 index 0000000000..fd0a871c9d --- /dev/null +++ b/libsql-server/src/rpc/snapshots/sqld__rpc__streaming_exec__test__query_number_less_than_previous_query.snap @@ -0,0 +1,56 @@ +--- +source: libsql-server/src/rpc/streaming_exec.rs +expression: stream.next().await.unwrap().unwrap() +--- +{ + "request_id": 0, + "response": { + "ProgramResp": { + "steps": [ + { + "step": { + "Init": {} + } + }, + { + "step": { + "BeginStep": {} + } + }, + { + "step": { + "ColsDescription": { + "columns": [] + } + } + }, + { + "step": { + "BeginRows": {} + } + }, + { + "step": { + "FinishRows": {} + } + }, + { + "step": { + "FinishStep": { + "affected_row_count": 0, + "last_insert_rowid": null + } + } + }, + { + "step": { + "Finish": { + "last_frame_no": null, + "state": 0 + } + } + } + ] + } + } +} diff --git a/libsql-server/src/rpc/streaming_exec.rs b/libsql-server/src/rpc/streaming_exec.rs index 420bc90fe7..9447e0719d 100644 --- a/libsql-server/src/rpc/streaming_exec.rs +++ b/libsql-server/src/rpc/streaming_exec.rs @@ -58,6 +58,8 @@ where let conn = Arc::new(conn); pin!(request_stream); + + let mut last_request_id = None; loop { tokio::select! { @@ -71,6 +73,16 @@ where } Ok(req) => { let request_id = req.request_id; + if let Some(last_req_id) = last_request_id { + if request_id <= last_req_id { + tracing::error!("received request with id less than last received request, closing stream"); + yield Err(Status::new(Code::InvalidArgument, "received request with id less than last received request, closing stream")); + return; + } + } + + last_request_id = Some(request_id); + match req.request { Some(Request::Execute(pgm)) => { let Ok(pgm) = @@ -343,7 +355,7 @@ impl QueryResultBuilder for StreamResponseBuilder { #[cfg(test)] pub mod test { - use insta::{assert_debug_snapshot, assert_snapshot}; + use insta::{assert_debug_snapshot, assert_snapshot, assert_json_snapshot}; use tempfile::tempdir; use tokio_stream::wrappers::ReceiverStream; @@ -501,6 +513,54 @@ pub mod test { assert_eq!(resp.request_id, 1); } + #[tokio::test] + async fn perform_multiple_queries() { + let tmp = tempdir().unwrap(); + let conn = LibSqlConnection::new_test(tmp.path()); + let (snd, rcv) = mpsc::channel(1); + let auth = Authenticated::Authorized(Authorized { + namespace: None, + permission: Permission::FullAccess, + }); + let stream = make_proxy_stream(conn, auth, ReceiverStream::new(rcv)); + + pin!(stream); + + // request 0 should be dropped, and request 1 should be processed instead + let req1 = exec_req_stmt("create table test (foo)", 0); + snd.send(Ok(req1)).await.unwrap(); + assert_json_snapshot!(stream.next().await.unwrap().unwrap()); + + let req2 = exec_req_stmt("insert into test values (12)", 1); + snd.send(Ok(req2)).await.unwrap(); + assert_json_snapshot!(stream.next().await.unwrap().unwrap()); + } + + #[tokio::test] + async fn query_number_less_than_previous_query() { + let tmp = tempdir().unwrap(); + let conn = LibSqlConnection::new_test(tmp.path()); + let (snd, rcv) = mpsc::channel(1); + let auth = Authenticated::Authorized(Authorized { + namespace: None, + permission: Permission::FullAccess, + }); + let stream = make_proxy_stream(conn, auth, ReceiverStream::new(rcv)); + + pin!(stream); + + // request 0 should be dropped, and request 1 should be processed instead + let req1 = exec_req_stmt("create table test (foo)", 0); + snd.send(Ok(req1)).await.unwrap(); + assert_json_snapshot!(stream.next().await.unwrap().unwrap()); + + let req2 = exec_req_stmt("insert into test values (12)", 0); + snd.send(Ok(req2)).await.unwrap(); + let resp = stream.next().await.unwrap(); + assert!(resp.is_err()); + assert_debug_snapshot!(resp); + } + #[tokio::test] async fn describe() { let tmp = tempdir().unwrap(); From a7fe6bd558241770cee693753dea407d7c8b82fc Mon Sep 17 00:00:00 2001 From: ad hoc Date: Tue, 31 Oct 2023 17:09:34 +0100 Subject: [PATCH 21/26] restore cargo lock after bad rebase --- Cargo.lock | 10 ++++++++++ libsql-replication/build.rs | 21 --------------------- libsql-replication/src/rpc.rs | 4 ++-- libsql-replication/tests/bootstrap.rs | 10 +--------- 4 files changed, 13 insertions(+), 32 deletions(-) delete mode 100644 libsql-replication/build.rs diff --git a/Cargo.lock b/Cargo.lock index 496b008d63..4bc838e0b3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1638,6 +1638,15 @@ dependencies = [ "syn 2.0.38", ] +[[package]] +name = "futures-option" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01141bf1c1a803403a2e7ae6a7eb20506c6bd2ebd209fc2424d05572a78af5f5" +dependencies = [ + "futures-core", +] + [[package]] name = "futures-sink" version = "0.3.29" @@ -3754,6 +3763,7 @@ dependencies = [ "fallible-iterator 0.3.0", "futures", "futures-core", + "futures-option", "hmac", "hyper", "hyper-rustls 0.24.1", diff --git a/libsql-replication/build.rs b/libsql-replication/build.rs deleted file mode 100644 index b82b69a98f..0000000000 --- a/libsql-replication/build.rs +++ /dev/null @@ -1,21 +0,0 @@ -use prost_build::Config; - -fn main() -> Result<(), Box> { - std::env::set_var("PROTOC", protobuf_src::protoc()); - - let mut config = Config::new(); - config.bytes([".wal_log"]); - tonic_build::configure() - .protoc_arg("--experimental_allow_proto3_optional") - .type_attribute(".proxy", "#[derive(serde::Serialize, serde::Deserialize)]") - .type_attribute(".proxy", "#[cfg_attr(test, derive(arbitrary::Arbitrary))]") - .compile_with_config( - config, - &["proto/replication_log.proto", "proto/proxy.proto"], - &["proto"], - )?; - - println!("cargo:rerun-if-changed=proto/"); - - Ok(()) -} diff --git a/libsql-replication/src/rpc.rs b/libsql-replication/src/rpc.rs index 667daad513..45842c0282 100644 --- a/libsql-replication/src/rpc.rs +++ b/libsql-replication/src/rpc.rs @@ -7,7 +7,7 @@ pub mod proxy { impl From> for RowValue { fn from(value: ValueRef<'_>) -> Self { use row_value::Value; - + let value = Some(match value { ValueRef::Null => Value::Null(true), ValueRef::Integer(i) => Value::Integer(i), @@ -15,7 +15,7 @@ pub mod proxy { ValueRef::Text(s) => Value::Text(String::from_utf8(s.to_vec()).unwrap()), ValueRef::Blob(b) => Value::Blob(b.to_vec()), }); - + RowValue { value } } } diff --git a/libsql-replication/tests/bootstrap.rs b/libsql-replication/tests/bootstrap.rs index 110786297e..4bc868f2e5 100644 --- a/libsql-replication/tests/bootstrap.rs +++ b/libsql-replication/tests/bootstrap.rs @@ -17,15 +17,7 @@ fn bootstrap() { .build_server(true) .build_transport(true) .out_dir(&out_dir) - .type_attribute(".proxy", "#[cfg_attr(test, derive(arbitrary::Arbitrary))]") - .field_attribute( - ".proxy.Value.data", - "#[cfg_attr(test, arbitrary(with = crate::test::arbitrary_rpc_value))]", - ) - .field_attribute( - ".proxy.ProgramReq.namespace", - "#[cfg_attr(test, arbitrary(with = crate::test::arbitrary_bytes))]", - ) + .type_attribute(".proxy", "#[derive(serde::Serialize, serde::Deserialize)]") .compile_with_config(config, iface_files, dirs) .unwrap(); From f63f358eafe4e40ba139f13eb0dbae718d7a59fb Mon Sep 17 00:00:00 2001 From: ad hoc Date: Tue, 31 Oct 2023 17:27:05 +0100 Subject: [PATCH 22/26] fix generated rpc code --- libsql-replication/src/generated/proxy.rs | 459 ++++++++++++++++++---- libsql-replication/src/rpc.rs | 4 +- 2 files changed, 377 insertions(+), 86 deletions(-) diff --git a/libsql-replication/src/generated/proxy.rs b/libsql-replication/src/generated/proxy.rs index 48c0fb545a..7b1633a781 100644 --- a/libsql-replication/src/generated/proxy.rs +++ b/libsql-replication/src/generated/proxy.rs @@ -1,4 +1,4 @@ -#[cfg_attr(test, derive(arbitrary::Arbitrary))] +#[derive(serde::Serialize, serde::Deserialize)] #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct Queries { @@ -8,7 +8,7 @@ pub struct Queries { #[prost(string, tag = "2")] pub client_id: ::prost::alloc::string::String, } -#[cfg_attr(test, derive(arbitrary::Arbitrary))] +#[derive(serde::Serialize, serde::Deserialize)] #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct Query { @@ -21,7 +21,7 @@ pub struct Query { } /// Nested message and enum types in `Query`. pub mod query { - #[cfg_attr(test, derive(arbitrary::Arbitrary))] + #[derive(serde::Serialize, serde::Deserialize)] #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Oneof)] pub enum Params { @@ -31,14 +31,14 @@ pub mod query { Named(super::Named), } } -#[cfg_attr(test, derive(arbitrary::Arbitrary))] +#[derive(serde::Serialize, serde::Deserialize)] #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct Positional { #[prost(message, repeated, tag = "1")] pub values: ::prost::alloc::vec::Vec, } -#[cfg_attr(test, derive(arbitrary::Arbitrary))] +#[derive(serde::Serialize, serde::Deserialize)] #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct Named { @@ -47,7 +47,7 @@ pub struct Named { #[prost(message, repeated, tag = "2")] pub values: ::prost::alloc::vec::Vec, } -#[cfg_attr(test, derive(arbitrary::Arbitrary))] +#[derive(serde::Serialize, serde::Deserialize)] #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct QueryResult { @@ -56,7 +56,7 @@ pub struct QueryResult { } /// Nested message and enum types in `QueryResult`. pub mod query_result { - #[cfg_attr(test, derive(arbitrary::Arbitrary))] + #[derive(serde::Serialize, serde::Deserialize)] #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Oneof)] pub enum RowResult { @@ -66,7 +66,7 @@ pub mod query_result { Row(super::ResultRows), } } -#[cfg_attr(test, derive(arbitrary::Arbitrary))] +#[derive(serde::Serialize, serde::Deserialize)] #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct Error { @@ -79,7 +79,7 @@ pub struct Error { } /// Nested message and enum types in `Error`. pub mod error { - #[cfg_attr(test, derive(arbitrary::Arbitrary))] + #[derive(serde::Serialize, serde::Deserialize)] #[derive( Clone, Copy, @@ -105,25 +105,25 @@ pub mod error { /// (if the ProtoBuf definition does not change) and safe for programmatic use. pub fn as_str_name(&self) -> &'static str { match self { - ErrorCode::SqlError => "SQLError", - ErrorCode::TxBusy => "TxBusy", - ErrorCode::TxTimeout => "TxTimeout", - ErrorCode::Internal => "Internal", + ErrorCode::SqlError => "SQL_ERROR", + ErrorCode::TxBusy => "TX_BUSY", + ErrorCode::TxTimeout => "TX_TIMEOUT", + ErrorCode::Internal => "INTERNAL", } } /// Creates an enum from field names used in the ProtoBuf definition. pub fn from_str_name(value: &str) -> ::core::option::Option { match value { - "SQLError" => Some(Self::SqlError), - "TxBusy" => Some(Self::TxBusy), - "TxTimeout" => Some(Self::TxTimeout), - "Internal" => Some(Self::Internal), + "SQL_ERROR" => Some(Self::SqlError), + "TX_BUSY" => Some(Self::TxBusy), + "TX_TIMEOUT" => Some(Self::TxTimeout), + "INTERNAL" => Some(Self::Internal), _ => None, } } } } -#[cfg_attr(test, derive(arbitrary::Arbitrary))] +#[derive(serde::Serialize, serde::Deserialize)] #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct ResultRows { @@ -136,7 +136,7 @@ pub struct ResultRows { #[prost(int64, optional, tag = "4")] pub last_insert_rowid: ::core::option::Option, } -#[cfg_attr(test, derive(arbitrary::Arbitrary))] +#[derive(serde::Serialize, serde::Deserialize)] #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct DescribeRequest { @@ -145,7 +145,7 @@ pub struct DescribeRequest { #[prost(string, tag = "2")] pub stmt: ::prost::alloc::string::String, } -#[cfg_attr(test, derive(arbitrary::Arbitrary))] +#[derive(serde::Serialize, serde::Deserialize)] #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct DescribeResult { @@ -154,7 +154,7 @@ pub struct DescribeResult { } /// Nested message and enum types in `DescribeResult`. pub mod describe_result { - #[cfg_attr(test, derive(arbitrary::Arbitrary))] + #[derive(serde::Serialize, serde::Deserialize)] #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Oneof)] pub enum DescribeResult { @@ -164,7 +164,7 @@ pub mod describe_result { Description(super::Description), } } -#[cfg_attr(test, derive(arbitrary::Arbitrary))] +#[derive(serde::Serialize, serde::Deserialize)] #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct Description { @@ -175,23 +175,22 @@ pub struct Description { #[prost(uint64, tag = "3")] pub param_count: u64, } -#[cfg_attr(test, derive(arbitrary::Arbitrary))] +#[derive(serde::Serialize, serde::Deserialize)] #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct Value { /// / bincode encoded Value #[prost(bytes = "vec", tag = "1")] - #[cfg_attr(test, arbitrary(with = crate::test::arbitrary_rpc_value))] pub data: ::prost::alloc::vec::Vec, } -#[cfg_attr(test, derive(arbitrary::Arbitrary))] +#[derive(serde::Serialize, serde::Deserialize)] #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct Row { #[prost(message, repeated, tag = "1")] pub values: ::prost::alloc::vec::Vec, } -#[cfg_attr(test, derive(arbitrary::Arbitrary))] +#[derive(serde::Serialize, serde::Deserialize)] #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct Column { @@ -200,81 +199,38 @@ pub struct Column { #[prost(string, optional, tag = "3")] pub decltype: ::core::option::Option<::prost::alloc::string::String>, } -#[cfg_attr(test, derive(arbitrary::Arbitrary))] +#[derive(serde::Serialize, serde::Deserialize)] #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct DisconnectMessage { #[prost(string, tag = "1")] pub client_id: ::prost::alloc::string::String, } -#[cfg_attr(test, derive(arbitrary::Arbitrary))] +#[derive(serde::Serialize, serde::Deserialize)] #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct Ack {} -#[cfg_attr(test, derive(arbitrary::Arbitrary))] +#[derive(serde::Serialize, serde::Deserialize)] #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct ExecuteResults { #[prost(message, repeated, tag = "1")] pub results: ::prost::alloc::vec::Vec, /// / State after executing the queries - #[prost(enumeration = "execute_results::State", tag = "2")] + #[prost(enumeration = "State", tag = "2")] pub state: i32, /// / Primary frame_no after executing the request. #[prost(uint64, optional, tag = "3")] pub current_frame_no: ::core::option::Option, } -/// Nested message and enum types in `ExecuteResults`. -pub mod execute_results { - #[cfg_attr(test, derive(arbitrary::Arbitrary))] - #[derive( - Clone, - Copy, - Debug, - PartialEq, - Eq, - Hash, - PartialOrd, - Ord, - ::prost::Enumeration - )] - #[repr(i32)] - pub enum State { - Init = 0, - Invalid = 1, - Txn = 2, - } - impl State { - /// String value of the enum field names used in the ProtoBuf definition. - /// - /// The values are not transformed in any way and thus are considered stable - /// (if the ProtoBuf definition does not change) and safe for programmatic use. - pub fn as_str_name(&self) -> &'static str { - match self { - State::Init => "Init", - State::Invalid => "Invalid", - State::Txn => "Txn", - } - } - /// Creates an enum from field names used in the ProtoBuf definition. - pub fn from_str_name(value: &str) -> ::core::option::Option { - match value { - "Init" => Some(Self::Init), - "Invalid" => Some(Self::Invalid), - "Txn" => Some(Self::Txn), - _ => None, - } - } - } -} -#[cfg_attr(test, derive(arbitrary::Arbitrary))] +#[derive(serde::Serialize, serde::Deserialize)] #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct Program { #[prost(message, repeated, tag = "1")] pub steps: ::prost::alloc::vec::Vec, } -#[cfg_attr(test, derive(arbitrary::Arbitrary))] +#[derive(serde::Serialize, serde::Deserialize)] #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct Step { @@ -283,7 +239,7 @@ pub struct Step { #[prost(message, optional, tag = "2")] pub query: ::core::option::Option, } -#[cfg_attr(test, derive(arbitrary::Arbitrary))] +#[derive(serde::Serialize, serde::Deserialize)] #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct Cond { @@ -292,7 +248,7 @@ pub struct Cond { } /// Nested message and enum types in `Cond`. pub mod cond { - #[cfg_attr(test, derive(arbitrary::Arbitrary))] + #[derive(serde::Serialize, serde::Deserialize)] #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Oneof)] pub enum Cond { @@ -310,46 +266,46 @@ pub mod cond { IsAutocommit(super::IsAutocommitCond), } } -#[cfg_attr(test, derive(arbitrary::Arbitrary))] +#[derive(serde::Serialize, serde::Deserialize)] #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct OkCond { #[prost(int64, tag = "1")] pub step: i64, } -#[cfg_attr(test, derive(arbitrary::Arbitrary))] +#[derive(serde::Serialize, serde::Deserialize)] #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct ErrCond { #[prost(int64, tag = "1")] pub step: i64, } -#[cfg_attr(test, derive(arbitrary::Arbitrary))] +#[derive(serde::Serialize, serde::Deserialize)] #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct NotCond { #[prost(message, optional, boxed, tag = "1")] pub cond: ::core::option::Option<::prost::alloc::boxed::Box>, } -#[cfg_attr(test, derive(arbitrary::Arbitrary))] +#[derive(serde::Serialize, serde::Deserialize)] #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct AndCond { #[prost(message, repeated, tag = "1")] pub conds: ::prost::alloc::vec::Vec, } -#[cfg_attr(test, derive(arbitrary::Arbitrary))] +#[derive(serde::Serialize, serde::Deserialize)] #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct OrCond { #[prost(message, repeated, tag = "1")] pub conds: ::prost::alloc::vec::Vec, } -#[cfg_attr(test, derive(arbitrary::Arbitrary))] +#[derive(serde::Serialize, serde::Deserialize)] #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct IsAutocommitCond {} -#[cfg_attr(test, derive(arbitrary::Arbitrary))] +#[derive(serde::Serialize, serde::Deserialize)] #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct ProgramReq { @@ -358,6 +314,262 @@ pub struct ProgramReq { #[prost(message, optional, tag = "2")] pub pgm: ::core::option::Option, } +/// / Streaming exec request +#[derive(serde::Serialize, serde::Deserialize)] +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ExecReq { + /// / id of the request. The response will contain this id. + #[prost(uint32, tag = "1")] + pub request_id: u32, + #[prost(oneof = "exec_req::Request", tags = "2, 3")] + pub request: ::core::option::Option, +} +/// Nested message and enum types in `ExecReq`. +pub mod exec_req { + #[derive(serde::Serialize, serde::Deserialize)] + #[allow(clippy::derive_partial_eq_without_eq)] + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum Request { + #[prost(message, tag = "2")] + Execute(super::StreamProgramReq), + #[prost(message, tag = "3")] + Describe(super::StreamDescribeReq), + } +} +/// / Describe request for the streaming protocol +#[derive(serde::Serialize, serde::Deserialize)] +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct StreamProgramReq { + #[prost(message, optional, tag = "1")] + pub pgm: ::core::option::Option, +} +/// / descibre request for the streaming protocol +#[derive(serde::Serialize, serde::Deserialize)] +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct StreamDescribeReq { + #[prost(string, tag = "1")] + pub stmt: ::prost::alloc::string::String, +} +/// / Request response types +#[derive(serde::Serialize, serde::Deserialize)] +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Init {} +#[derive(serde::Serialize, serde::Deserialize)] +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct BeginStep {} +#[derive(serde::Serialize, serde::Deserialize)] +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct FinishStep { + #[prost(uint64, tag = "1")] + pub affected_row_count: u64, + #[prost(int64, optional, tag = "2")] + pub last_insert_rowid: ::core::option::Option, +} +#[derive(serde::Serialize, serde::Deserialize)] +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct StepError { + #[prost(message, optional, tag = "1")] + pub error: ::core::option::Option, +} +#[derive(serde::Serialize, serde::Deserialize)] +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ColsDescription { + #[prost(message, repeated, tag = "1")] + pub columns: ::prost::alloc::vec::Vec, +} +#[derive(serde::Serialize, serde::Deserialize)] +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct RowValue { + #[prost(oneof = "row_value::Value", tags = "1, 2, 3, 4, 5")] + pub value: ::core::option::Option, +} +/// Nested message and enum types in `RowValue`. +pub mod row_value { + #[derive(serde::Serialize, serde::Deserialize)] + #[allow(clippy::derive_partial_eq_without_eq)] + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum Value { + #[prost(string, tag = "1")] + Text(::prost::alloc::string::String), + #[prost(int64, tag = "2")] + Integer(i64), + #[prost(double, tag = "3")] + Real(f64), + #[prost(bytes, tag = "4")] + Blob(::prost::alloc::vec::Vec), + /// null if present + #[prost(bool, tag = "5")] + Null(bool), + } +} +#[derive(serde::Serialize, serde::Deserialize)] +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct BeginRows {} +#[derive(serde::Serialize, serde::Deserialize)] +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct BeginRow {} +#[derive(serde::Serialize, serde::Deserialize)] +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct AddRowValue { + #[prost(message, optional, tag = "1")] + pub val: ::core::option::Option, +} +#[derive(serde::Serialize, serde::Deserialize)] +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct FinishRow {} +#[derive(serde::Serialize, serde::Deserialize)] +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct FinishRows {} +#[derive(serde::Serialize, serde::Deserialize)] +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Finish { + #[prost(uint64, optional, tag = "1")] + pub last_frame_no: ::core::option::Option, + #[prost(enumeration = "State", tag = "2")] + pub state: i32, +} +/// / Stream execx dexcribe response messages +#[derive(serde::Serialize, serde::Deserialize)] +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct DescribeParam { + #[prost(string, optional, tag = "1")] + pub name: ::core::option::Option<::prost::alloc::string::String>, +} +#[derive(serde::Serialize, serde::Deserialize)] +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct DescribeCol { + #[prost(string, tag = "1")] + pub name: ::prost::alloc::string::String, + #[prost(string, optional, tag = "2")] + pub decltype: ::core::option::Option<::prost::alloc::string::String>, +} +#[derive(serde::Serialize, serde::Deserialize)] +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct DescribeResp { + #[prost(message, repeated, tag = "1")] + pub params: ::prost::alloc::vec::Vec, + #[prost(message, repeated, tag = "2")] + pub cols: ::prost::alloc::vec::Vec, + #[prost(bool, tag = "3")] + pub is_explain: bool, + #[prost(bool, tag = "4")] + pub is_readonly: bool, +} +#[derive(serde::Serialize, serde::Deserialize)] +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct RespStep { + #[prost(oneof = "resp_step::Step", tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11")] + pub step: ::core::option::Option, +} +/// Nested message and enum types in `RespStep`. +pub mod resp_step { + #[derive(serde::Serialize, serde::Deserialize)] + #[allow(clippy::derive_partial_eq_without_eq)] + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum Step { + #[prost(message, tag = "1")] + Init(super::Init), + #[prost(message, tag = "2")] + BeginStep(super::BeginStep), + #[prost(message, tag = "3")] + FinishStep(super::FinishStep), + #[prost(message, tag = "4")] + StepError(super::StepError), + #[prost(message, tag = "5")] + ColsDescription(super::ColsDescription), + #[prost(message, tag = "6")] + BeginRows(super::BeginRows), + #[prost(message, tag = "7")] + BeginRow(super::BeginRow), + #[prost(message, tag = "8")] + AddRowValue(super::AddRowValue), + #[prost(message, tag = "9")] + FinishRow(super::FinishRow), + #[prost(message, tag = "10")] + FinishRows(super::FinishRows), + #[prost(message, tag = "11")] + Finish(super::Finish), + } +} +#[derive(serde::Serialize, serde::Deserialize)] +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ProgramResp { + #[prost(message, repeated, tag = "1")] + pub steps: ::prost::alloc::vec::Vec, +} +#[derive(serde::Serialize, serde::Deserialize)] +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ExecResp { + #[prost(uint32, tag = "1")] + pub request_id: u32, + #[prost(oneof = "exec_resp::Response", tags = "2, 3, 4")] + pub response: ::core::option::Option, +} +/// Nested message and enum types in `ExecResp`. +pub mod exec_resp { + #[derive(serde::Serialize, serde::Deserialize)] + #[allow(clippy::derive_partial_eq_without_eq)] + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum Response { + #[prost(message, tag = "2")] + ProgramResp(super::ProgramResp), + #[prost(message, tag = "3")] + DescribeResp(super::DescribeResp), + #[prost(message, tag = "4")] + Error(super::Error), + } +} +#[derive(serde::Serialize, serde::Deserialize)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum State { + Init = 0, + Invalid = 1, + Txn = 2, +} +impl State { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + State::Init => "INIT", + State::Invalid => "INVALID", + State::Txn => "TXN", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "INIT" => Some(Self::Init), + "INVALID" => Some(Self::Invalid), + "TXN" => Some(Self::Txn), + _ => None, + } + } +} /// Generated client implementations. pub mod proxy_client { #![allow(unused_variables, dead_code, missing_docs, clippy::let_unit_value)] @@ -443,6 +655,29 @@ pub mod proxy_client { self.inner = self.inner.max_encoding_message_size(limit); self } + pub async fn stream_exec( + &mut self, + request: impl tonic::IntoStreamingRequest, + ) -> std::result::Result< + tonic::Response>, + tonic::Status, + > { + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static("/proxy.Proxy/StreamExec"); + let mut req = request.into_streaming_request(); + req.extensions_mut().insert(GrpcMethod::new("proxy.Proxy", "StreamExec")); + self.inner.streaming(req, path, codec).await + } + /// Deprecated: pub async fn execute( &mut self, request: impl tonic::IntoRequest, @@ -509,6 +744,17 @@ pub mod proxy_server { /// Generated trait containing gRPC methods that should be implemented for use with ProxyServer. #[async_trait] pub trait Proxy: Send + Sync + 'static { + /// Server streaming response type for the StreamExec method. + type StreamExecStream: tonic::codegen::tokio_stream::Stream< + Item = std::result::Result, + > + + Send + + 'static; + async fn stream_exec( + &self, + request: tonic::Request>, + ) -> std::result::Result, tonic::Status>; + /// Deprecated: async fn execute( &self, request: tonic::Request, @@ -601,6 +847,51 @@ pub mod proxy_server { fn call(&mut self, req: http::Request) -> Self::Future { let inner = self.inner.clone(); match req.uri().path() { + "/proxy.Proxy/StreamExec" => { + #[allow(non_camel_case_types)] + struct StreamExecSvc(pub Arc); + impl tonic::server::StreamingService + for StreamExecSvc { + type Response = super::ExecResp; + type ResponseStream = T::StreamExecStream; + type Future = BoxFuture< + tonic::Response, + tonic::Status, + >; + fn call( + &mut self, + request: tonic::Request>, + ) -> Self::Future { + let inner = Arc::clone(&self.0); + let fut = async move { + ::stream_exec(&inner, request).await + }; + Box::pin(fut) + } + } + let accept_compression_encodings = self.accept_compression_encodings; + let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.clone(); + let fut = async move { + let inner = inner.0; + let method = StreamExecSvc(inner); + let codec = tonic::codec::ProstCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec) + .apply_compression_config( + accept_compression_encodings, + send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.streaming(method, req).await; + Ok(res) + }; + Box::pin(fut) + } "/proxy.Proxy/Execute" => { #[allow(non_camel_case_types)] struct ExecuteSvc(pub Arc); diff --git a/libsql-replication/src/rpc.rs b/libsql-replication/src/rpc.rs index 45842c0282..e062e8280f 100644 --- a/libsql-replication/src/rpc.rs +++ b/libsql-replication/src/rpc.rs @@ -1,9 +1,9 @@ pub mod proxy { #![allow(clippy::all)] include!("generated/proxy.rs"); - + use sqld_libsql_bindings::rusqlite::types::ValueRef; - + impl From> for RowValue { fn from(value: ValueRef<'_>) -> Self { use row_value::Value; From 4f617afc53ee11b0c5e1db65f356c918b91ce653 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Tue, 31 Oct 2023 17:40:09 +0100 Subject: [PATCH 23/26] fix clippy and fmt --- libsql-replication/src/rpc.rs | 8 ++++---- libsql-server/src/http/user/mod.rs | 2 +- libsql-server/src/query_analysis.rs | 4 +++- libsql-server/src/rpc/streaming_exec.rs | 4 ++-- libsql-shell/src/main.rs | 6 ++---- xtask/src/main.rs | 2 +- 6 files changed, 13 insertions(+), 13 deletions(-) diff --git a/libsql-replication/src/rpc.rs b/libsql-replication/src/rpc.rs index e062e8280f..667daad513 100644 --- a/libsql-replication/src/rpc.rs +++ b/libsql-replication/src/rpc.rs @@ -1,13 +1,13 @@ pub mod proxy { #![allow(clippy::all)] include!("generated/proxy.rs"); - + use sqld_libsql_bindings::rusqlite::types::ValueRef; - + impl From> for RowValue { fn from(value: ValueRef<'_>) -> Self { use row_value::Value; - + let value = Some(match value { ValueRef::Null => Value::Null(true), ValueRef::Integer(i) => Value::Integer(i), @@ -15,7 +15,7 @@ pub mod proxy { ValueRef::Text(s) => Value::Text(String::from_utf8(s.to_vec()).unwrap()), ValueRef::Blob(b) => Value::Blob(b.to_vec()), }); - + RowValue { value } } } diff --git a/libsql-server/src/http/user/mod.rs b/libsql-server/src/http/user/mod.rs index 74aa35a38e..3958d78a84 100644 --- a/libsql-server/src/http/user/mod.rs +++ b/libsql-server/src/http/user/mod.rs @@ -37,7 +37,7 @@ use crate::http::user::types::HttpQuery; use crate::namespace::{MakeNamespace, NamespaceStore}; use crate::net::Accept; use crate::query::{self, Query}; -use crate::query_analysis::{predict_final_state, TxnStatus, Statement}; +use crate::query_analysis::{predict_final_state, Statement, TxnStatus}; use crate::query_result_builder::QueryResultBuilder; use crate::rpc::proxy::rpc::proxy_server::{Proxy, ProxyServer}; use crate::rpc::replication_log::rpc::replication_log_server::ReplicationLog; diff --git a/libsql-server/src/query_analysis.rs b/libsql-server/src/query_analysis.rs index 174688112c..0555a2ae75 100644 --- a/libsql-server/src/query_analysis.rs +++ b/libsql-server/src/query_analysis.rs @@ -213,7 +213,9 @@ pub enum TxnStatus { impl TxnStatus { pub fn step(&mut self, kind: StmtKind) { *self = match (*self, kind) { - (TxnStatus::Txn, StmtKind::TxnBegin) | (TxnStatus::Init, StmtKind::TxnEnd) => TxnStatus::Invalid, + (TxnStatus::Txn, StmtKind::TxnBegin) | (TxnStatus::Init, StmtKind::TxnEnd) => { + TxnStatus::Invalid + } (TxnStatus::Txn, StmtKind::TxnEnd) => TxnStatus::Init, (state, StmtKind::Other | StmtKind::Write | StmtKind::Read) => state, (TxnStatus::Invalid, _) => TxnStatus::Invalid, diff --git a/libsql-server/src/rpc/streaming_exec.rs b/libsql-server/src/rpc/streaming_exec.rs index 9447e0719d..0424fe96be 100644 --- a/libsql-server/src/rpc/streaming_exec.rs +++ b/libsql-server/src/rpc/streaming_exec.rs @@ -58,7 +58,7 @@ where let conn = Arc::new(conn); pin!(request_stream); - + let mut last_request_id = None; loop { @@ -355,7 +355,7 @@ impl QueryResultBuilder for StreamResponseBuilder { #[cfg(test)] pub mod test { - use insta::{assert_debug_snapshot, assert_snapshot, assert_json_snapshot}; + use insta::{assert_debug_snapshot, assert_json_snapshot, assert_snapshot}; use tempfile::tempdir; use tokio_stream::wrappers::ReceiverStream; diff --git a/libsql-shell/src/main.rs b/libsql-shell/src/main.rs index ce05b5af44..aaefb06114 100644 --- a/libsql-shell/src/main.rs +++ b/libsql-shell/src/main.rs @@ -426,10 +426,8 @@ impl Shell { })?; let mut mapped_rows = vec![]; - for row in rows { - if let Ok(r) = row { - mapped_rows.push(r); - } + for row in rows.flatten() { + mapped_rows.push(row); } mapped_rows }; diff --git a/xtask/src/main.rs b/xtask/src/main.rs index 69a86ca6da..997f8c1611 100644 --- a/xtask/src/main.rs +++ b/xtask/src/main.rs @@ -45,7 +45,7 @@ fn build() -> Result<()> { fn run_cargo(cmd: &[&str]) -> Result<()> { let mut out = Command::new("cargo") - .args(&cmd[..]) + .args(cmd) .spawn() .context("spawn")?; From 6ab684564b2ca9e15d37f6e9cd98080002a87e96 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Tue, 31 Oct 2023 17:45:45 +0100 Subject: [PATCH 24/26] remote future-option dep --- Cargo.lock | 10 ---------- libsql-server/Cargo.toml | 1 - libsql-server/src/rpc/streaming_exec.rs | 11 ++++++----- xtask/src/main.rs | 5 +---- 4 files changed, 7 insertions(+), 20 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4bc838e0b3..496b008d63 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1638,15 +1638,6 @@ dependencies = [ "syn 2.0.38", ] -[[package]] -name = "futures-option" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01141bf1c1a803403a2e7ae6a7eb20506c6bd2ebd209fc2424d05572a78af5f5" -dependencies = [ - "futures-core", -] - [[package]] name = "futures-sink" version = "0.3.29" @@ -3763,7 +3754,6 @@ dependencies = [ "fallible-iterator 0.3.0", "futures", "futures-core", - "futures-option", "hmac", "hyper", "hyper-rustls 0.24.1", diff --git a/libsql-server/Cargo.toml b/libsql-server/Cargo.toml index 6f500ab786..e8e006cbff 100644 --- a/libsql-server/Cargo.toml +++ b/libsql-server/Cargo.toml @@ -26,7 +26,6 @@ enclose = "1.1" fallible-iterator = "0.3.0" futures = "0.3.25" futures-core = "0.3" -futures-option = "0.2.0" hmac = "0.12" hyper = { version = "0.14.23", features = ["http2"] } hyper-rustls = { git = "https://github.com/rustls/hyper-rustls.git", rev = "163b3f5" } diff --git a/libsql-server/src/rpc/streaming_exec.rs b/libsql-server/src/rpc/streaming_exec.rs index 0424fe96be..51e4235f42 100644 --- a/libsql-server/src/rpc/streaming_exec.rs +++ b/libsql-server/src/rpc/streaming_exec.rs @@ -1,8 +1,9 @@ +use std::future::poll_fn; use std::sync::Arc; +use std::task::Poll; use futures_core::future::BoxFuture; use futures_core::Stream; -use futures_option::OptionExt; use libsql_replication::rpc::proxy::exec_req::Request; use libsql_replication::rpc::proxy::exec_resp::{self, Response}; use libsql_replication::rpc::proxy::resp_step::Step; @@ -53,7 +54,7 @@ where C: Connection, { async_stream::stream! { - let mut current_request_fut: Option, u32)>> = None; + let mut current_request_fut: BoxFuture<'static, (crate::Result<()>, u32)> = Box::pin(poll_fn(|_| Poll::Pending)); let (snd, mut recv) = mpsc::channel(1); let conn = Arc::new(conn); @@ -107,7 +108,7 @@ where (ret, request_id) }; - current_request_fut.replace(Box::pin(fut)); + current_request_fut = Box::pin(fut); } Some(Request::Describe(StreamDescribeReq { stmt })) => { let auth = auth.clone(); @@ -135,7 +136,7 @@ where (ret, request_id) }; - current_request_fut.replace(Box::pin(fut)); + current_request_fut = Box::pin(fut); }, None => { @@ -149,7 +150,7 @@ where Some(res) = recv.recv() => { yield Ok(res); }, - (ret, request_id) = current_request_fut.current(), if current_request_fut.is_some() => { + (ret, request_id) = &mut current_request_fut => { if let Err(e) = ret { yield Ok(ExecResp { request_id, response: Some(Response::Error(e.into())) }) } diff --git a/xtask/src/main.rs b/xtask/src/main.rs index 997f8c1611..2862589393 100644 --- a/xtask/src/main.rs +++ b/xtask/src/main.rs @@ -44,10 +44,7 @@ fn build() -> Result<()> { } fn run_cargo(cmd: &[&str]) -> Result<()> { - let mut out = Command::new("cargo") - .args(cmd) - .spawn() - .context("spawn")?; + let mut out = Command::new("cargo").args(cmd).spawn().context("spawn")?; let exit = out.wait().context("wait")?; From 26696c0582460249b5c742e6c6c36db408a4ce5f Mon Sep 17 00:00:00 2001 From: ad hoc Date: Wed, 1 Nov 2023 10:31:28 +0100 Subject: [PATCH 25/26] remove default type on WriteProxyConnection --- libsql-server/src/connection/mod.rs | 2 +- libsql-server/src/connection/write_proxy.rs | 10 ++++++---- libsql-server/src/database.rs | 6 +++--- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/libsql-server/src/connection/mod.rs b/libsql-server/src/connection/mod.rs index 43334bc194..871a4e6fd1 100644 --- a/libsql-server/src/connection/mod.rs +++ b/libsql-server/src/connection/mod.rs @@ -50,7 +50,7 @@ pub trait Connection: Send + Sync + 'static { let mut steps = make_batch_program(batch); if !steps.is_empty() { - // We add a conditional rollback step if the last step was not sucessful. + // We add a conditional rollback step if the last step was not successful. steps.push(Step { query: Query { stmt: Statement::parse("ROLLBACK").next().unwrap().unwrap(), diff --git a/libsql-server/src/connection/write_proxy.rs b/libsql-server/src/connection/write_proxy.rs index f7c96b8efc..0882af4780 100644 --- a/libsql-server/src/connection/write_proxy.rs +++ b/libsql-server/src/connection/write_proxy.rs @@ -32,6 +32,8 @@ use super::program::DescribeResponse; use super::Connection; use super::{MakeConnection, Program}; +pub type RpcStream = Streaming; + pub struct MakeWriteProxyConn { client: ProxyClient, stats: Arc, @@ -85,7 +87,7 @@ impl MakeWriteProxyConn { #[async_trait::async_trait] impl MakeConnection for MakeWriteProxyConn { - type Connection = WriteProxyConnection; + type Connection = WriteProxyConnection; async fn create(&self) -> Result { let db = WriteProxyConnection::new( self.client.clone(), @@ -104,7 +106,7 @@ impl MakeConnection for MakeWriteProxyConn { } } -pub struct WriteProxyConnection> { +pub struct WriteProxyConnection { /// Lazily initialized read connection read_conn: LibSqlConnection, write_proxy: ProxyClient, @@ -122,7 +124,7 @@ pub struct WriteProxyConnection> { remote_conn: Mutex>>, } -impl WriteProxyConnection { +impl WriteProxyConnection { #[allow(clippy::too_many_arguments)] async fn new( write_proxy: ProxyClient, @@ -389,7 +391,7 @@ where } #[async_trait::async_trait] -impl Connection for WriteProxyConnection { +impl Connection for WriteProxyConnection { async fn execute_program( &self, pgm: Program, diff --git a/libsql-server/src/database.rs b/libsql-server/src/database.rs index 60ca8e4b12..291f361908 100644 --- a/libsql-server/src/database.rs +++ b/libsql-server/src/database.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use crate::connection::libsql::LibSqlConnection; -use crate::connection::write_proxy::WriteProxyConnection; +use crate::connection::write_proxy::{RpcStream, WriteProxyConnection}; use crate::connection::{Connection, MakeConnection, TrackedConnection}; use crate::replication::{ReplicationLogger, ReplicationLoggerHook}; @@ -15,11 +15,11 @@ pub trait Database: Sync + Send + 'static { pub struct ReplicaDatabase { pub connection_maker: - Arc>>, + Arc>>>, } impl Database for ReplicaDatabase { - type Connection = TrackedConnection; + type Connection = TrackedConnection>; fn connection_maker(&self) -> Arc> { self.connection_maker.clone() From 01a2a5f7646361ccf6b99e9cd0a6b2088a997c87 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Fri, 3 Nov 2023 11:16:07 +0100 Subject: [PATCH 26/26] fix future re-poll --- libsql-server/src/rpc/streaming_exec.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/libsql-server/src/rpc/streaming_exec.rs b/libsql-server/src/rpc/streaming_exec.rs index 51e4235f42..ca0e515bf5 100644 --- a/libsql-server/src/rpc/streaming_exec.rs +++ b/libsql-server/src/rpc/streaming_exec.rs @@ -54,7 +54,8 @@ where C: Connection, { async_stream::stream! { - let mut current_request_fut: BoxFuture<'static, (crate::Result<()>, u32)> = Box::pin(poll_fn(|_| Poll::Pending)); + let never = || Box::pin(poll_fn(|_| Poll::Pending)); + let mut current_request_fut: BoxFuture<'static, (crate::Result<()>, u32)> = never(); let (snd, mut recv) = mpsc::channel(1); let conn = Arc::new(conn); @@ -151,6 +152,7 @@ where yield Ok(res); }, (ret, request_id) = &mut current_request_fut => { + current_request_fut = never(); if let Err(e) = ret { yield Ok(ExecResp { request_id, response: Some(Response::Error(e.into())) }) }