From 417c8bf9f1d6eaed38a06d723a97a19de41a977c Mon Sep 17 00:00:00 2001 From: ad hoc Date: Mon, 2 Sep 2024 10:26:05 +0200 Subject: [PATCH 1/5] generate admin_shell service --- Cargo.lock | 2 + libsql-server/Cargo.toml | 2 + libsql-server/proto/admin_shell.proto | 42 +++ libsql-server/src/generated/admin_shell.rs | 365 +++++++++++++++++++++ libsql-server/src/http/admin/mod.rs | 1 + libsql-server/tests/bootstrap.rs | 34 ++ 6 files changed, 446 insertions(+) create mode 100644 libsql-server/proto/admin_shell.proto create mode 100644 libsql-server/src/generated/admin_shell.rs create mode 100644 libsql-server/tests/bootstrap.rs diff --git a/Cargo.lock b/Cargo.lock index cc5ba428fa..9faa2c69f4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3205,6 +3205,7 @@ dependencies = [ "priority-queue 1.4.0", "proptest", "prost", + "prost-build", "rand", "regex", "reqwest", @@ -3227,6 +3228,7 @@ dependencies = [ "tokio-tungstenite", "tokio-util", "tonic 0.11.0", + "tonic-build 0.11.0", "tonic-web", "tower", "tower-http 0.3.5", diff --git a/libsql-server/Cargo.toml b/libsql-server/Cargo.toml index 4bcaa2a933..b117d8ac47 100644 --- a/libsql-server/Cargo.toml +++ b/libsql-server/Cargo.toml @@ -112,6 +112,8 @@ metrics-util = "0.15" s3s = "0.8.1" s3s-fs = "0.8.1" ring = { version = "0.17.8", features = ["std"] } +tonic-build = "0.11" +prost-build = "0.12" [build-dependencies] vergen = { version = "8", features = ["build", "git", "gitcl"] } diff --git a/libsql-server/proto/admin_shell.proto b/libsql-server/proto/admin_shell.proto new file mode 100644 index 0000000000..ffa8218e54 --- /dev/null +++ b/libsql-server/proto/admin_shell.proto @@ -0,0 +1,42 @@ +syntax = "proto3"; + +package admin_shell; + +message Query { + string query = 1; +} + +message Value { + oneof value { + Null null = 1; + double real = 2; + int64 integer = 3; + string text = 4; + bytes blob = 5; + } +} + +message Null {} + +message Row { + repeated Value values = 1; +} + +message Rows { + repeated Row rows = 1; +} + +message Error { + string Error = 1; +} + +message Response { + oneof resp { + Rows rows = 1; + Error error = 2; + } +} + +service AdminShellService { + rpc Shell(stream Query) returns (stream Response) {} +} diff --git a/libsql-server/src/generated/admin_shell.rs b/libsql-server/src/generated/admin_shell.rs new file mode 100644 index 0000000000..0b49a45c96 --- /dev/null +++ b/libsql-server/src/generated/admin_shell.rs @@ -0,0 +1,365 @@ +// This file is @generated by prost-build. +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Query { + #[prost(string, tag = "1")] + pub query: ::prost::alloc::string::String, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Value { + #[prost(oneof = "value::Value", tags = "1, 2, 3, 4, 5")] + pub value: ::core::option::Option, +} +/// Nested message and enum types in `Value`. +pub mod value { + #[allow(clippy::derive_partial_eq_without_eq)] + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum Value { + #[prost(message, tag = "1")] + Null(super::Null), + #[prost(double, tag = "2")] + Real(f64), + #[prost(int64, tag = "3")] + Integer(i64), + #[prost(string, tag = "4")] + Text(::prost::alloc::string::String), + #[prost(bytes, tag = "5")] + Blob(::prost::alloc::vec::Vec), + } +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Null {} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Row { + #[prost(message, repeated, tag = "1")] + pub values: ::prost::alloc::vec::Vec, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Rows { + #[prost(message, repeated, tag = "1")] + pub rows: ::prost::alloc::vec::Vec, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Error { + #[prost(string, tag = "1")] + pub error: ::prost::alloc::string::String, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Response { + #[prost(oneof = "response::Resp", tags = "1, 2")] + pub resp: ::core::option::Option, +} +/// Nested message and enum types in `Response`. +pub mod response { + #[allow(clippy::derive_partial_eq_without_eq)] + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum Resp { + #[prost(message, tag = "1")] + Rows(super::Rows), + #[prost(message, tag = "2")] + Error(super::Error), + } +} +/// Generated client implementations. +pub mod admin_shell_service_client { + #![allow(unused_variables, dead_code, missing_docs, clippy::let_unit_value)] + use tonic::codegen::*; + use tonic::codegen::http::Uri; + #[derive(Debug, Clone)] + pub struct AdminShellServiceClient { + inner: tonic::client::Grpc, + } + impl AdminShellServiceClient { + /// Attempt to create a new client by connecting to a given endpoint. + pub async fn connect(dst: D) -> Result + where + D: TryInto, + D::Error: Into, + { + let conn = tonic::transport::Endpoint::new(dst)?.connect().await?; + Ok(Self::new(conn)) + } + } + impl AdminShellServiceClient + where + T: tonic::client::GrpcService, + T::Error: Into, + T::ResponseBody: Body + Send + 'static, + ::Error: Into + Send, + { + pub fn new(inner: T) -> Self { + let inner = tonic::client::Grpc::new(inner); + Self { inner } + } + pub fn with_origin(inner: T, origin: Uri) -> Self { + let inner = tonic::client::Grpc::with_origin(inner, origin); + Self { inner } + } + pub fn with_interceptor( + inner: T, + interceptor: F, + ) -> AdminShellServiceClient> + where + F: tonic::service::Interceptor, + T::ResponseBody: Default, + T: tonic::codegen::Service< + http::Request, + Response = http::Response< + >::ResponseBody, + >, + >, + , + >>::Error: Into + Send + Sync, + { + AdminShellServiceClient::new(InterceptedService::new(inner, interceptor)) + } + /// Compress requests with the given encoding. + /// + /// This requires the server to support it otherwise it might respond with an + /// error. + #[must_use] + pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.inner = self.inner.send_compressed(encoding); + self + } + /// Enable decompressing responses. + #[must_use] + pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.inner = self.inner.accept_compressed(encoding); + self + } + /// Limits the maximum size of a decoded message. + /// + /// Default: `4MB` + #[must_use] + pub fn max_decoding_message_size(mut self, limit: usize) -> Self { + self.inner = self.inner.max_decoding_message_size(limit); + self + } + /// Limits the maximum size of an encoded message. + /// + /// Default: `usize::MAX` + #[must_use] + pub fn max_encoding_message_size(mut self, limit: usize) -> Self { + self.inner = self.inner.max_encoding_message_size(limit); + self + } + pub async fn shell( + &mut self, + request: impl tonic::IntoStreamingRequest, + ) -> std::result::Result< + tonic::Response>, + tonic::Status, + > { + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static( + "/admin_shell.AdminShellService/Shell", + ); + let mut req = request.into_streaming_request(); + req.extensions_mut() + .insert(GrpcMethod::new("admin_shell.AdminShellService", "Shell")); + self.inner.streaming(req, path, codec).await + } + } +} +/// Generated server implementations. +pub mod admin_shell_service_server { + #![allow(unused_variables, dead_code, missing_docs, clippy::let_unit_value)] + use tonic::codegen::*; + /// Generated trait containing gRPC methods that should be implemented for use with AdminShellServiceServer. + #[async_trait] + pub trait AdminShellService: Send + Sync + 'static { + /// Server streaming response type for the Shell method. + type ShellStream: tonic::codegen::tokio_stream::Stream< + Item = std::result::Result, + > + + Send + + 'static; + async fn shell( + &self, + request: tonic::Request>, + ) -> std::result::Result, tonic::Status>; + } + #[derive(Debug)] + pub struct AdminShellServiceServer { + inner: _Inner, + accept_compression_encodings: EnabledCompressionEncodings, + send_compression_encodings: EnabledCompressionEncodings, + max_decoding_message_size: Option, + max_encoding_message_size: Option, + } + struct _Inner(Arc); + impl AdminShellServiceServer { + pub fn new(inner: T) -> Self { + Self::from_arc(Arc::new(inner)) + } + pub fn from_arc(inner: Arc) -> Self { + let inner = _Inner(inner); + Self { + inner, + accept_compression_encodings: Default::default(), + send_compression_encodings: Default::default(), + max_decoding_message_size: None, + max_encoding_message_size: None, + } + } + pub fn with_interceptor( + inner: T, + interceptor: F, + ) -> InterceptedService + where + F: tonic::service::Interceptor, + { + InterceptedService::new(Self::new(inner), interceptor) + } + /// Enable decompressing requests with the given encoding. + #[must_use] + pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.accept_compression_encodings.enable(encoding); + self + } + /// Compress responses with the given encoding, if the client supports it. + #[must_use] + pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.send_compression_encodings.enable(encoding); + self + } + /// Limits the maximum size of a decoded message. + /// + /// Default: `4MB` + #[must_use] + pub fn max_decoding_message_size(mut self, limit: usize) -> Self { + self.max_decoding_message_size = Some(limit); + self + } + /// Limits the maximum size of an encoded message. + /// + /// Default: `usize::MAX` + #[must_use] + pub fn max_encoding_message_size(mut self, limit: usize) -> Self { + self.max_encoding_message_size = Some(limit); + self + } + } + impl tonic::codegen::Service> for AdminShellServiceServer + where + T: AdminShellService, + B: Body + Send + 'static, + B::Error: Into + Send + 'static, + { + type Response = http::Response; + type Error = std::convert::Infallible; + type Future = BoxFuture; + fn poll_ready( + &mut self, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + fn call(&mut self, req: http::Request) -> Self::Future { + let inner = self.inner.clone(); + match req.uri().path() { + "/admin_shell.AdminShellService/Shell" => { + #[allow(non_camel_case_types)] + struct ShellSvc(pub Arc); + impl< + T: AdminShellService, + > tonic::server::StreamingService for ShellSvc { + type Response = super::Response; + type ResponseStream = T::ShellStream; + type Future = BoxFuture< + tonic::Response, + tonic::Status, + >; + fn call( + &mut self, + request: tonic::Request>, + ) -> Self::Future { + let inner = Arc::clone(&self.0); + let fut = async move { + ::shell(&inner, request).await + }; + Box::pin(fut) + } + } + let accept_compression_encodings = self.accept_compression_encodings; + let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.clone(); + let fut = async move { + let inner = inner.0; + let method = ShellSvc(inner); + let codec = tonic::codec::ProstCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec) + .apply_compression_config( + accept_compression_encodings, + send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.streaming(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + _ => { + Box::pin(async move { + Ok( + http::Response::builder() + .status(200) + .header("grpc-status", "12") + .header("content-type", "application/grpc") + .body(empty_body()) + .unwrap(), + ) + }) + } + } + } + } + impl Clone for AdminShellServiceServer { + fn clone(&self) -> Self { + let inner = self.inner.clone(); + Self { + inner, + accept_compression_encodings: self.accept_compression_encodings, + send_compression_encodings: self.send_compression_encodings, + max_decoding_message_size: self.max_decoding_message_size, + max_encoding_message_size: self.max_encoding_message_size, + } + } + } + impl Clone for _Inner { + fn clone(&self) -> Self { + Self(Arc::clone(&self.0)) + } + } + impl std::fmt::Debug for _Inner { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self.0) + } + } + impl tonic::server::NamedService + for AdminShellServiceServer { + const NAME: &'static str = "admin_shell.AdminShellService"; + } +} diff --git a/libsql-server/src/http/admin/mod.rs b/libsql-server/src/http/admin/mod.rs index 63423bd4e3..73ad051b01 100644 --- a/libsql-server/src/http/admin/mod.rs +++ b/libsql-server/src/http/admin/mod.rs @@ -31,6 +31,7 @@ use crate::net::Connector; use crate::LIBSQL_PAGE_SIZE; pub mod stats; +mod admin_shell; #[derive(Clone)] struct Metrics { diff --git a/libsql-server/tests/bootstrap.rs b/libsql-server/tests/bootstrap.rs new file mode 100644 index 0000000000..82344b22cb --- /dev/null +++ b/libsql-server/tests/bootstrap.rs @@ -0,0 +1,34 @@ +use std::process::Command; +use std::path::PathBuf; + +#[test] +fn bootstrap() { + let iface_files = &[ + "proto/admin_shell.proto", + ]; + let dirs = &["proto"]; + + let out_dir = PathBuf::from(std::env!("CARGO_MANIFEST_DIR")) + .join("src") + .join("generated"); + + let config = prost_build::Config::new(); + + tonic_build::configure() + .build_client(true) + .build_server(true) + .build_transport(true) + .out_dir(&out_dir) + .compile_with_config(config, iface_files, dirs) + .unwrap(); + + let status = Command::new("git") + .arg("diff") + .arg("--exit-code") + .arg("--") + .arg(&out_dir) + .status() + .unwrap(); + + assert!(status.success(), "You should commit the protobuf files"); +} From bd06333f5c3b505a63adb400141d4a2109228532 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Mon, 2 Sep 2024 13:52:42 +0200 Subject: [PATCH 2/5] instroduce admin shell and client --- Cargo.lock | 21 +++++++++++++++++++++ libsql-server/Cargo.toml | 1 + libsql-server/src/admin_shell.rs | 4 ++++ libsql-server/src/http/admin/mod.rs | 11 +++++++++-- libsql-server/src/lib.rs | 1 + 5 files changed, 36 insertions(+), 2 deletions(-) create mode 100644 libsql-server/src/admin_shell.rs diff --git a/Cargo.lock b/Cargo.lock index 9faa2c69f4..bf09fa2edd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1347,6 +1347,7 @@ dependencies = [ "encode_unicode", "lazy_static", "libc", + "unicode-width", "windows-sys 0.52.0", ] @@ -1822,6 +1823,19 @@ dependencies = [ "syn 2.0.70", ] +[[package]] +name = "dialoguer" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "658bce805d770f407bc62102fca7c2c64ceef2fbcb2b8bd19d2765ce093980de" +dependencies = [ + "console", + "shell-words", + "tempfile", + "thiserror", + "zeroize", +] + [[package]] name = "digest" version = "0.10.7" @@ -3167,6 +3181,7 @@ dependencies = [ "console-subscriber", "crc", "crossbeam", + "dialoguer", "enclose", "env_logger", "fallible-iterator 0.3.0", @@ -5218,6 +5233,12 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "shell-words" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24188a676b6ae68c3b2cb3a01be17fbf7240ce009799bb56d5b1409051e78fde" + [[package]] name = "shellexpand" version = "2.1.2" diff --git a/libsql-server/Cargo.toml b/libsql-server/Cargo.toml index b117d8ac47..edeb5988d4 100644 --- a/libsql-server/Cargo.toml +++ b/libsql-server/Cargo.toml @@ -95,6 +95,7 @@ tar = "0.4.41" aws-config = "1" aws-sdk-s3 = "1" aws-smithy-runtime = "1.6.2" +dialoguer = { version = "0.11.0", features = ["history"] } [dev-dependencies] arbitrary = { version = "1.3.0", features = ["derive_arbitrary"] } diff --git a/libsql-server/src/admin_shell.rs b/libsql-server/src/admin_shell.rs new file mode 100644 index 0000000000..7b25ca6883 --- /dev/null +++ b/libsql-server/src/admin_shell.rs @@ -0,0 +1,4 @@ +use std::fmt::Display; +use std::pin::Pin; + +use bytes::Bytes; diff --git a/libsql-server/src/http/admin/mod.rs b/libsql-server/src/http/admin/mod.rs index 73ad051b01..00fdc6d13a 100644 --- a/libsql-server/src/http/admin/mod.rs +++ b/libsql-server/src/http/admin/mod.rs @@ -31,7 +31,6 @@ use crate::net::Connector; use crate::LIBSQL_PAGE_SIZE; pub mod stats; -mod admin_shell; #[derive(Clone)] struct Metrics { @@ -169,7 +168,7 @@ where .route("/profile/heap/disable/:id", post(disable_profile_heap)) .route("/profile/heap/:id", delete(delete_profile_heap)) .with_state(Arc::new(AppState { - namespaces, + namespaces: namespaces.clone(), connector, user_http_server, metrics, @@ -185,6 +184,14 @@ where ) .layer(axum::middleware::from_fn_with_state(auth, auth_middleware)); + let admin_shell = crate::admin_shell::make_svc(namespaces.clone()); + let grpc_router = tonic::transport::Server::builder() + .accept_http1(true) + .add_service(tonic_web::enable(admin_shell)) + .into_router(); + + let router = router.merge(grpc_router); + hyper::server::Server::builder(acceptor) .serve(router.into_make_service()) .with_graceful_shutdown(shutdown.notified()) diff --git a/libsql-server/src/lib.rs b/libsql-server/src/lib.rs index 3263db560b..dd752f17a6 100644 --- a/libsql-server/src/lib.rs +++ b/libsql-server/src/lib.rs @@ -84,6 +84,7 @@ use self::net::AddrIncoming; use self::replication::script_backup_manager::{CommandHandler, ScriptBackupManager}; use self::schema::SchedulerHandle; +pub mod admin_shell; pub mod auth; mod broadcaster; pub mod config; From a76b7dbcf96c171e2681adabe5d61a11758b25ed Mon Sep 17 00:00:00 2001 From: ad hoc Date: Mon, 2 Sep 2024 13:53:28 +0200 Subject: [PATCH 3/5] fixup! generate admin_shell service --- libsql-server/tests/bootstrap.rs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/libsql-server/tests/bootstrap.rs b/libsql-server/tests/bootstrap.rs index 82344b22cb..a464f53288 100644 --- a/libsql-server/tests/bootstrap.rs +++ b/libsql-server/tests/bootstrap.rs @@ -1,11 +1,9 @@ -use std::process::Command; use std::path::PathBuf; +use std::process::Command; #[test] fn bootstrap() { - let iface_files = &[ - "proto/admin_shell.proto", - ]; + let iface_files = &["proto/admin_shell.proto"]; let dirs = &["proto"]; let out_dir = PathBuf::from(std::env!("CARGO_MANIFEST_DIR")) From e38cc1e1ed63851c5836fd70a9ef4e68007d02f9 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Mon, 2 Sep 2024 13:53:32 +0200 Subject: [PATCH 4/5] add admin-shell subcommand --- libsql-server/src/admin_shell.rs | 230 +++++++++++++++++++++++++++++++ libsql-server/src/main.rs | 33 ++++- 2 files changed, 256 insertions(+), 7 deletions(-) diff --git a/libsql-server/src/admin_shell.rs b/libsql-server/src/admin_shell.rs index 7b25ca6883..a8d750106a 100644 --- a/libsql-server/src/admin_shell.rs +++ b/libsql-server/src/admin_shell.rs @@ -2,3 +2,233 @@ use std::fmt::Display; use std::pin::Pin; use bytes::Bytes; +use dialoguer::BasicHistory; +use rusqlite::types::ValueRef; +use tokio_stream::{Stream, StreamExt as _}; +use tonic::metadata::BinaryMetadataValue; + +use crate::connection::Connection as _; +use crate::database::Connection; +use crate::namespace::{NamespaceName, NamespaceStore}; + +use self::rpc::admin_shell_service_server::{AdminShellService, AdminShellServiceServer}; +use self::rpc::response::Resp; +use self::rpc::Null; + +mod rpc { + #![allow(clippy::all)] + include!("generated/admin_shell.rs"); +} + +pub(crate) fn make_svc(namespace_store: NamespaceStore) -> AdminShellServiceServer { + let admin_shell = AdminShell::new(namespace_store); + rpc::admin_shell_service_server::AdminShellServiceServer::new(admin_shell) +} + +pub(super) struct AdminShell { + namespace_store: NamespaceStore, +} + +impl AdminShell { + fn new(namespace_store: NamespaceStore) -> Self { + Self { namespace_store } + } + + async fn with_namespace( + &self, + ns: Bytes, + queries: impl Stream>, + ) -> anyhow::Result>> { + let namespace = NamespaceName::from_bytes(ns).unwrap(); + let connection_maker = self + .namespace_store + .with(namespace, |ns| ns.db.connection_maker()) + .await?; + let connection = connection_maker.create().await?; + Ok(run_shell(connection, queries)) + } +} + +fn run_shell( + conn: Connection, + queries: impl Stream>, +) -> impl Stream> { + async_stream::stream! { + tokio::pin!(queries); + while let Some(q) = queries.next().await { + let Ok(q) = q else { break }; + let res = tokio::task::block_in_place(|| { + conn.with_raw(move |conn| { + run_one(conn, q.query) + }) + }); + + yield res + } + } +} + +fn run_one(conn: &mut rusqlite::Connection, q: String) -> Result { + match try_run_one(conn, q) { + Ok(resp) => Ok(resp), + Err(e) => Ok(rpc::Response { + resp: Some(Resp::Error(rpc::Error { + error: e.to_string(), + })), + }), + } +} + +fn try_run_one(conn: &mut rusqlite::Connection, q: String) -> anyhow::Result { + let mut stmt = conn.prepare(&q)?; + let col_count = stmt.column_count(); + let mut rows = stmt.query(())?; + let mut out_rows = Vec::new(); + while let Some(row) = rows.next()? { + let mut out_row = Vec::with_capacity(col_count); + for i in 0..col_count { + let rpc_value = match row.get_ref(i).unwrap() { + ValueRef::Null => rpc::value::Value::Null(Null {}), + ValueRef::Integer(i) => rpc::value::Value::Integer(i), + ValueRef::Real(x) => rpc::value::Value::Real(x), + ValueRef::Text(s) => rpc::value::Value::Text(String::from_utf8(s.to_vec())?), + ValueRef::Blob(b) => rpc::value::Value::Blob(b.to_vec()), + }; + out_row.push(rpc::Value { + value: Some(rpc_value), + }); + } + out_rows.push(rpc::Row { values: out_row }); + } + + Ok(rpc::Response { + resp: Some(Resp::Rows(rpc::Rows { rows: out_rows })), + }) +} + +#[async_trait::async_trait] +impl AdminShellService for AdminShell { + type ShellStream = Pin> + Send>>; + + async fn shell( + &self, + request: tonic::Request>, + ) -> std::result::Result, tonic::Status> { + let Some(namespace) = request.metadata().get_bin("x-namespace-bin") else { + return Err(tonic::Status::new( + tonic::Code::InvalidArgument, + "missing namespace", + )); + }; + let Ok(ns_bytes) = namespace.to_bytes() else { + return Err(tonic::Status::new( + tonic::Code::InvalidArgument, + "bad namespace encoding", + )); + }; + + match self.with_namespace(ns_bytes, request.into_inner()).await { + Ok(s) => Ok(tonic::Response::new(Box::pin(s))), + Err(e) => Err(tonic::Status::new( + tonic::Code::FailedPrecondition, + e.to_string(), + )), + } + } +} + +pub struct AdminShellClient { + remote_url: String, +} + +impl AdminShellClient { + pub fn new(remote_url: String) -> Self { + Self { remote_url } + } + + pub async fn run_namespace(&self, namespace: &str) -> anyhow::Result<()> { + let namespace = NamespaceName::from_string(namespace.to_string())?; + let mut client = rpc::admin_shell_service_client::AdminShellServiceClient::connect( + self.remote_url.clone(), + ) + .await?; + let (sender, receiver) = tokio::sync::mpsc::channel(1); + let req_stream = tokio_stream::wrappers::ReceiverStream::new(receiver); + + let mut req = tonic::Request::new(req_stream); + req.metadata_mut().insert_bin( + "x-namespace-bin", + BinaryMetadataValue::from_bytes(namespace.as_slice()), + ); + let mut resp_stream = client.shell(req).await?.into_inner(); + + let mut history = BasicHistory::new(); + loop { + // this is blocking, but the shell runs in it's own process with no other tasks, so + // that's ok + let prompt = dialoguer::Input::::new() + .with_prompt("> ") + .history_with(&mut history) + .interact_text(); + + match prompt { + Ok(query) => { + let q = rpc::Query { query }; + sender.send(q).await?; + match resp_stream.next().await { + Some(Ok(rpc::Response { + resp: Some(rpc::response::Resp::Rows(rows)), + })) => { + println!("{}", RowsFormatter(rows)); + } + Some(Ok(rpc::Response { + resp: Some(rpc::response::Resp::Error(rpc::Error { error })), + })) => { + println!("query error: {error}"); + } + Some(Err(e)) => { + println!("rpc error: {}", e.message()); + break; + } + _ => break, + } + } + Err(e) => println!("error: {e}"), + } + } + + Ok(()) + } +} + +struct RowsFormatter(rpc::Rows); + +impl Display for RowsFormatter { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + for row in self.0.rows.iter() { + let mut is_first = true; + for value in row.values.iter() { + if !is_first { + f.write_str(", ")?; + } + is_first = false; + + match value.value.as_ref().unwrap() { + rpc::value::Value::Null(_) => f.write_str("null")?, + rpc::value::Value::Real(x) => write!(f, "{x}")?, + rpc::value::Value::Integer(i) => write!(f, "{i}")?, + rpc::value::Value::Text(s) => f.write_str(&s)?, + rpc::value::Value::Blob(b) => { + for x in b { + write!(f, "{x:0x}")? + } + } + } + } + + writeln!(f)?; + } + + Ok(()) + } +} diff --git a/libsql-server/src/main.rs b/libsql-server/src/main.rs index 7295bf9d16..70c23728fe 100644 --- a/libsql-server/src/main.rs +++ b/libsql-server/src/main.rs @@ -301,16 +301,18 @@ struct Cli { default_value = "8" )] sync_conccurency: usize, + + #[clap(subcommand)] + subcommand: Option, } #[derive(clap::Subcommand, Debug)] enum UtilsSubcommands { - Dump { - #[clap(long)] - /// Path at which to write the dump - path: Option, + AdminShell { + #[clap(long, default_value = "http://127.0.0.1:9090")] + admin_api_url: String, #[clap(long)] - namespace: String, + namespace: Option, }, } @@ -710,6 +712,25 @@ async fn build_server(config: &Cli) -> anyhow::Result { #[tokio::main] async fn main() -> Result<()> { + let args = Cli::parse(); + + if let Some(ref subcommand) = args.subcommand { + match subcommand { + UtilsSubcommands::AdminShell { + admin_api_url, + namespace, + } => { + let client = + libsql_server::admin_shell::AdminShellClient::new(admin_api_url.clone()); + if let Some(ns) = namespace { + client.run_namespace(ns).await?; + } + } + } + + return Ok(()); + } + if std::env::var("RUST_LOG").is_err() { std::env::set_var("RUST_LOG", "info"); } @@ -730,8 +751,6 @@ async fn main() -> Result<()> { ) .init(); - let args = Cli::parse(); - args.print_welcome_message(); let server = build_server(&args).await?; server.start().await?; From 85d677bc30245f72d2ea4ba820caada72eb4215a Mon Sep 17 00:00:00 2001 From: ad hoc Date: Tue, 3 Sep 2024 19:24:17 +0200 Subject: [PATCH 5/5] handle auth --- libsql-server/src/admin_shell.rs | 16 +++++++++++++--- libsql-server/src/http/admin/mod.rs | 7 ++++--- libsql-server/src/main.rs | 9 +++++++-- 3 files changed, 24 insertions(+), 8 deletions(-) diff --git a/libsql-server/src/admin_shell.rs b/libsql-server/src/admin_shell.rs index a8d750106a..e97b272d72 100644 --- a/libsql-server/src/admin_shell.rs +++ b/libsql-server/src/admin_shell.rs @@ -1,11 +1,12 @@ use std::fmt::Display; use std::pin::Pin; +use std::str::FromStr; use bytes::Bytes; use dialoguer::BasicHistory; use rusqlite::types::ValueRef; use tokio_stream::{Stream, StreamExt as _}; -use tonic::metadata::BinaryMetadataValue; +use tonic::metadata::{AsciiMetadataValue, BinaryMetadataValue}; use crate::connection::Connection as _; use crate::database::Connection; @@ -139,11 +140,12 @@ impl AdminShellService for AdminShell { pub struct AdminShellClient { remote_url: String, + auth: Option, } impl AdminShellClient { - pub fn new(remote_url: String) -> Self { - Self { remote_url } + pub fn new(remote_url: String, auth: Option) -> Self { + Self { remote_url, auth } } pub async fn run_namespace(&self, namespace: &str) -> anyhow::Result<()> { @@ -160,6 +162,14 @@ impl AdminShellClient { "x-namespace-bin", BinaryMetadataValue::from_bytes(namespace.as_slice()), ); + + if let Some(ref auth) = self.auth { + req.metadata_mut().insert( + "authorization", + AsciiMetadataValue::from_str(&format!("basic {auth}")).unwrap(), + ); + } + let mut resp_stream = client.shell(req).await?.into_inner(); let mut history = BasicHistory::new(); diff --git a/libsql-server/src/http/admin/mod.rs b/libsql-server/src/http/admin/mod.rs index 00fdc6d13a..4b774afa0d 100644 --- a/libsql-server/src/http/admin/mod.rs +++ b/libsql-server/src/http/admin/mod.rs @@ -181,8 +181,7 @@ where .level(tracing::Level::DEBUG) .latency_unit(tower_http::LatencyUnit::Micros), ), - ) - .layer(axum::middleware::from_fn_with_state(auth, auth_middleware)); + ); let admin_shell = crate::admin_shell::make_svc(namespaces.clone()); let grpc_router = tonic::transport::Server::builder() @@ -190,7 +189,9 @@ where .add_service(tonic_web::enable(admin_shell)) .into_router(); - let router = router.merge(grpc_router); + let router = router + .merge(grpc_router) + .layer(axum::middleware::from_fn_with_state(auth, auth_middleware)); hyper::server::Server::builder(acceptor) .serve(router.into_make_service()) diff --git a/libsql-server/src/main.rs b/libsql-server/src/main.rs index 70c23728fe..9d505f5dfa 100644 --- a/libsql-server/src/main.rs +++ b/libsql-server/src/main.rs @@ -313,6 +313,8 @@ enum UtilsSubcommands { admin_api_url: String, #[clap(long)] namespace: Option, + #[clap(long)] + auth: Option, }, } @@ -719,9 +721,12 @@ async fn main() -> Result<()> { UtilsSubcommands::AdminShell { admin_api_url, namespace, + auth, } => { - let client = - libsql_server::admin_shell::AdminShellClient::new(admin_api_url.clone()); + let client = libsql_server::admin_shell::AdminShellClient::new( + admin_api_url.clone(), + auth.clone(), + ); if let Some(ns) = namespace { client.run_namespace(ns).await?; }