Skip to content

Commit

Permalink
implemented custom headers passing in auth layer (#1472)
Browse files Browse the repository at this point in the history
  • Loading branch information
shopifyski authored Jul 12, 2024
1 parent 4b21878 commit ec9aa56
Show file tree
Hide file tree
Showing 13 changed files with 176 additions and 141 deletions.
1 change: 1 addition & 0 deletions libsql-server/src/auth/constants.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
pub(crate) static AUTH_HEADER: &str = "authorization";
pub(crate) static GRPC_AUTH_HEADER: &str = "x-authorization";
pub(crate) static GRPC_PROXY_AUTH_HEADER: &str = "x-proxy-authorization";
5 changes: 1 addition & 4 deletions libsql-server/src/auth/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,7 @@ impl Auth {
}
}

pub fn authenticate(
&self,
context: Result<UserAuthContext, AuthError>,
) -> Result<Authenticated, AuthError> {
pub fn authenticate(&self, context: UserAuthContext) -> Result<Authenticated, AuthError> {
self.user_strategy.authenticate(context)
}
}
54 changes: 24 additions & 30 deletions libsql-server/src/auth/parsers.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::auth::{constants::GRPC_AUTH_HEADER, AuthError};
use crate::auth::AuthError;

use anyhow::{bail, Context as _, Result};
use axum::http::HeaderValue;
Expand Down Expand Up @@ -41,12 +41,21 @@ pub fn parse_jwt_keys(data: &str) -> Result<Vec<jsonwebtoken::DecodingKey>> {
}
}

pub(crate) fn parse_grpc_auth_header(metadata: &MetadataMap) -> Result<UserAuthContext, AuthError> {
metadata
.get(GRPC_AUTH_HEADER)
.ok_or(AuthError::AuthHeaderNotFound)
.and_then(|h| h.to_str().map_err(|_| AuthError::AuthHeaderNonAscii))
.and_then(|t| UserAuthContext::from_auth_str(t))
pub(crate) fn parse_grpc_auth_header(
metadata: &MetadataMap,
required_fields: &Vec<&'static str>,
) -> Result<UserAuthContext> {
let mut context = UserAuthContext::empty();

for field in required_fields.iter() {
metadata
.get(*field)
.ok_or_else(|| AuthError::AuthHeaderNotFound)
.and_then(|h| h.to_str().map_err(|_| AuthError::AuthHeaderNonAscii))
.map(|v| context.add_field(field, v.into()))?;
}

Ok(context)
}

pub fn parse_http_auth_header<'a>(
Expand Down Expand Up @@ -78,6 +87,7 @@ mod tests {
use hyper::header::AUTHORIZATION;

use crate::auth::authorized::Scopes;
use crate::auth::constants::GRPC_AUTH_HEADER;
use crate::auth::user_auth_strategies::jwt::Token;
use crate::auth::{parse_http_auth_header, parse_jwt_keys, AuthError};

Expand All @@ -86,41 +96,25 @@ mod tests {
#[test]
fn parse_grpc_auth_header_returns_valid_context() {
let mut map = tonic::metadata::MetadataMap::new();
map.insert("x-authorization", "bearer 123".parse().unwrap());
let context = parse_grpc_auth_header(&map).unwrap();
assert_eq!(context.scheme().as_ref().unwrap(), "bearer");
assert_eq!(context.token().as_ref().unwrap(), "123");
}
map.insert(GRPC_AUTH_HEADER, "bearer 123".parse().unwrap());
let required_fields = vec!["x-authorization".into()];
let context = parse_grpc_auth_header(&map, &required_fields).unwrap();

#[test]
fn parse_grpc_auth_header_error_no_header() {
let map = tonic::metadata::MetadataMap::new();
let result = parse_grpc_auth_header(&map);
assert_eq!(
result.unwrap_err().to_string(),
"Expected authorization header but none given"
context.get_field("x-authorization"),
Some(&"bearer 123".to_string())
);
}

#[test]
fn parse_grpc_auth_header_error_non_ascii() {
let mut map = tonic::metadata::MetadataMap::new();
map.insert("x-authorization", "bearer I❤NY".parse().unwrap());
let result = parse_grpc_auth_header(&map);
let required_fields = vec!["x-authorization".into()];
let result = parse_grpc_auth_header(&map, &required_fields);
assert_eq!(result.unwrap_err().to_string(), "Non-ASCII auth header")
}

#[test]
fn parse_grpc_auth_header_error_malformed_auth_str() {
let mut map = tonic::metadata::MetadataMap::new();
map.insert("x-authorization", "bearer123".parse().unwrap());
let result = parse_grpc_auth_header(&map);
assert_eq!(
result.unwrap_err().to_string(),
"Auth string does not conform to '<scheme> <token>' form"
)
}

#[test]
fn parse_http_auth_header_returns_auth_header_param_when_valid() {
assert_eq!(
Expand Down
7 changes: 2 additions & 5 deletions libsql-server/src/auth/user_auth_strategies/disabled.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,7 @@ use crate::auth::{AuthError, Authenticated};
pub struct Disabled {}

impl UserAuthStrategy for Disabled {
fn authenticate(
&self,
_context: Result<UserAuthContext, AuthError>,
) -> Result<Authenticated, AuthError> {
fn authenticate(&self, _context: UserAuthContext) -> Result<Authenticated, AuthError> {
tracing::trace!("executing disabled auth");
Ok(Authenticated::FullAccess)
}
Expand All @@ -26,7 +23,7 @@ mod tests {
#[test]
fn authenticates() {
let strategy = Disabled::new();
let context = Ok(UserAuthContext::empty());
let context = UserAuthContext::empty();

assert!(matches!(
strategy.authenticate(context).unwrap(),
Expand Down
36 changes: 21 additions & 15 deletions libsql-server/src/auth/user_auth_strategies/http_basic.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use crate::auth::{AuthError, Authenticated};
use crate::auth::{
constants::{AUTH_HEADER, GRPC_AUTH_HEADER},
AuthError, Authenticated,
};

use super::{UserAuthContext, UserAuthStrategy};

Expand All @@ -7,27 +10,30 @@ pub struct HttpBasic {
}

impl UserAuthStrategy for HttpBasic {
fn authenticate(
&self,
context: Result<UserAuthContext, AuthError>,
) -> Result<Authenticated, AuthError> {
fn authenticate(&self, ctx: UserAuthContext) -> Result<Authenticated, AuthError> {
tracing::trace!("executing http basic auth");
let auth_str = ctx
.get_field(AUTH_HEADER)
.or_else(|| ctx.get_field(GRPC_AUTH_HEADER));

let (_, token) = auth_str
.ok_or(AuthError::AuthHeaderNotFound)
.map(|s| s.split_once(' ').ok_or(AuthError::AuthStringMalformed))
.and_then(|o| o)?;

// NOTE: this naive comparison may leak information about the `expected_value`
// using a timing attack
let expected_value = self.credential.trim_end_matches('=');

let creds_match = match context?.token {
Some(s) => s.contains(expected_value),
None => expected_value.is_empty(),
};

let creds_match = token.contains(expected_value);
if creds_match {
return Ok(Authenticated::FullAccess);
}

Err(AuthError::BasicRejected)
}

fn required_fields(&self) -> Vec<&'static str> {
vec![AUTH_HEADER, GRPC_AUTH_HEADER]
}
}

impl HttpBasic {
Expand All @@ -48,7 +54,7 @@ mod tests {

#[test]
fn authenticates_with_valid_credential() {
let context = Ok(UserAuthContext::basic(CREDENTIAL));
let context = UserAuthContext::basic(CREDENTIAL);

assert!(matches!(
strategy().authenticate(context).unwrap(),
Expand All @@ -59,7 +65,7 @@ mod tests {
#[test]
fn authenticates_with_valid_trimmed_credential() {
let credential = CREDENTIAL.trim_end_matches('=');
let context = Ok(UserAuthContext::basic(credential));
let context = UserAuthContext::basic(credential);

assert!(matches!(
strategy().authenticate(context).unwrap(),
Expand All @@ -69,7 +75,7 @@ mod tests {

#[test]
fn errors_when_credentials_do_not_match() {
let context = Ok(UserAuthContext::basic("abc"));
let context = UserAuthContext::basic("abc");

assert_eq!(
strategy().authenticate(context).unwrap_err(),
Expand Down
49 changes: 26 additions & 23 deletions libsql-server/src/auth/user_auth_strategies/jwt.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
use chrono::{DateTime, Utc};

use crate::{
auth::{authenticated::LegacyAuth, AuthError, Authenticated, Authorized, Permission},
auth::{
authenticated::LegacyAuth,
constants::{AUTH_HEADER, GRPC_AUTH_HEADER},
AuthError, Authenticated, Authorized, Permission,
},
namespace::NamespaceName,
};

Expand All @@ -12,28 +16,27 @@ pub struct Jwt {
}

impl UserAuthStrategy for Jwt {
fn authenticate(
&self,
context: Result<UserAuthContext, AuthError>,
) -> Result<Authenticated, AuthError> {
fn authenticate(&self, ctx: UserAuthContext) -> Result<Authenticated, AuthError> {
tracing::trace!("executing jwt auth");
let auth_str = ctx
.get_field(AUTH_HEADER)
.or_else(|| ctx.get_field(GRPC_AUTH_HEADER))
.ok_or_else(|| AuthError::AuthHeaderNotFound)?;

let ctx = context?;

let UserAuthContext {
scheme: Some(scheme),
token: Some(token),
} = ctx
else {
return Err(AuthError::HttpAuthHeaderInvalid);
};
let (scheme, token) = auth_str
.split_once(' ')
.ok_or(AuthError::AuthStringMalformed)?;

if !scheme.eq_ignore_ascii_case("bearer") {
return Err(AuthError::HttpAuthHeaderUnsupportedScheme);
}

validate_any_jwt(&self.keys, &token)
}

fn required_fields(&self) -> Vec<&'static str> {
vec![AUTH_HEADER, GRPC_AUTH_HEADER]
}
}

impl Jwt {
Expand Down Expand Up @@ -190,7 +193,7 @@ mod tests {
};
let token = encode(&token, &enc);

let context = Ok(UserAuthContext::bearer(token.as_str()));
let context = UserAuthContext::bearer(token.as_str());

assert!(matches!(
strategy(dec).authenticate(context).unwrap(),
Expand All @@ -212,7 +215,7 @@ mod tests {
};
let token = encode(&token, &enc);

let context = Ok(UserAuthContext::bearer(token.as_str()));
let context = UserAuthContext::bearer(token.as_str());

let Authenticated::Legacy(a) = strategy(dec).authenticate(context).unwrap() else {
panic!()
Expand All @@ -225,7 +228,7 @@ mod tests {
#[test]
fn errors_when_jwt_token_invalid() {
let (_enc, dec) = generate_key_pair();
let context = Ok(UserAuthContext::bearer("abc"));
let context = UserAuthContext::bearer("abc");

assert_eq!(
strategy(dec).authenticate(context).unwrap_err(),
Expand All @@ -245,7 +248,7 @@ mod tests {

let token = encode(&token, &enc);

let context = Ok(UserAuthContext::bearer(token.as_str()));
let context = UserAuthContext::bearer(token.as_str());

assert_eq!(
strategy(dec).authenticate(context).unwrap_err(),
Expand All @@ -267,7 +270,7 @@ mod tests {

let token = encode(&token, &enc);

let context = Ok(UserAuthContext::bearer(token.as_str()));
let context = UserAuthContext::bearer(token.as_str());

let Authenticated::Authorized(a) = strategy(dec).authenticate(context).unwrap() else {
panic!()
Expand Down Expand Up @@ -304,7 +307,7 @@ mod tests {
for enc in multi_enc.iter() {
let token = encode(&token, &enc);

let context = Ok(UserAuthContext::bearer(token.as_str()));
let context = UserAuthContext::bearer(token.as_str());

let Authenticated::Authorized(a) = strategy.authenticate(context).unwrap() else {
panic!()
Expand All @@ -331,7 +334,7 @@ mod tests {
});
let token = encode(&token, &enc);

let context = Ok(UserAuthContext::bearer(token.as_str()));
let context = UserAuthContext::bearer(token.as_str());

assert_eq!(
strategy_with_multiple(multi_dec)
Expand All @@ -352,7 +355,7 @@ mod tests {
};
let token = encode(&token, &multi_enc[0]);

let context = Ok(UserAuthContext::bearer(token.as_str()));
let context = UserAuthContext::bearer(token.as_str());

assert_eq!(
strategy_with_multiple(multi_dec)
Expand All @@ -373,7 +376,7 @@ mod tests {
};
let token = encode(&token, &multi_enc[2]);

let context = Ok(UserAuthContext::bearer(token.as_str()));
let context = UserAuthContext::bearer(token.as_str());

assert_eq!(
strategy_with_multiple(multi_dec)
Expand Down
Loading

0 comments on commit ec9aa56

Please sign in to comment.