diff --git a/libsql-server/src/config.rs b/libsql-server/src/config.rs index 49320258b9..795323366e 100644 --- a/libsql-server/src/config.rs +++ b/libsql-server/src/config.rs @@ -83,6 +83,7 @@ pub struct AdminApiConfig> { pub acceptor: A, pub connector: C, pub disable_metrics: bool, + pub auth_key: Option, } #[derive(Clone)] diff --git a/libsql-server/src/http/admin/mod.rs b/libsql-server/src/http/admin/mod.rs index 908b66e3f2..63423bd4e3 100644 --- a/libsql-server/src/http/admin/mod.rs +++ b/libsql-server/src/http/admin/mod.rs @@ -1,11 +1,12 @@ use anyhow::Context as _; use axum::body::StreamBody; use axum::extract::{FromRef, Path, State}; +use axum::middleware::Next; use axum::routing::delete; use axum::Json; use chrono::NaiveDateTime; use futures::{SinkExt, StreamExt, TryStreamExt}; -use hyper::{Body, Request}; +use hyper::{Body, Request, StatusCode}; use metrics_exporter_prometheus::{PrometheusBuilder, PrometheusHandle}; use parking_lot::Mutex; use serde::{Deserialize, Serialize}; @@ -64,6 +65,7 @@ pub async fn run( connector: C, disable_metrics: bool, shutdown: Arc, + auth: Option>, ) -> anyhow::Result<()> where A: crate::net::Accept, @@ -162,15 +164,15 @@ where ) .route("/v1/diagnostics", get(handle_diagnostics)) .route("/metrics", get(handle_metrics)) + .route("/profile/heap/enable", post(enable_profile_heap)) + .route("/profile/heap/disable/:id", post(disable_profile_heap)) + .route("/profile/heap/:id", delete(delete_profile_heap)) .with_state(Arc::new(AppState { namespaces, connector, user_http_server, metrics, })) - .route("/profile/heap/enable", post(enable_profile_heap)) - .route("/profile/heap/disable/:id", post(disable_profile_heap)) - .route("/profile/heap/:id", delete(delete_profile_heap)) .layer( tower_http::trace::TraceLayer::new_for_http() .on_request(trace_request) @@ -179,7 +181,8 @@ where .level(tracing::Level::DEBUG) .latency_unit(tower_http::LatencyUnit::Micros), ), - ); + ) + .layer(axum::middleware::from_fn_with_state(auth, auth_middleware)); hyper::server::Server::builder(acceptor) .serve(router.into_make_service()) @@ -190,6 +193,34 @@ where Ok(()) } +async fn auth_middleware( + State(auth): State>>, + request: Request, + next: Next, +) -> Result { + if let Some(ref auth) = auth { + let Some(auth_header) = request.headers().get("authorization") else { + return Err(StatusCode::UNAUTHORIZED); + }; + let Ok(auth_str) = std::str::from_utf8(auth_header.as_bytes()) else { + return Err(StatusCode::UNAUTHORIZED); + }; + + let mut split = auth_str.split_whitespace(); + match split.next() { + Some(s) if s.trim().eq_ignore_ascii_case("basic") => (), + _ => return Err(StatusCode::UNAUTHORIZED), + } + + match split.next() { + Some(s) if s.trim() == auth.as_ref() => (), + _ => return Err(StatusCode::UNAUTHORIZED), + } + } + + Ok(next.run(request).await) +} + async fn handle_get_index() -> &'static str { "Welcome to the sqld admin API" } diff --git a/libsql-server/src/lib.rs b/libsql-server/src/lib.rs index 3c0892c6ce..cd415eef33 100644 --- a/libsql-server/src/lib.rs +++ b/libsql-server/src/lib.rs @@ -320,6 +320,7 @@ where acceptor, connector, disable_metrics, + auth_key, }) = self.admin_api_config { task_manager.spawn_with_shutdown_notify(|shutdown| { @@ -330,6 +331,7 @@ where connector, disable_metrics, shutdown, + auth_key.map(Into::into), ) }); } diff --git a/libsql-server/src/main.rs b/libsql-server/src/main.rs index 038d34b22d..eb24126d6e 100644 --- a/libsql-server/src/main.rs +++ b/libsql-server/src/main.rs @@ -276,6 +276,10 @@ struct Cli { /// Enables the main runtime deadlock monitor: if the main runtime deadlocks, logs an error #[clap(long)] enable_deadlock_monitor: bool, + + /// Auth key for the admin API + #[clap(long, env = "LIBSQL_ADMIN_AUTH_KEY", requires = "admin_listen_addr")] + admin_auth_key: Option, } #[derive(clap::Subcommand, Debug)] @@ -468,6 +472,7 @@ async fn make_admin_api_config(config: &Cli) -> anyhow::Result Ok(None), diff --git a/libsql-server/tests/cluster/mod.rs b/libsql-server/tests/cluster/mod.rs index 46f9801dff..a8ae930613 100644 --- a/libsql-server/tests/cluster/mod.rs +++ b/libsql-server/tests/cluster/mod.rs @@ -34,6 +34,7 @@ pub fn make_cluster(sim: &mut Sim, num_replica: usize, disable_namespaces: bool) acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 9090)).await?, connector: TurmoilConnector, disable_metrics: true, + auth_key: None, }), rpc_server_config: Some(RpcServerConfig { acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 4567)).await?, @@ -64,6 +65,7 @@ pub fn make_cluster(sim: &mut Sim, num_replica: usize, disable_namespaces: bool) acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 9090)).await?, connector: TurmoilConnector, disable_metrics: true, + auth_key: None, }), rpc_client_config: Some(RpcClientConfig { remote_url: "http://primary:4567".into(), diff --git a/libsql-server/tests/cluster/replica_restart.rs b/libsql-server/tests/cluster/replica_restart.rs index 11c78d8ced..e8bcd21fcd 100644 --- a/libsql-server/tests/cluster/replica_restart.rs +++ b/libsql-server/tests/cluster/replica_restart.rs @@ -34,6 +34,7 @@ fn replica_restart() { acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 9090)).await?, connector: TurmoilConnector, disable_metrics: true, + auth_key: None, }), rpc_server_config: Some(RpcServerConfig { acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 4567)).await?, @@ -67,6 +68,7 @@ fn replica_restart() { acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 9090)).await.unwrap(), connector: TurmoilConnector, disable_metrics: true, + auth_key: None, }), rpc_client_config: Some(RpcClientConfig { remote_url: "http://primary:4567".into(), @@ -187,6 +189,7 @@ fn primary_regenerate_log_no_replica_restart() { acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 9090)).await.unwrap(), connector: TurmoilConnector, disable_metrics: true, + auth_key: None, }), rpc_server_config: Some(RpcServerConfig { acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 4567)).await.unwrap(), @@ -241,6 +244,7 @@ fn primary_regenerate_log_no_replica_restart() { acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 9090)).await.unwrap(), connector: TurmoilConnector, disable_metrics: true, + auth_key: None, }), rpc_client_config: Some(RpcClientConfig { remote_url: "http://primary:4567".into(), @@ -365,6 +369,7 @@ fn primary_regenerate_log_with_replica_restart() { acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 9090)).await.unwrap(), connector: TurmoilConnector, disable_metrics: true, + auth_key: None, }), rpc_server_config: Some(RpcServerConfig { acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 4567)).await.unwrap(), @@ -421,6 +426,7 @@ fn primary_regenerate_log_with_replica_restart() { acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 9090)).await.unwrap(), connector: TurmoilConnector, disable_metrics: true, + auth_key: None, }), rpc_client_config: Some(RpcClientConfig { remote_url: "http://primary:4567".into(), diff --git a/libsql-server/tests/cluster/replication.rs b/libsql-server/tests/cluster/replication.rs index 920614fea3..f77ea01d19 100644 --- a/libsql-server/tests/cluster/replication.rs +++ b/libsql-server/tests/cluster/replication.rs @@ -40,6 +40,7 @@ fn apply_partial_snapshot() { acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 9090)).await.unwrap(), connector: TurmoilConnector, disable_metrics: true, + auth_key: None, }), rpc_server_config: Some(RpcServerConfig { acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 5050)).await.unwrap(), @@ -71,6 +72,7 @@ fn apply_partial_snapshot() { acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 9090)).await.unwrap(), connector: TurmoilConnector, disable_metrics: true, + auth_key: None, }), rpc_client_config: Some(RpcClientConfig { remote_url: "http://primary:5050".into(), @@ -167,6 +169,7 @@ fn replica_lazy_creation() { acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 9090)).await.unwrap(), connector: TurmoilConnector, disable_metrics: true, + auth_key: None, }), rpc_server_config: Some(RpcServerConfig { acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 5050)).await.unwrap(), @@ -197,6 +200,7 @@ fn replica_lazy_creation() { acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 9090)).await.unwrap(), connector: TurmoilConnector, disable_metrics: true, + auth_key: None, }), rpc_client_config: Some(RpcClientConfig { remote_url: "http://primary:5050".into(), diff --git a/libsql-server/tests/common/http.rs b/libsql-server/tests/common/http.rs index b746de9782..8716a60503 100644 --- a/libsql-server/tests/common/http.rs +++ b/libsql-server/tests/common/http.rs @@ -1,3 +1,4 @@ +use axum::http::HeaderName; use bytes::Bytes; use hyper::Body; use serde::{de::DeserializeOwned, Serialize}; @@ -41,11 +42,27 @@ impl Client { } pub(crate) async fn post(&self, url: &str, body: T) -> anyhow::Result { + self.post_with_headers(url, &[], body).await + } + + pub(crate) async fn post_with_headers( + &self, + url: &str, + headers: &[(HeaderName, &str)], + body: T, + ) -> anyhow::Result { let bytes: Bytes = serde_json::to_vec(&body)?.into(); let body = Body::from(bytes); - let request = hyper::Request::post(url) + let mut request = hyper::Request::post(url) .header("Content-Type", "application/json") .body(body)?; + + for (key, val) in headers { + request + .headers_mut() + .insert(key.clone(), val.parse().unwrap()); + } + let resp = self.0.request(request).await?; if resp.status().is_server_error() { diff --git a/libsql-server/tests/embedded_replica/mod.rs b/libsql-server/tests/embedded_replica/mod.rs index bf5296ecb3..0288c9a117 100644 --- a/libsql-server/tests/embedded_replica/mod.rs +++ b/libsql-server/tests/embedded_replica/mod.rs @@ -55,6 +55,7 @@ fn make_primary(sim: &mut Sim, path: PathBuf) { acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 9090)).await?, connector: TurmoilConnector, disable_metrics: false, + auth_key: None, }), rpc_server_config: Some(RpcServerConfig { acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 4567)).await?, @@ -408,6 +409,7 @@ fn replica_primary_reset() { acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 9090)).await.unwrap(), connector: TurmoilConnector, disable_metrics: true, + auth_key: None, }), rpc_server_config: Some(RpcServerConfig { acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 4567)).await.unwrap(), @@ -692,6 +694,7 @@ fn replicate_with_snapshots() { acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 9090)).await.unwrap(), connector: TurmoilConnector, disable_metrics: true, + auth_key: None, }), rpc_server_config: Some(RpcServerConfig { acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 4567)).await.unwrap(), @@ -1266,6 +1269,7 @@ fn replicated_return() { acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 9090)).await.unwrap(), connector: TurmoilConnector, disable_metrics: true, + auth_key: None, }), rpc_server_config: Some(RpcServerConfig { acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 4567)).await.unwrap(), @@ -1395,6 +1399,7 @@ fn replicate_auth() { acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 9090)).await?, connector: TurmoilConnector, disable_metrics: true, + auth_key: None, }), rpc_server_config: Some(RpcServerConfig { acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 4567)).await?, @@ -1430,6 +1435,7 @@ fn replicate_auth() { acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 9090)).await?, connector: TurmoilConnector, disable_metrics: true, + auth_key: None, }), rpc_client_config: Some(RpcClientConfig { remote_url: "http://primary:4567".into(), diff --git a/libsql-server/tests/namespaces/mod.rs b/libsql-server/tests/namespaces/mod.rs index 2979991b1a..37b373b76e 100644 --- a/libsql-server/tests/namespaces/mod.rs +++ b/libsql-server/tests/namespaces/mod.rs @@ -29,6 +29,7 @@ fn make_primary(sim: &mut Sim, path: PathBuf) { acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 9090)).await?, connector: TurmoilConnector, disable_metrics: true, + auth_key: None, }), rpc_server_config: Some(RpcServerConfig { acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 4567)).await?, diff --git a/libsql-server/tests/standalone/admin.rs b/libsql-server/tests/standalone/admin.rs new file mode 100644 index 0000000000..ed70315c6b --- /dev/null +++ b/libsql-server/tests/standalone/admin.rs @@ -0,0 +1,67 @@ +use std::time::Duration; + +use hyper::StatusCode; +use libsql_server::config::{AdminApiConfig, UserApiConfig}; +use s3s::header::AUTHORIZATION; +use serde_json::json; +use tempfile::tempdir; + +use crate::common::{ + http::Client, + net::{SimServer as _, TestServer, TurmoilAcceptor, TurmoilConnector}, +}; + +#[test] +fn admin_auth() { + let mut sim = turmoil::Builder::new() + .simulation_duration(Duration::from_secs(1000)) + .build(); + + sim.host("primary", || async move { + let tmp = tempdir().unwrap(); + let server = TestServer { + path: tmp.path().to_owned().into(), + user_api_config: UserApiConfig { + hrana_ws_acceptor: None, + ..Default::default() + }, + admin_api_config: Some(AdminApiConfig { + acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 9090)).await.unwrap(), + connector: TurmoilConnector, + disable_metrics: true, + auth_key: Some("secretkey".into()), + }), + disable_namespaces: false, + ..Default::default() + }; + server.start_sim(8080).await?; + Ok(()) + }); + + sim.client("test", async { + let client = Client::new(); + + assert_eq!( + client + .post("http://primary:9090/v1/namespaces/foo/create", json!({})) + .await + .unwrap() + .status(), + StatusCode::UNAUTHORIZED + ); + assert!(client + .post_with_headers( + "http://primary:9090/v1/namespaces/foo/create", + &[(AUTHORIZATION, "basic secretkey")], + json!({}) + ) + .await + .unwrap() + .status() + .is_success()); + + Ok(()) + }); + + sim.run().unwrap(); +} diff --git a/libsql-server/tests/standalone/mod.rs b/libsql-server/tests/standalone/mod.rs index 1cc82441ca..0b3c631e75 100644 --- a/libsql-server/tests/standalone/mod.rs +++ b/libsql-server/tests/standalone/mod.rs @@ -17,6 +17,7 @@ use libsql_server::config::{AdminApiConfig, UserApiConfig}; use common::net::{init_tracing, TestServer, TurmoilConnector}; +mod admin; mod attach; mod auth; @@ -33,6 +34,7 @@ async fn make_standalone_server() -> Result<(), Box> { acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 9090)).await.unwrap(), connector: TurmoilConnector, disable_metrics: true, + auth_key: None, }), disable_namespaces: false, ..Default::default() @@ -355,6 +357,7 @@ fn dirty_startup_dont_prevent_namespace_creation() { acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 9090)).await.unwrap(), connector: TurmoilConnector, disable_metrics: true, + auth_key: None, }), disable_default_namespace: true, disable_namespaces: false,