Skip to content

Commit

Permalink
add support for basic auth for admin api
Browse files Browse the repository at this point in the history
  • Loading branch information
MarinPostma committed Sep 3, 2024
1 parent c055656 commit 78abfe9
Show file tree
Hide file tree
Showing 12 changed files with 151 additions and 6 deletions.
1 change: 1 addition & 0 deletions libsql-server/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ pub struct AdminApiConfig<A = AddrIncoming, C = HttpsConnector<HttpConnector>> {
pub acceptor: A,
pub connector: C,
pub disable_metrics: bool,
pub auth_key: Option<String>,
}

#[derive(Clone)]
Expand Down
41 changes: 36 additions & 5 deletions libsql-server/src/http/admin/mod.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use anyhow::Context as _;
use axum::body::StreamBody;
use axum::extract::{FromRef, Path, State};
use axum::middleware::Next;
use axum::routing::delete;
use axum::Json;
use chrono::NaiveDateTime;
use futures::{SinkExt, StreamExt, TryStreamExt};
use hyper::{Body, Request};
use hyper::{Body, Request, StatusCode};
use metrics_exporter_prometheus::{PrometheusBuilder, PrometheusHandle};
use parking_lot::Mutex;
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -64,6 +65,7 @@ pub async fn run<A, C>(
connector: C,
disable_metrics: bool,
shutdown: Arc<Notify>,
auth: Option<Arc<str>>,
) -> anyhow::Result<()>
where
A: crate::net::Accept,
Expand Down Expand Up @@ -162,15 +164,15 @@ where
)
.route("/v1/diagnostics", get(handle_diagnostics))
.route("/metrics", get(handle_metrics))
.route("/profile/heap/enable", post(enable_profile_heap))
.route("/profile/heap/disable/:id", post(disable_profile_heap))
.route("/profile/heap/:id", delete(delete_profile_heap))
.with_state(Arc::new(AppState {
namespaces,
connector,
user_http_server,
metrics,
}))
.route("/profile/heap/enable", post(enable_profile_heap))
.route("/profile/heap/disable/:id", post(disable_profile_heap))
.route("/profile/heap/:id", delete(delete_profile_heap))
.layer(
tower_http::trace::TraceLayer::new_for_http()
.on_request(trace_request)
Expand All @@ -179,7 +181,8 @@ where
.level(tracing::Level::DEBUG)
.latency_unit(tower_http::LatencyUnit::Micros),
),
);
)
.layer(axum::middleware::from_fn_with_state(auth, auth_middleware));

hyper::server::Server::builder(acceptor)
.serve(router.into_make_service())
Expand All @@ -190,6 +193,34 @@ where
Ok(())
}

async fn auth_middleware<B>(
State(auth): State<Option<Arc<str>>>,
request: Request<B>,
next: Next<B>,
) -> Result<axum::response::Response, StatusCode> {
if let Some(ref auth) = auth {
let Some(auth_header) = request.headers().get("authorization") else {
return Err(StatusCode::UNAUTHORIZED);
};
let Ok(auth_str) = std::str::from_utf8(auth_header.as_bytes()) else {
return Err(StatusCode::UNAUTHORIZED);
};

let mut split = auth_str.split_whitespace();
match split.next() {
Some(s) if s.trim().eq_ignore_ascii_case("basic") => (),
_ => return Err(StatusCode::UNAUTHORIZED),
}

match split.next() {
Some(s) if s.trim() == auth.as_ref() => (),
_ => return Err(StatusCode::UNAUTHORIZED),
}
}

Ok(next.run(request).await)
}

async fn handle_get_index() -> &'static str {
"Welcome to the sqld admin API"
}
Expand Down
2 changes: 2 additions & 0 deletions libsql-server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@ where
acceptor,
connector,
disable_metrics,
auth_key,
}) = self.admin_api_config
{
task_manager.spawn_with_shutdown_notify(|shutdown| {
Expand All @@ -330,6 +331,7 @@ where
connector,
disable_metrics,
shutdown,
auth_key.map(Into::into),
)
});
}
Expand Down
5 changes: 5 additions & 0 deletions libsql-server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,10 @@ struct Cli {
/// Enables the main runtime deadlock monitor: if the main runtime deadlocks, logs an error
#[clap(long)]
enable_deadlock_monitor: bool,

/// Auth key for the admin API
#[clap(long, env = "LIBSQL_ADMIN_AUTH_KEY", requires = "admin_listen_addr")]
admin_auth_key: Option<String>,
}

#[derive(clap::Subcommand, Debug)]
Expand Down Expand Up @@ -468,6 +472,7 @@ async fn make_admin_api_config(config: &Cli) -> anyhow::Result<Option<AdminApiCo
acceptor,
connector,
disable_metrics: false,
auth_key: config.admin_auth_key.clone(),
}))
}
None => Ok(None),
Expand Down
2 changes: 2 additions & 0 deletions libsql-server/tests/cluster/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ pub fn make_cluster(sim: &mut Sim, num_replica: usize, disable_namespaces: bool)
acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 9090)).await?,
connector: TurmoilConnector,
disable_metrics: true,
auth_key: None,
}),
rpc_server_config: Some(RpcServerConfig {
acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 4567)).await?,
Expand Down Expand Up @@ -64,6 +65,7 @@ pub fn make_cluster(sim: &mut Sim, num_replica: usize, disable_namespaces: bool)
acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 9090)).await?,
connector: TurmoilConnector,
disable_metrics: true,
auth_key: None,
}),
rpc_client_config: Some(RpcClientConfig {
remote_url: "http://primary:4567".into(),
Expand Down
6 changes: 6 additions & 0 deletions libsql-server/tests/cluster/replica_restart.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ fn replica_restart() {
acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 9090)).await?,
connector: TurmoilConnector,
disable_metrics: true,
auth_key: None,
}),
rpc_server_config: Some(RpcServerConfig {
acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 4567)).await?,
Expand Down Expand Up @@ -67,6 +68,7 @@ fn replica_restart() {
acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 9090)).await.unwrap(),
connector: TurmoilConnector,
disable_metrics: true,
auth_key: None,
}),
rpc_client_config: Some(RpcClientConfig {
remote_url: "http://primary:4567".into(),
Expand Down Expand Up @@ -187,6 +189,7 @@ fn primary_regenerate_log_no_replica_restart() {
acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 9090)).await.unwrap(),
connector: TurmoilConnector,
disable_metrics: true,
auth_key: None,
}),
rpc_server_config: Some(RpcServerConfig {
acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 4567)).await.unwrap(),
Expand Down Expand Up @@ -241,6 +244,7 @@ fn primary_regenerate_log_no_replica_restart() {
acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 9090)).await.unwrap(),
connector: TurmoilConnector,
disable_metrics: true,
auth_key: None,
}),
rpc_client_config: Some(RpcClientConfig {
remote_url: "http://primary:4567".into(),
Expand Down Expand Up @@ -365,6 +369,7 @@ fn primary_regenerate_log_with_replica_restart() {
acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 9090)).await.unwrap(),
connector: TurmoilConnector,
disable_metrics: true,
auth_key: None,
}),
rpc_server_config: Some(RpcServerConfig {
acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 4567)).await.unwrap(),
Expand Down Expand Up @@ -421,6 +426,7 @@ fn primary_regenerate_log_with_replica_restart() {
acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 9090)).await.unwrap(),
connector: TurmoilConnector,
disable_metrics: true,
auth_key: None,
}),
rpc_client_config: Some(RpcClientConfig {
remote_url: "http://primary:4567".into(),
Expand Down
4 changes: 4 additions & 0 deletions libsql-server/tests/cluster/replication.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ fn apply_partial_snapshot() {
acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 9090)).await.unwrap(),
connector: TurmoilConnector,
disable_metrics: true,
auth_key: None,
}),
rpc_server_config: Some(RpcServerConfig {
acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 5050)).await.unwrap(),
Expand Down Expand Up @@ -71,6 +72,7 @@ fn apply_partial_snapshot() {
acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 9090)).await.unwrap(),
connector: TurmoilConnector,
disable_metrics: true,
auth_key: None,
}),
rpc_client_config: Some(RpcClientConfig {
remote_url: "http://primary:5050".into(),
Expand Down Expand Up @@ -167,6 +169,7 @@ fn replica_lazy_creation() {
acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 9090)).await.unwrap(),
connector: TurmoilConnector,
disable_metrics: true,
auth_key: None,
}),
rpc_server_config: Some(RpcServerConfig {
acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 5050)).await.unwrap(),
Expand Down Expand Up @@ -197,6 +200,7 @@ fn replica_lazy_creation() {
acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 9090)).await.unwrap(),
connector: TurmoilConnector,
disable_metrics: true,
auth_key: None,
}),
rpc_client_config: Some(RpcClientConfig {
remote_url: "http://primary:5050".into(),
Expand Down
19 changes: 18 additions & 1 deletion libsql-server/tests/common/http.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use axum::http::HeaderName;
use bytes::Bytes;
use hyper::Body;
use serde::{de::DeserializeOwned, Serialize};
Expand Down Expand Up @@ -41,11 +42,27 @@ impl Client {
}

pub(crate) async fn post<T: Serialize>(&self, url: &str, body: T) -> anyhow::Result<Response> {
self.post_with_headers(url, &[], body).await
}

pub(crate) async fn post_with_headers<T: Serialize>(
&self,
url: &str,
headers: &[(HeaderName, &str)],
body: T,
) -> anyhow::Result<Response> {
let bytes: Bytes = serde_json::to_vec(&body)?.into();
let body = Body::from(bytes);
let request = hyper::Request::post(url)
let mut request = hyper::Request::post(url)
.header("Content-Type", "application/json")
.body(body)?;

for (key, val) in headers {
request
.headers_mut()
.insert(key.clone(), val.parse().unwrap());
}

let resp = self.0.request(request).await?;

if resp.status().is_server_error() {
Expand Down
6 changes: 6 additions & 0 deletions libsql-server/tests/embedded_replica/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ fn make_primary(sim: &mut Sim, path: PathBuf) {
acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 9090)).await?,
connector: TurmoilConnector,
disable_metrics: false,
auth_key: None,
}),
rpc_server_config: Some(RpcServerConfig {
acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 4567)).await?,
Expand Down Expand Up @@ -408,6 +409,7 @@ fn replica_primary_reset() {
acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 9090)).await.unwrap(),
connector: TurmoilConnector,
disable_metrics: true,
auth_key: None,
}),
rpc_server_config: Some(RpcServerConfig {
acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 4567)).await.unwrap(),
Expand Down Expand Up @@ -692,6 +694,7 @@ fn replicate_with_snapshots() {
acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 9090)).await.unwrap(),
connector: TurmoilConnector,
disable_metrics: true,
auth_key: None,
}),
rpc_server_config: Some(RpcServerConfig {
acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 4567)).await.unwrap(),
Expand Down Expand Up @@ -1266,6 +1269,7 @@ fn replicated_return() {
acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 9090)).await.unwrap(),
connector: TurmoilConnector,
disable_metrics: true,
auth_key: None,
}),
rpc_server_config: Some(RpcServerConfig {
acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 4567)).await.unwrap(),
Expand Down Expand Up @@ -1395,6 +1399,7 @@ fn replicate_auth() {
acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 9090)).await?,
connector: TurmoilConnector,
disable_metrics: true,
auth_key: None,
}),
rpc_server_config: Some(RpcServerConfig {
acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 4567)).await?,
Expand Down Expand Up @@ -1430,6 +1435,7 @@ fn replicate_auth() {
acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 9090)).await?,
connector: TurmoilConnector,
disable_metrics: true,
auth_key: None,
}),
rpc_client_config: Some(RpcClientConfig {
remote_url: "http://primary:4567".into(),
Expand Down
1 change: 1 addition & 0 deletions libsql-server/tests/namespaces/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ fn make_primary(sim: &mut Sim, path: PathBuf) {
acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 9090)).await?,
connector: TurmoilConnector,
disable_metrics: true,
auth_key: None,
}),
rpc_server_config: Some(RpcServerConfig {
acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 4567)).await?,
Expand Down
67 changes: 67 additions & 0 deletions libsql-server/tests/standalone/admin.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
use std::time::Duration;

use hyper::StatusCode;
use libsql_server::config::{AdminApiConfig, UserApiConfig};
use s3s::header::AUTHORIZATION;
use serde_json::json;
use tempfile::tempdir;

use crate::common::{
http::Client,
net::{SimServer as _, TestServer, TurmoilAcceptor, TurmoilConnector},
};

#[test]
fn admin_auth() {
let mut sim = turmoil::Builder::new()
.simulation_duration(Duration::from_secs(1000))
.build();

sim.host("primary", || async move {
let tmp = tempdir().unwrap();
let server = TestServer {
path: tmp.path().to_owned().into(),
user_api_config: UserApiConfig {
hrana_ws_acceptor: None,
..Default::default()
},
admin_api_config: Some(AdminApiConfig {
acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 9090)).await.unwrap(),
connector: TurmoilConnector,
disable_metrics: true,
auth_key: Some("secretkey".into()),
}),
disable_namespaces: false,
..Default::default()
};
server.start_sim(8080).await?;
Ok(())
});

sim.client("test", async {
let client = Client::new();

assert_eq!(
client
.post("http://primary:9090/v1/namespaces/foo/create", json!({}))
.await
.unwrap()
.status(),
StatusCode::UNAUTHORIZED
);
assert!(client
.post_with_headers(
"http://primary:9090/v1/namespaces/foo/create",
&[(AUTHORIZATION, "basic secretkey")],
json!({})
)
.await
.unwrap()
.status()
.is_success());

Ok(())
});

sim.run().unwrap();
}
Loading

0 comments on commit 78abfe9

Please sign in to comment.