diff --git a/libsql-server/src/auth/mod.rs b/libsql-server/src/auth/mod.rs index 871c9e96d4..c89725c570 100644 --- a/libsql-server/src/auth/mod.rs +++ b/libsql-server/src/auth/mod.rs @@ -15,7 +15,7 @@ pub use parsers::{parse_http_auth_header, parse_http_basic_auth_arg, parse_jwt_k pub use permission::Permission; pub use user_auth_strategies::{Disabled, HttpBasic, Jwt, UserAuthContext, UserAuthStrategy}; -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct Auth { pub user_strategy: Arc, } diff --git a/libsql-server/src/auth/user_auth_strategies/disabled.rs b/libsql-server/src/auth/user_auth_strategies/disabled.rs index ef9aae9062..b8dd85859b 100644 --- a/libsql-server/src/auth/user_auth_strategies/disabled.rs +++ b/libsql-server/src/auth/user_auth_strategies/disabled.rs @@ -1,6 +1,7 @@ use super::{UserAuthContext, UserAuthStrategy}; use crate::auth::{AuthError, Authenticated}; +#[derive(Debug)] pub struct Disabled {} impl UserAuthStrategy for Disabled { diff --git a/libsql-server/src/auth/user_auth_strategies/http_basic.rs b/libsql-server/src/auth/user_auth_strategies/http_basic.rs index f42605d92c..f0cfbbf202 100644 --- a/libsql-server/src/auth/user_auth_strategies/http_basic.rs +++ b/libsql-server/src/auth/user_auth_strategies/http_basic.rs @@ -5,6 +5,7 @@ use crate::auth::{ use super::{UserAuthContext, UserAuthStrategy}; +#[derive(Debug)] pub struct HttpBasic { credential: String, } diff --git a/libsql-server/src/auth/user_auth_strategies/jwt.rs b/libsql-server/src/auth/user_auth_strategies/jwt.rs index e55b18b30d..9c6893d1d7 100644 --- a/libsql-server/src/auth/user_auth_strategies/jwt.rs +++ b/libsql-server/src/auth/user_auth_strategies/jwt.rs @@ -1,3 +1,5 @@ +use std::fmt::{self, Debug, Formatter}; + use chrono::{DateTime, Utc}; use crate::{ @@ -15,6 +17,12 @@ pub struct Jwt { keys: Vec, } +impl Debug for Jwt { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("Jwt").finish() + } +} + impl UserAuthStrategy for Jwt { fn authenticate(&self, ctx: UserAuthContext) -> Result { tracing::trace!("executing jwt auth"); diff --git a/libsql-server/src/auth/user_auth_strategies/mod.rs b/libsql-server/src/auth/user_auth_strategies/mod.rs index 7223a587f3..1ee2534c52 100644 --- a/libsql-server/src/auth/user_auth_strategies/mod.rs +++ b/libsql-server/src/auth/user_auth_strategies/mod.rs @@ -55,7 +55,7 @@ impl UserAuthContext { } } -pub trait UserAuthStrategy: Sync + Send { +pub trait UserAuthStrategy: Sync + Send + std::fmt::Debug { /// Returns a list of fields required by the stragegy. /// Every strategy implementation should override this function if it requires input to work. /// Strategy implementations should validate the content of provided fields. diff --git a/libsql-server/src/lib.rs b/libsql-server/src/lib.rs index dd7ea2bffe..fed533832c 100644 --- a/libsql-server/src/lib.rs +++ b/libsql-server/src/lib.rs @@ -129,11 +129,12 @@ pub(crate) static BLOCKING_RT: Lazy = Lazy::new(|| { type Result = std::result::Result; type StatsSender = mpsc::Sender<(NamespaceName, MetaStoreHandle, Weak)>; type MakeReplicationSvc = Box< - dyn FnOnce( + dyn Fn( NamespaceStore, Option, Option, bool, + bool, ) -> BoxReplicationService + Send + 'static, @@ -620,9 +621,10 @@ where let replication_service = make_replication_svc( namespace_store.clone(), - None, + Some(user_auth_strategy.clone()), idle_shutdown_kicker.clone(), false, + true, ); task_manager.spawn_until_shutdown(run_rpc_server( @@ -630,7 +632,7 @@ where config.acceptor, config.tls_config, idle_shutdown_kicker.clone(), - replication_service, + replication_service, // internal replicaton service )); } @@ -658,12 +660,12 @@ where .await?; } - let replication_svc = ReplicationLogService::new( + let replication_svc = make_replication_svc( namespace_store.clone(), - idle_shutdown_kicker.clone(), Some(user_auth_strategy.clone()), - self.disable_namespaces, + idle_shutdown_kicker.clone(), true, + false, // external replication service ); let proxy_svc = ProxyService::new( @@ -936,9 +938,9 @@ where let make_replication_svc = Box::new({ let registry = registry.clone(); let disable_namespaces = self.disable_namespaces; - move |store, user_auth, _, _| -> BoxReplicationService { + move |store, user_auth, _, _, _| -> BoxReplicationService { Box::new(LibsqlReplicationService::new( - registry, + registry.clone(), store, user_auth, disable_namespaces, @@ -1023,13 +1025,19 @@ where let make_replication_svc = Box::new({ let disable_namespaces = self.disable_namespaces; - move |store, client_auth, idle_shutdown, collect_stats| -> BoxReplicationService { + move |store, + client_auth, + idle_shutdown, + collect_stats, + is_internal| + -> BoxReplicationService { Box::new(ReplicationLogService::new( store, idle_shutdown, client_auth, disable_namespaces, collect_stats, + is_internal, )) } }); @@ -1055,13 +1063,19 @@ where let make_replication_svc = Box::new({ let disable_namespaces = self.disable_namespaces; - move |store, client_auth, idle_shutdown, collect_stats| -> BoxReplicationService { + move |store, + client_auth, + idle_shutdown, + collect_stats, + is_internal| + -> BoxReplicationService { Box::new(ReplicationLogService::new( store, idle_shutdown, client_auth, disable_namespaces, collect_stats, + is_internal, )) } }); diff --git a/libsql-server/src/rpc/replication/replication_log.rs b/libsql-server/src/rpc/replication/replication_log.rs index b2cb1d1bfc..356d3d0f16 100644 --- a/libsql-server/src/rpc/replication/replication_log.rs +++ b/libsql-server/src/rpc/replication/replication_log.rs @@ -37,6 +37,9 @@ pub struct ReplicationLogService { disable_namespaces: bool, session_token: Bytes, collect_stats: bool, + // whether this is an internal service. If it is an internal service, auth is checked for + // proxied requests + service_internal: bool, //deprecated: generation_id: Uuid, @@ -52,6 +55,7 @@ impl ReplicationLogService { user_auth_strategy: Option, disable_namespaces: bool, collect_stats: bool, + service_internal: bool, ) -> Self { let session_token = Uuid::new_v4().to_string().into(); Self { @@ -63,6 +67,7 @@ impl ReplicationLogService { collect_stats, generation_id: Uuid::new_v4(), replicas_with_hello: Default::default(), + service_internal, } } @@ -71,14 +76,20 @@ impl ReplicationLogService { req: &tonic::Request, namespace: NamespaceName, ) -> Result<(), Status> { - super::auth::authenticate( - &self.namespaces, - req, - namespace, - &self.user_auth_strategy, - true, - ) - .await + if self.service_internal && req.metadata().get("libsql-proxied").is_some() + || !self.service_internal + { + super::auth::authenticate( + &self.namespaces, + req, + namespace, + &self.user_auth_strategy, + true, + ) + .await + } else { + Ok(()) + } } fn verify_session_token( diff --git a/libsql-server/src/rpc/replication/replication_log_proxy.rs b/libsql-server/src/rpc/replication/replication_log_proxy.rs index 30cad39915..2a7f80ca06 100644 --- a/libsql-server/src/rpc/replication/replication_log_proxy.rs +++ b/libsql-server/src/rpc/replication/replication_log_proxy.rs @@ -1,4 +1,5 @@ use hyper::Uri; +use tonic::metadata::AsciiMetadataValue; use tonic::{transport::Channel, Status}; use super::replication_log::rpc::replication_log_client::ReplicationLogClient; @@ -19,6 +20,12 @@ impl ReplicationLogProxyService { } } +fn mark_proxied(mut req: tonic::Request) -> tonic::Request { + req.metadata_mut() + .insert("libsql-proxied", AsciiMetadataValue::from_static("true")); + req +} + #[tonic::async_trait] impl ReplicationLog for ReplicationLogProxyService { type LogEntriesStream = tonic::codec::Streaming; @@ -29,7 +36,7 @@ impl ReplicationLog for ReplicationLogProxyService { req: tonic::Request, ) -> Result, Status> { let mut client = self.client.clone(); - client.log_entries(req).await + client.log_entries(mark_proxied(req)).await } async fn batch_log_entries( @@ -37,7 +44,7 @@ impl ReplicationLog for ReplicationLogProxyService { req: tonic::Request, ) -> Result, Status> { let mut client = self.client.clone(); - client.batch_log_entries(req).await + client.batch_log_entries(mark_proxied(req)).await } async fn hello( @@ -45,7 +52,7 @@ impl ReplicationLog for ReplicationLogProxyService { req: tonic::Request, ) -> Result, Status> { let mut client = self.client.clone(); - client.hello(req).await + client.hello(mark_proxied(req)).await } async fn snapshot( @@ -53,6 +60,6 @@ impl ReplicationLog for ReplicationLogProxyService { req: tonic::Request, ) -> Result, Status> { let mut client = self.client.clone(); - client.snapshot(req).await + client.snapshot(mark_proxied(req)).await } } diff --git a/libsql-server/tests/embedded_replica/mod.rs b/libsql-server/tests/embedded_replica/mod.rs index 2d5f8c0de0..bf5296ecb3 100644 --- a/libsql-server/tests/embedded_replica/mod.rs +++ b/libsql-server/tests/embedded_replica/mod.rs @@ -6,11 +6,15 @@ use std::path::PathBuf; use std::sync::Arc; use std::time::{Duration, Instant}; +use crate::common::auth::encode; use crate::common::http::Client; use crate::common::net::{init_tracing, SimServer, TestServer, TurmoilAcceptor, TurmoilConnector}; -use crate::common::snapshot_metrics; +use crate::common::{self, snapshot_metrics}; use libsql::Database; -use libsql_server::config::{AdminApiConfig, DbConfig, RpcServerConfig, UserApiConfig}; +use libsql_server::auth::{user_auth_strategies, Auth}; +use libsql_server::config::{ + AdminApiConfig, DbConfig, RpcClientConfig, RpcServerConfig, UserApiConfig, +}; use serde_json::json; use tempfile::tempdir; use tokio::sync::Notify; @@ -1362,3 +1366,149 @@ fn replicated_return() { sim.run().unwrap(); } + +#[test] +fn replicate_auth() { + init_tracing(); + let mut sim = Builder::new() + .simulation_duration(Duration::from_secs(1000)) + .build(); + + let (encoding, decoding) = common::auth::key_pair(); + sim.host("primary", { + let decoding = decoding.clone(); + move || { + let decoding = decoding.clone(); + async move { + let tmp = tempdir()?; + let jwt_keys = + vec![jsonwebtoken::DecodingKey::from_ed_components(&decoding).unwrap()]; + let auth = Auth::new(user_auth_strategies::Jwt::new(jwt_keys)); + let server = TestServer { + path: tmp.path().to_owned().into(), + user_api_config: UserApiConfig { + hrana_ws_acceptor: None, + auth_strategy: auth, + ..Default::default() + }, + admin_api_config: Some(AdminApiConfig { + acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 9090)).await?, + connector: TurmoilConnector, + disable_metrics: true, + }), + rpc_server_config: Some(RpcServerConfig { + acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 4567)).await?, + tls_config: None, + }), + ..Default::default() + }; + + server.start_sim(8080).await?; + + Ok(()) + } + } + }); + + sim.host("replica", { + let decoding = decoding.clone(); + move || { + let decoding = decoding.clone(); + async move { + let tmp = tempdir()?; + let jwt_keys = + vec![jsonwebtoken::DecodingKey::from_ed_components(&decoding).unwrap()]; + let auth = Auth::new(user_auth_strategies::Jwt::new(jwt_keys)); + let server = TestServer { + path: tmp.path().to_owned().into(), + user_api_config: UserApiConfig { + hrana_ws_acceptor: None, + auth_strategy: auth, + ..Default::default() + }, + admin_api_config: Some(AdminApiConfig { + acceptor: TurmoilAcceptor::bind(([0, 0, 0, 0], 9090)).await?, + connector: TurmoilConnector, + disable_metrics: true, + }), + rpc_client_config: Some(RpcClientConfig { + remote_url: "http://primary:4567".into(), + connector: TurmoilConnector, + tls_config: None, + }), + ..Default::default() + }; + + server.start_sim(8080).await?; + + Ok(()) + } + } + }); + + sim.client("client", async move { + let token = encode( + &serde_json::json!({ + "id": "default", + }), + &encoding, + ); + + // no auth + let tmp = tempdir().unwrap(); + let db = Database::open_with_remote_sync_connector( + tmp.path().join("embedded").to_str().unwrap(), + "http://primary:8080", + "", + TurmoilConnector, + false, + None, + ) + .await?; + + assert!(db.sync().await.is_err()); + + let tmp = tempdir().unwrap(); + let db = Database::open_with_remote_sync_connector( + tmp.path().join("embedded").to_str().unwrap(), + "http://replica:8080", + "", + TurmoilConnector, + false, + None, + ) + .await?; + + assert!(db.sync().await.is_err()); + + // auth + let tmp = tempdir().unwrap(); + let db = Database::open_with_remote_sync_connector( + tmp.path().join("embedded").to_str().unwrap(), + "http://primary:8080", + token.clone(), + TurmoilConnector, + false, + None, + ) + .await?; + + assert!(db.sync().await.is_ok()); + + let tmp = tempdir().unwrap(); + let db = Database::open_with_remote_sync_connector( + tmp.path().join("embedded").to_str().unwrap(), + "http://replica:8080", + token.clone(), + TurmoilConnector, + false, + None, + ) + .await?; + + assert!(db.sync().await.is_ok()); + Ok(()) + }); + + sim.run().unwrap(); +}