diff --git a/libsql-server/src/auth/parsers.rs b/libsql-server/src/auth/parsers.rs index 0fa208ca66..458750d61e 100644 --- a/libsql-server/src/auth/parsers.rs +++ b/libsql-server/src/auth/parsers.rs @@ -47,12 +47,22 @@ pub(crate) fn parse_grpc_auth_header( ) -> Result { let mut context = UserAuthContext::empty(); + let mut auth_header_seen = false; + + if required_fields.is_empty() { + return Ok(context); + } + for field in required_fields.iter() { - metadata - .get(*field) - .ok_or_else(|| AuthError::AuthHeaderNotFound) - .and_then(|h| h.to_str().map_err(|_| AuthError::AuthHeaderNonAscii)) - .map(|v| context.add_field(field, v.into()))?; + if let Some(h) = metadata.get(*field) { + let v = h.to_str().map_err(|_| AuthError::AuthHeaderNonAscii)?; + context.add_field(field, v.into()); + auth_header_seen = true; + } + } + + if !auth_header_seen { + return Err(AuthError::AuthHeaderNotFound.into()); } Ok(context) @@ -106,6 +116,19 @@ mod tests { ); } + #[test] + fn parse_grpc_auth_header_with_multiple_required_fields() { + let mut map = tonic::metadata::MetadataMap::new(); + map.insert(GRPC_AUTH_HEADER, "bearer 123".parse().unwrap()); + let required_fields = vec!["authorization".into(), "x-authorization".into()]; + let context = parse_grpc_auth_header(&map, &required_fields).unwrap(); + + assert_eq!( + context.get_field("x-authorization"), + Some(&"bearer 123".to_string()) + ); + } + #[test] fn parse_grpc_auth_header_error_non_ascii() { let mut map = tonic::metadata::MetadataMap::new(); diff --git a/libsql-server/src/auth/user_auth_strategies/mod.rs b/libsql-server/src/auth/user_auth_strategies/mod.rs index 73aadf6630..7223a587f3 100644 --- a/libsql-server/src/auth/user_auth_strategies/mod.rs +++ b/libsql-server/src/auth/user_auth_strategies/mod.rs @@ -62,7 +62,6 @@ pub trait UserAuthStrategy: Sync + Send { /// /// The caller is responsible for providing at least one of these fields in UserAuthContext. /// The caller should assume the strategy will not work if none of the required fields is provided. - /// fn required_fields(&self) -> Vec<&'static str> { vec![] } diff --git a/libsql-server/src/hrana/ws/session.rs b/libsql-server/src/hrana/ws/session.rs index 16ceab5253..983a84f1c6 100644 --- a/libsql-server/src/hrana/ws/session.rs +++ b/libsql-server/src/hrana/ws/session.rs @@ -104,8 +104,10 @@ pub(super) async fn handle_hello( .map(Auth::new) .unwrap_or_else(|| server.user_auth_strategy.clone()); + let token = jwt.map(|t| format!("Bearer {}", t)); + let context: UserAuthContext = - build_context(jwt, &auth_strategy.user_strategy.required_fields()); + build_context(token, &auth_strategy.user_strategy.required_fields()); auth_strategy .authenticate(context) diff --git a/libsql-server/tests/auth/jwt_key.pem b/libsql-server/tests/auth/jwt_key.pem new file mode 100644 index 0000000000..ac9591cf12 --- /dev/null +++ b/libsql-server/tests/auth/jwt_key.pem @@ -0,0 +1,3 @@ +-----BEGIN PUBLIC KEY----- +MCowBQYDK2VwAyEAbz/oEg1rMRGY12X4Q0GXioX1hXaM69o9kp7h0eCLD/E= +-----END PUBLIC KEY----- diff --git a/libsql-server/tests/auth/mod.rs b/libsql-server/tests/auth/mod.rs new file mode 100644 index 0000000000..5834d44c33 --- /dev/null +++ b/libsql-server/tests/auth/mod.rs @@ -0,0 +1,136 @@ +//! Test hrana related functionalities +#![allow(deprecated)] + +use futures::SinkExt as _; +use libsql::Database; +use libsql_server::{ + auth::{user_auth_strategies, Auth}, + config::UserApiConfig, +}; +use tempfile::tempdir; +use tokio_stream::StreamExt; +use tokio_tungstenite::{ + client_async, + tungstenite::{self, client::IntoClientRequest}, +}; +use turmoil::net::TcpStream; + +use crate::common::net::{init_tracing, SimServer, TestServer, TurmoilConnector}; + +const TEST_JWT_KEY: &str = "eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJleHAiOjE3MjE2NTIwNTB9.5XhUDHQhtShszTssjjUzVuJA3r-031mT4inVkvYEYz64sOCxnNpZUZdVF-CmZ4t-JTSXFlm8ddscBgkhccBxDg"; + +async fn make_standalone_server() -> Result<(), Box> { + let jwt_pem = include_bytes!("jwt_key.pem"); + let jwt_keys = vec![jsonwebtoken::DecodingKey::from_ed_pem(jwt_pem).unwrap()]; + + init_tracing(); + let tmp = tempdir()?; + let server = TestServer { + path: tmp.path().to_owned().into(), + user_api_config: UserApiConfig { + hrana_ws_acceptor: None, + auth_strategy: Auth::new(user_auth_strategies::Jwt::new(jwt_keys)), + ..Default::default() + }, + ..Default::default() + }; + + server.start_sim(8080).await?; + + Ok(()) +} + +#[test] +fn http_hrana() { + let mut sim = turmoil::Builder::new().build(); + sim.host("primary", make_standalone_server); + sim.client("client", async { + let db = Database::open_remote_with_connector( + "http://primary:8080", + TEST_JWT_KEY, + TurmoilConnector, + )?; + let conn = db.connect()?; + + conn.execute("create table t(x text)", ()).await?; + + Ok(()) + }); + + sim.run().unwrap(); +} + +#[test] +fn embedded_replica() { + let tmp_embedded = tempdir().unwrap(); + let tmp_embedded_path = tmp_embedded.path().to_owned(); + + let mut sim = turmoil::Builder::new().build(); + sim.host("primary", make_standalone_server); + sim.client("client", async move { + let path = tmp_embedded_path.join("embedded"); + + let db = Database::open_with_remote_sync_connector( + path.to_str().unwrap(), + "http://primary:8080", + TEST_JWT_KEY, + TurmoilConnector, + false, + None, + ) + .await?; + + let conn = db.connect()?; + + conn.execute("create table t(x text)", ()).await?; + + Ok(()) + }); + + sim.run().unwrap(); +} + +#[test] +fn ws_hrana() { + let mut sim = turmoil::Builder::new().build(); + sim.host("primary", make_standalone_server); + sim.client("client", async { + let url = "ws://primary:8080"; + + let req = url.into_client_request().unwrap(); + + let conn = TcpStream::connect("primary:8080").await.unwrap(); + + let (mut ws, _) = client_async(req, conn).await.unwrap(); + + #[derive(serde::Serialize, Debug)] + #[serde(tag = "type", rename_all = "snake_case")] + pub enum ClientMsg { + Hello { jwt: Option }, + } + + #[derive(serde::Deserialize, Debug)] + #[serde(tag = "type", rename_all = "snake_case")] + pub enum ServerMsg { + HelloOk {}, + } + + let msg = ClientMsg::Hello { + jwt: Some(TEST_JWT_KEY.to_string()), + }; + + let msg_data = serde_json::to_string(&msg).unwrap(); + + ws.send(tungstenite::Message::Text(msg_data)).await.unwrap(); + + let Some(tungstenite::Message::Text(msg)) = ws.try_next().await.unwrap() else { + panic!("wrong message type"); + }; + + serde_json::from_str::(&msg).unwrap(); + + Ok(()) + }); + + sim.run().unwrap(); +} diff --git a/libsql-server/tests/tests.rs b/libsql-server/tests/tests.rs index 497814660c..ab475df546 100644 --- a/libsql-server/tests/tests.rs +++ b/libsql-server/tests/tests.rs @@ -3,6 +3,7 @@ #[macro_use] mod common; +mod auth; mod cluster; mod embedded_replica; mod hrana;