diff --git a/libsql-server/src/auth/constants.rs b/libsql-server/src/auth/constants.rs index 125d583a70..71d8540daf 100644 --- a/libsql-server/src/auth/constants.rs +++ b/libsql-server/src/auth/constants.rs @@ -1,2 +1,3 @@ +pub(crate) static AUTH_HEADER: &str = "authorization"; pub(crate) static GRPC_AUTH_HEADER: &str = "x-authorization"; pub(crate) static GRPC_PROXY_AUTH_HEADER: &str = "x-proxy-authorization"; diff --git a/libsql-server/src/auth/mod.rs b/libsql-server/src/auth/mod.rs index acc4fecb86..871c9e96d4 100644 --- a/libsql-server/src/auth/mod.rs +++ b/libsql-server/src/auth/mod.rs @@ -27,10 +27,7 @@ impl Auth { } } - pub fn authenticate( - &self, - context: Result, - ) -> Result { + pub fn authenticate(&self, context: UserAuthContext) -> Result { self.user_strategy.authenticate(context) } } diff --git a/libsql-server/src/auth/parsers.rs b/libsql-server/src/auth/parsers.rs index ca07d3ffa3..0fa208ca66 100644 --- a/libsql-server/src/auth/parsers.rs +++ b/libsql-server/src/auth/parsers.rs @@ -1,4 +1,4 @@ -use crate::auth::{constants::GRPC_AUTH_HEADER, AuthError}; +use crate::auth::AuthError; use anyhow::{bail, Context as _, Result}; use axum::http::HeaderValue; @@ -41,12 +41,21 @@ pub fn parse_jwt_keys(data: &str) -> Result> { } } -pub(crate) fn parse_grpc_auth_header(metadata: &MetadataMap) -> Result { - metadata - .get(GRPC_AUTH_HEADER) - .ok_or(AuthError::AuthHeaderNotFound) - .and_then(|h| h.to_str().map_err(|_| AuthError::AuthHeaderNonAscii)) - .and_then(|t| UserAuthContext::from_auth_str(t)) +pub(crate) fn parse_grpc_auth_header( + metadata: &MetadataMap, + required_fields: &Vec<&'static str>, +) -> Result { + let mut context = UserAuthContext::empty(); + + for field in required_fields.iter() { + metadata + .get(*field) + .ok_or_else(|| AuthError::AuthHeaderNotFound) + .and_then(|h| h.to_str().map_err(|_| AuthError::AuthHeaderNonAscii)) + .map(|v| context.add_field(field, v.into()))?; + } + + Ok(context) } pub fn parse_http_auth_header<'a>( @@ -78,6 +87,7 @@ mod tests { use hyper::header::AUTHORIZATION; use crate::auth::authorized::Scopes; + use crate::auth::constants::GRPC_AUTH_HEADER; use crate::auth::user_auth_strategies::jwt::Token; use crate::auth::{parse_http_auth_header, parse_jwt_keys, AuthError}; @@ -86,19 +96,13 @@ mod tests { #[test] fn parse_grpc_auth_header_returns_valid_context() { let mut map = tonic::metadata::MetadataMap::new(); - map.insert("x-authorization", "bearer 123".parse().unwrap()); - let context = parse_grpc_auth_header(&map).unwrap(); - assert_eq!(context.scheme().as_ref().unwrap(), "bearer"); - assert_eq!(context.token().as_ref().unwrap(), "123"); - } + map.insert(GRPC_AUTH_HEADER, "bearer 123".parse().unwrap()); + let required_fields = vec!["x-authorization".into()]; + let context = parse_grpc_auth_header(&map, &required_fields).unwrap(); - #[test] - fn parse_grpc_auth_header_error_no_header() { - let map = tonic::metadata::MetadataMap::new(); - let result = parse_grpc_auth_header(&map); assert_eq!( - result.unwrap_err().to_string(), - "Expected authorization header but none given" + context.get_field("x-authorization"), + Some(&"bearer 123".to_string()) ); } @@ -106,21 +110,11 @@ mod tests { fn parse_grpc_auth_header_error_non_ascii() { let mut map = tonic::metadata::MetadataMap::new(); map.insert("x-authorization", "bearer I❤NY".parse().unwrap()); - let result = parse_grpc_auth_header(&map); + let required_fields = vec!["x-authorization".into()]; + let result = parse_grpc_auth_header(&map, &required_fields); assert_eq!(result.unwrap_err().to_string(), "Non-ASCII auth header") } - #[test] - fn parse_grpc_auth_header_error_malformed_auth_str() { - let mut map = tonic::metadata::MetadataMap::new(); - map.insert("x-authorization", "bearer123".parse().unwrap()); - let result = parse_grpc_auth_header(&map); - assert_eq!( - result.unwrap_err().to_string(), - "Auth string does not conform to ' ' form" - ) - } - #[test] fn parse_http_auth_header_returns_auth_header_param_when_valid() { assert_eq!( diff --git a/libsql-server/src/auth/user_auth_strategies/disabled.rs b/libsql-server/src/auth/user_auth_strategies/disabled.rs index b95d52c061..ef9aae9062 100644 --- a/libsql-server/src/auth/user_auth_strategies/disabled.rs +++ b/libsql-server/src/auth/user_auth_strategies/disabled.rs @@ -4,10 +4,7 @@ use crate::auth::{AuthError, Authenticated}; pub struct Disabled {} impl UserAuthStrategy for Disabled { - fn authenticate( - &self, - _context: Result, - ) -> Result { + fn authenticate(&self, _context: UserAuthContext) -> Result { tracing::trace!("executing disabled auth"); Ok(Authenticated::FullAccess) } @@ -26,7 +23,7 @@ mod tests { #[test] fn authenticates() { let strategy = Disabled::new(); - let context = Ok(UserAuthContext::empty()); + let context = UserAuthContext::empty(); assert!(matches!( strategy.authenticate(context).unwrap(), diff --git a/libsql-server/src/auth/user_auth_strategies/http_basic.rs b/libsql-server/src/auth/user_auth_strategies/http_basic.rs index fbb45d0912..f42605d92c 100644 --- a/libsql-server/src/auth/user_auth_strategies/http_basic.rs +++ b/libsql-server/src/auth/user_auth_strategies/http_basic.rs @@ -1,4 +1,7 @@ -use crate::auth::{AuthError, Authenticated}; +use crate::auth::{ + constants::{AUTH_HEADER, GRPC_AUTH_HEADER}, + AuthError, Authenticated, +}; use super::{UserAuthContext, UserAuthStrategy}; @@ -7,27 +10,30 @@ pub struct HttpBasic { } impl UserAuthStrategy for HttpBasic { - fn authenticate( - &self, - context: Result, - ) -> Result { + fn authenticate(&self, ctx: UserAuthContext) -> Result { tracing::trace!("executing http basic auth"); + let auth_str = ctx + .get_field(AUTH_HEADER) + .or_else(|| ctx.get_field(GRPC_AUTH_HEADER)); + + let (_, token) = auth_str + .ok_or(AuthError::AuthHeaderNotFound) + .map(|s| s.split_once(' ').ok_or(AuthError::AuthStringMalformed)) + .and_then(|o| o)?; // NOTE: this naive comparison may leak information about the `expected_value` // using a timing attack let expected_value = self.credential.trim_end_matches('='); - - let creds_match = match context?.token { - Some(s) => s.contains(expected_value), - None => expected_value.is_empty(), - }; - + let creds_match = token.contains(expected_value); if creds_match { return Ok(Authenticated::FullAccess); } - Err(AuthError::BasicRejected) } + + fn required_fields(&self) -> Vec<&'static str> { + vec![AUTH_HEADER, GRPC_AUTH_HEADER] + } } impl HttpBasic { @@ -48,7 +54,7 @@ mod tests { #[test] fn authenticates_with_valid_credential() { - let context = Ok(UserAuthContext::basic(CREDENTIAL)); + let context = UserAuthContext::basic(CREDENTIAL); assert!(matches!( strategy().authenticate(context).unwrap(), @@ -59,7 +65,7 @@ mod tests { #[test] fn authenticates_with_valid_trimmed_credential() { let credential = CREDENTIAL.trim_end_matches('='); - let context = Ok(UserAuthContext::basic(credential)); + let context = UserAuthContext::basic(credential); assert!(matches!( strategy().authenticate(context).unwrap(), @@ -69,7 +75,7 @@ mod tests { #[test] fn errors_when_credentials_do_not_match() { - let context = Ok(UserAuthContext::basic("abc")); + let context = UserAuthContext::basic("abc"); assert_eq!( strategy().authenticate(context).unwrap_err(), diff --git a/libsql-server/src/auth/user_auth_strategies/jwt.rs b/libsql-server/src/auth/user_auth_strategies/jwt.rs index 5bc60faa03..320952bdce 100644 --- a/libsql-server/src/auth/user_auth_strategies/jwt.rs +++ b/libsql-server/src/auth/user_auth_strategies/jwt.rs @@ -1,7 +1,11 @@ use chrono::{DateTime, Utc}; use crate::{ - auth::{authenticated::LegacyAuth, AuthError, Authenticated, Authorized, Permission}, + auth::{ + authenticated::LegacyAuth, + constants::{AUTH_HEADER, GRPC_AUTH_HEADER}, + AuthError, Authenticated, Authorized, Permission, + }, namespace::NamespaceName, }; @@ -12,21 +16,16 @@ pub struct Jwt { } impl UserAuthStrategy for Jwt { - fn authenticate( - &self, - context: Result, - ) -> Result { + fn authenticate(&self, ctx: UserAuthContext) -> Result { tracing::trace!("executing jwt auth"); + let auth_str = ctx + .get_field(AUTH_HEADER) + .or_else(|| ctx.get_field(GRPC_AUTH_HEADER)) + .ok_or_else(|| AuthError::AuthHeaderNotFound)?; - let ctx = context?; - - let UserAuthContext { - scheme: Some(scheme), - token: Some(token), - } = ctx - else { - return Err(AuthError::HttpAuthHeaderInvalid); - }; + let (scheme, token) = auth_str + .split_once(' ') + .ok_or(AuthError::AuthStringMalformed)?; if !scheme.eq_ignore_ascii_case("bearer") { return Err(AuthError::HttpAuthHeaderUnsupportedScheme); @@ -34,6 +33,10 @@ impl UserAuthStrategy for Jwt { validate_any_jwt(&self.keys, &token) } + + fn required_fields(&self) -> Vec<&'static str> { + vec![AUTH_HEADER, GRPC_AUTH_HEADER] + } } impl Jwt { @@ -190,7 +193,7 @@ mod tests { }; let token = encode(&token, &enc); - let context = Ok(UserAuthContext::bearer(token.as_str())); + let context = UserAuthContext::bearer(token.as_str()); assert!(matches!( strategy(dec).authenticate(context).unwrap(), @@ -212,7 +215,7 @@ mod tests { }; let token = encode(&token, &enc); - let context = Ok(UserAuthContext::bearer(token.as_str())); + let context = UserAuthContext::bearer(token.as_str()); let Authenticated::Legacy(a) = strategy(dec).authenticate(context).unwrap() else { panic!() @@ -225,7 +228,7 @@ mod tests { #[test] fn errors_when_jwt_token_invalid() { let (_enc, dec) = generate_key_pair(); - let context = Ok(UserAuthContext::bearer("abc")); + let context = UserAuthContext::bearer("abc"); assert_eq!( strategy(dec).authenticate(context).unwrap_err(), @@ -245,7 +248,7 @@ mod tests { let token = encode(&token, &enc); - let context = Ok(UserAuthContext::bearer(token.as_str())); + let context = UserAuthContext::bearer(token.as_str()); assert_eq!( strategy(dec).authenticate(context).unwrap_err(), @@ -267,7 +270,7 @@ mod tests { let token = encode(&token, &enc); - let context = Ok(UserAuthContext::bearer(token.as_str())); + let context = UserAuthContext::bearer(token.as_str()); let Authenticated::Authorized(a) = strategy(dec).authenticate(context).unwrap() else { panic!() @@ -304,7 +307,7 @@ mod tests { for enc in multi_enc.iter() { let token = encode(&token, &enc); - let context = Ok(UserAuthContext::bearer(token.as_str())); + let context = UserAuthContext::bearer(token.as_str()); let Authenticated::Authorized(a) = strategy.authenticate(context).unwrap() else { panic!() @@ -331,7 +334,7 @@ mod tests { }); let token = encode(&token, &enc); - let context = Ok(UserAuthContext::bearer(token.as_str())); + let context = UserAuthContext::bearer(token.as_str()); assert_eq!( strategy_with_multiple(multi_dec) @@ -352,7 +355,7 @@ mod tests { }; let token = encode(&token, &multi_enc[0]); - let context = Ok(UserAuthContext::bearer(token.as_str())); + let context = UserAuthContext::bearer(token.as_str()); assert_eq!( strategy_with_multiple(multi_dec) @@ -373,7 +376,7 @@ mod tests { }; let token = encode(&token, &multi_enc[2]); - let context = Ok(UserAuthContext::bearer(token.as_str())); + let context = UserAuthContext::bearer(token.as_str()); assert_eq!( strategy_with_multiple(multi_dec) diff --git a/libsql-server/src/auth/user_auth_strategies/mod.rs b/libsql-server/src/auth/user_auth_strategies/mod.rs index 4f0f2ef786..73aadf6630 100644 --- a/libsql-server/src/auth/user_auth_strategies/mod.rs +++ b/libsql-server/src/auth/user_auth_strategies/mod.rs @@ -3,58 +3,39 @@ pub mod http_basic; pub mod jwt; pub use disabled::Disabled; +use hashbrown::HashMap; pub use http_basic::HttpBasic; pub use jwt::Jwt; -use super::{AuthError, Authenticated}; +use super::{constants::AUTH_HEADER, AuthError, Authenticated}; #[derive(Debug)] pub struct UserAuthContext { - scheme: Option, - token: Option, + custom_fields: HashMap<&'static str, String>, } impl UserAuthContext { - pub fn scheme(&self) -> &Option { - &self.scheme - } - - pub fn token(&self) -> &Option { - &self.token - } - pub fn empty() -> UserAuthContext { UserAuthContext { - scheme: None, - token: None, + custom_fields: HashMap::new(), } } pub fn basic(creds: &str) -> UserAuthContext { UserAuthContext { - scheme: Some("Basic".into()), - token: Some(creds.into()), + custom_fields: HashMap::from([(AUTH_HEADER, format!("Basic {creds}"))]), } } pub fn bearer(token: &str) -> UserAuthContext { UserAuthContext { - scheme: Some("Bearer".into()), - token: Some(token.into()), - } - } - - pub fn bearer_opt(token: Option) -> UserAuthContext { - UserAuthContext { - scheme: Some("Bearer".into()), - token: token, + custom_fields: HashMap::from([(AUTH_HEADER, format!("Bearer {token}"))]), } } pub fn new(scheme: &str, token: &str) -> UserAuthContext { UserAuthContext { - scheme: Some(scheme.into()), - token: Some(token.into()), + custom_fields: HashMap::from([(AUTH_HEADER, format!("{scheme} {token}"))]), } } @@ -64,11 +45,33 @@ impl UserAuthContext { .ok_or(AuthError::AuthStringMalformed)?; Ok(UserAuthContext::new(scheme, token)) } + + pub fn add_field(&mut self, key: &'static str, value: String) { + self.custom_fields.insert(key, value.into()); + } + + pub fn get_field(&self, key: &'static str) -> Option<&String> { + return self.custom_fields.get(key); + } } pub trait UserAuthStrategy: Sync + Send { - fn authenticate( - &self, - context: Result, - ) -> Result; + /// Returns a list of fields required by the stragegy. + /// Every strategy implementation should override this function if it requires input to work. + /// Strategy implementations should validate the content of provided fields. + /// + /// The caller is responsible for providing at least one of these fields in UserAuthContext. + /// The caller should assume the strategy will not work if none of the required fields is provided. + /// + fn required_fields(&self) -> Vec<&'static str> { + vec![] + } + + /// Performs authentication of the user and returns Authenticated witness if successful. + /// Returns respective AuthError communicating the reason for failure. + /// Assumes the context input contains at least one of the fields specified in required_fields() + /// + /// Warning: this function deals with sensitive information. + /// Implementer should be very careful about what information they chose to log or provide in AuthError message. + fn authenticate(&self, context: UserAuthContext) -> Result; } diff --git a/libsql-server/src/hrana/ws/session.rs b/libsql-server/src/hrana/ws/session.rs index 4f4288ae8d..16ceab5253 100644 --- a/libsql-server/src/hrana/ws/session.rs +++ b/libsql-server/src/hrana/ws/session.rs @@ -7,6 +7,7 @@ use tokio::sync::{mpsc, oneshot}; use super::super::{batch, cursor, stmt, ProtocolError, Version}; use super::{proto, Server}; +use crate::auth::constants::AUTH_HEADER; use crate::auth::user_auth_strategies::UserAuthContext; use crate::auth::{Auth, AuthError, Authenticated, Jwt}; use crate::connection::{Connection as _, RequestContext}; @@ -98,14 +99,27 @@ pub(super) async fn handle_hello( .with(namespace.clone(), |ns| ns.jwt_keys()) .await??; - namespace_jwt_keys + let auth_strategy = namespace_jwt_keys .map(Jwt::new) .map(Auth::new) - .unwrap_or_else(|| server.user_auth_strategy.clone()) - .authenticate(Ok(UserAuthContext::bearer_opt(jwt))) + .unwrap_or_else(|| server.user_auth_strategy.clone()); + + let context: UserAuthContext = + build_context(jwt, &auth_strategy.user_strategy.required_fields()); + + auth_strategy + .authenticate(context) .map_err(|err| anyhow!(ResponseError::Auth { source: err })) } +fn build_context(jwt: Option, required_fields: &Vec<&'static str>) -> UserAuthContext { + let mut ctx = UserAuthContext::empty(); + if required_fields.contains(&AUTH_HEADER) && jwt.is_some() { + ctx.add_field(AUTH_HEADER, jwt.unwrap()); + } + ctx +} + pub(super) async fn handle_request( server: &Server, session: &mut Session, diff --git a/libsql-server/src/http/user/extract.rs b/libsql-server/src/http/user/extract.rs index bf4bec1861..89dd529150 100644 --- a/libsql-server/src/http/user/extract.rs +++ b/libsql-server/src/http/user/extract.rs @@ -1,7 +1,7 @@ use axum::extract::FromRequestParts; use crate::{ - auth::{Auth, AuthError, Jwt, UserAuthContext}, + auth::{Auth, Jwt}, connection::RequestContext, }; @@ -26,21 +26,15 @@ impl FromRequestParts for RequestContext { .with(namespace.clone(), |ns| ns.jwt_keys()) .await??; - let context = parts - .headers - .get(hyper::header::AUTHORIZATION) - .ok_or(AuthError::AuthHeaderNotFound) - .and_then(|h| h.to_str().map_err(|_| AuthError::AuthHeaderNonAscii)) - .and_then(|t| UserAuthContext::from_auth_str(t)); - - let authenticated = namespace_jwt_keys + let auth = namespace_jwt_keys .map(Jwt::new) .map(Auth::new) - .unwrap_or_else(|| state.user_auth_strategy.clone()) - .authenticate(context)?; + .unwrap_or_else(|| state.user_auth_strategy.clone()); + + let context = super::build_context(&parts.headers, &auth.user_strategy.required_fields()); Ok(Self::new( - authenticated, + auth.authenticate(context)?, namespace, state.namespaces.meta_store().clone(), )) diff --git a/libsql-server/src/http/user/mod.rs b/libsql-server/src/http/user/mod.rs index 54e9a2b307..463eb4cd42 100644 --- a/libsql-server/src/http/user/mod.rs +++ b/libsql-server/src/http/user/mod.rs @@ -472,20 +472,36 @@ impl FromRequestParts for Authenticated { .with(ns.clone(), |ns| ns.jwt_keys()) .await??; - let context = parts - .headers - .get(hyper::header::AUTHORIZATION) - .ok_or(AuthError::AuthHeaderNotFound) - .and_then(|h| h.to_str().map_err(|_| AuthError::AuthHeaderNonAscii)) - .and_then(|t| UserAuthContext::from_auth_str(t)); - - let authenticated = namespace_jwt_keys + let auth = namespace_jwt_keys .map(Jwt::new) .map(Auth::new) - .unwrap_or_else(|| state.user_auth_strategy.clone()) - .authenticate(context)?; - Ok(authenticated) + .unwrap_or_else(|| state.user_auth_strategy.clone()); + + let context = build_context(&parts.headers, &auth.user_strategy.required_fields()); + + Ok(auth.authenticate(context)?) + } +} + +fn build_context( + headers: &hyper::HeaderMap, + required_fields: &Vec<&'static str>, +) -> UserAuthContext { + let mut ctx = headers + .get(hyper::header::AUTHORIZATION) + .ok_or(AuthError::AuthHeaderNotFound) + .and_then(|h| h.to_str().map_err(|_| AuthError::AuthHeaderNonAscii)) + .and_then(|t| UserAuthContext::from_auth_str(t)) + .unwrap_or(UserAuthContext::empty()); + + for field in required_fields.iter() { + headers + .get(field.to_string()) + .map(|h| h.to_str().ok()) + .and_then(|t| t.map(|s| ctx.add_field(field, s.into()))); } + + ctx } impl FromRef for Auth { diff --git a/libsql-server/src/rpc/proxy.rs b/libsql-server/src/rpc/proxy.rs index 5a4565ca4e..44ed0429ff 100644 --- a/libsql-server/src/rpc/proxy.rs +++ b/libsql-server/src/rpc/proxy.rs @@ -327,7 +327,11 @@ impl ProxyService { }; let auth = if let Some(auth) = auth { - let context = parse_grpc_auth_header(req.metadata()); + let context = + parse_grpc_auth_header(req.metadata(), &auth.user_strategy.required_fields()) + .map_err(|e| { + tonic::Status::internal(format!("Error parsing auth header: {}", e)) + })?; auth.authenticate(context)? } else { Authenticated::from_proxy_grpc_request(req)? diff --git a/libsql-server/src/rpc/replica_proxy.rs b/libsql-server/src/rpc/replica_proxy.rs index 08cedb2c44..f6efec2739 100644 --- a/libsql-server/src/rpc/replica_proxy.rs +++ b/libsql-server/src/rpc/replica_proxy.rs @@ -45,7 +45,7 @@ impl ReplicaProxyService { let namespace_jwt_keys = jwt_result.and_then(|s| s); - let auth_strategy = match namespace_jwt_keys { + let auth = match namespace_jwt_keys { Ok(Some(key)) => Ok(Auth::new(Jwt::new(key))), Ok(None) | Err(crate::error::Error::NamespaceDoesntExist(_)) => { Ok(self.user_auth_strategy.clone()) @@ -56,10 +56,12 @@ impl ReplicaProxyService { ))), }?; - let auth_context = parse_grpc_auth_header(req.metadata()); - auth_strategy - .authenticate(auth_context)? - .upgrade_grpc_request(req); + let auth_context = + parse_grpc_auth_header(req.metadata(), &auth.user_strategy.required_fields()).map_err( + |e| tonic::Status::internal(format!("Error parsing auth header: {}", e)), + )?; + auth.authenticate(auth_context)?.upgrade_grpc_request(req); + return Ok(()); } } diff --git a/libsql-server/src/rpc/replication_log.rs b/libsql-server/src/rpc/replication_log.rs index 105c96aa7c..c0b216739e 100644 --- a/libsql-server/src/rpc/replication_log.rs +++ b/libsql-server/src/rpc/replication_log.rs @@ -94,8 +94,12 @@ impl ReplicationLogService { }; if let Some(auth) = auth { - let user_credential = parse_grpc_auth_header(req.metadata()); - auth.authenticate(user_credential)?; + let context = + parse_grpc_auth_header(req.metadata(), &auth.user_strategy.required_fields()) + .map_err(|e| { + tonic::Status::internal(format!("Error parsing auth header: {}", e)) + })?; + auth.authenticate(context)?; } Ok(())