Skip to content

Commit

Permalink
server: fix auth and add ws/http/grpc auth tests (#1577)
Browse files Browse the repository at this point in the history
* server: fix auth and add ws/http/grpc auth tests

* fix grpc parser
  • Loading branch information
LucioFranco authored Jul 19, 2024
1 parent a566230 commit de6dc81
Show file tree
Hide file tree
Showing 6 changed files with 171 additions and 7 deletions.
33 changes: 28 additions & 5 deletions libsql-server/src/auth/parsers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,22 @@ pub(crate) fn parse_grpc_auth_header(
) -> Result<UserAuthContext> {
let mut context = UserAuthContext::empty();

let mut auth_header_seen = false;

if required_fields.is_empty() {
return Ok(context);
}

for field in required_fields.iter() {
metadata
.get(*field)
.ok_or_else(|| AuthError::AuthHeaderNotFound)
.and_then(|h| h.to_str().map_err(|_| AuthError::AuthHeaderNonAscii))
.map(|v| context.add_field(field, v.into()))?;
if let Some(h) = metadata.get(*field) {
let v = h.to_str().map_err(|_| AuthError::AuthHeaderNonAscii)?;
context.add_field(field, v.into());
auth_header_seen = true;
}
}

if !auth_header_seen {
return Err(AuthError::AuthHeaderNotFound.into());
}

Ok(context)
Expand Down Expand Up @@ -106,6 +116,19 @@ mod tests {
);
}

#[test]
fn parse_grpc_auth_header_with_multiple_required_fields() {
let mut map = tonic::metadata::MetadataMap::new();
map.insert(GRPC_AUTH_HEADER, "bearer 123".parse().unwrap());
let required_fields = vec!["authorization".into(), "x-authorization".into()];
let context = parse_grpc_auth_header(&map, &required_fields).unwrap();

assert_eq!(
context.get_field("x-authorization"),
Some(&"bearer 123".to_string())
);
}

#[test]
fn parse_grpc_auth_header_error_non_ascii() {
let mut map = tonic::metadata::MetadataMap::new();
Expand Down
1 change: 0 additions & 1 deletion libsql-server/src/auth/user_auth_strategies/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ pub trait UserAuthStrategy: Sync + Send {
///
/// The caller is responsible for providing at least one of these fields in UserAuthContext.
/// The caller should assume the strategy will not work if none of the required fields is provided.
///
fn required_fields(&self) -> Vec<&'static str> {
vec![]
}
Expand Down
4 changes: 3 additions & 1 deletion libsql-server/src/hrana/ws/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,10 @@ pub(super) async fn handle_hello(
.map(Auth::new)
.unwrap_or_else(|| server.user_auth_strategy.clone());

let token = jwt.map(|t| format!("Bearer {}", t));

let context: UserAuthContext =
build_context(jwt, &auth_strategy.user_strategy.required_fields());
build_context(token, &auth_strategy.user_strategy.required_fields());

auth_strategy
.authenticate(context)
Expand Down
3 changes: 3 additions & 0 deletions libsql-server/tests/auth/jwt_key.pem
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
-----BEGIN PUBLIC KEY-----
MCowBQYDK2VwAyEAbz/oEg1rMRGY12X4Q0GXioX1hXaM69o9kp7h0eCLD/E=
-----END PUBLIC KEY-----
136 changes: 136 additions & 0 deletions libsql-server/tests/auth/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
//! Test hrana related functionalities
#![allow(deprecated)]

use futures::SinkExt as _;
use libsql::Database;
use libsql_server::{
auth::{user_auth_strategies, Auth},
config::UserApiConfig,
};
use tempfile::tempdir;
use tokio_stream::StreamExt;
use tokio_tungstenite::{
client_async,
tungstenite::{self, client::IntoClientRequest},
};
use turmoil::net::TcpStream;

use crate::common::net::{init_tracing, SimServer, TestServer, TurmoilConnector};

const TEST_JWT_KEY: &str = "eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJleHAiOjE3MjE2NTIwNTB9.5XhUDHQhtShszTssjjUzVuJA3r-031mT4inVkvYEYz64sOCxnNpZUZdVF-CmZ4t-JTSXFlm8ddscBgkhccBxDg";

async fn make_standalone_server() -> Result<(), Box<dyn std::error::Error>> {
let jwt_pem = include_bytes!("jwt_key.pem");
let jwt_keys = vec![jsonwebtoken::DecodingKey::from_ed_pem(jwt_pem).unwrap()];

init_tracing();
let tmp = tempdir()?;
let server = TestServer {
path: tmp.path().to_owned().into(),
user_api_config: UserApiConfig {
hrana_ws_acceptor: None,
auth_strategy: Auth::new(user_auth_strategies::Jwt::new(jwt_keys)),
..Default::default()
},
..Default::default()
};

server.start_sim(8080).await?;

Ok(())
}

#[test]
fn http_hrana() {
let mut sim = turmoil::Builder::new().build();
sim.host("primary", make_standalone_server);
sim.client("client", async {
let db = Database::open_remote_with_connector(
"http://primary:8080",
TEST_JWT_KEY,
TurmoilConnector,
)?;
let conn = db.connect()?;

conn.execute("create table t(x text)", ()).await?;

Ok(())
});

sim.run().unwrap();
}

#[test]
fn embedded_replica() {
let tmp_embedded = tempdir().unwrap();
let tmp_embedded_path = tmp_embedded.path().to_owned();

let mut sim = turmoil::Builder::new().build();
sim.host("primary", make_standalone_server);
sim.client("client", async move {
let path = tmp_embedded_path.join("embedded");

let db = Database::open_with_remote_sync_connector(
path.to_str().unwrap(),
"http://primary:8080",
TEST_JWT_KEY,
TurmoilConnector,
false,
None,
)
.await?;

let conn = db.connect()?;

conn.execute("create table t(x text)", ()).await?;

Ok(())
});

sim.run().unwrap();
}

#[test]
fn ws_hrana() {
let mut sim = turmoil::Builder::new().build();
sim.host("primary", make_standalone_server);
sim.client("client", async {
let url = "ws://primary:8080";

let req = url.into_client_request().unwrap();

let conn = TcpStream::connect("primary:8080").await.unwrap();

let (mut ws, _) = client_async(req, conn).await.unwrap();

#[derive(serde::Serialize, Debug)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ClientMsg {
Hello { jwt: Option<String> },
}

#[derive(serde::Deserialize, Debug)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ServerMsg {
HelloOk {},
}

let msg = ClientMsg::Hello {
jwt: Some(TEST_JWT_KEY.to_string()),
};

let msg_data = serde_json::to_string(&msg).unwrap();

ws.send(tungstenite::Message::Text(msg_data)).await.unwrap();

let Some(tungstenite::Message::Text(msg)) = ws.try_next().await.unwrap() else {
panic!("wrong message type");
};

serde_json::from_str::<ServerMsg>(&msg).unwrap();

Ok(())
});

sim.run().unwrap();
}
1 change: 1 addition & 0 deletions libsql-server/tests/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#[macro_use]
mod common;

mod auth;
mod cluster;
mod embedded_replica;
mod hrana;
Expand Down

0 comments on commit de6dc81

Please sign in to comment.