Skip to content

Commit

Permalink
Merge pull request #1708 from tursodatabase/test-replication-proxy
Browse files Browse the repository at this point in the history
test replication proxy
  • Loading branch information
MarinPostma authored Aug 29, 2024
2 parents e94f7bd + 2b25748 commit eb1b39e
Show file tree
Hide file tree
Showing 9 changed files with 218 additions and 26 deletions.
2 changes: 1 addition & 1 deletion libsql-server/src/auth/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ pub use parsers::{parse_http_auth_header, parse_http_basic_auth_arg, parse_jwt_k
pub use permission::Permission;
pub use user_auth_strategies::{Disabled, HttpBasic, Jwt, UserAuthContext, UserAuthStrategy};

#[derive(Clone)]
#[derive(Clone, Debug)]
pub struct Auth {
pub user_strategy: Arc<dyn UserAuthStrategy + Send + Sync>,
}
Expand Down
1 change: 1 addition & 0 deletions libsql-server/src/auth/user_auth_strategies/disabled.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use super::{UserAuthContext, UserAuthStrategy};
use crate::auth::{AuthError, Authenticated};

#[derive(Debug)]
pub struct Disabled {}

impl UserAuthStrategy for Disabled {
Expand Down
1 change: 1 addition & 0 deletions libsql-server/src/auth/user_auth_strategies/http_basic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::auth::{

use super::{UserAuthContext, UserAuthStrategy};

#[derive(Debug)]
pub struct HttpBasic {
credential: String,
}
Expand Down
8 changes: 8 additions & 0 deletions libsql-server/src/auth/user_auth_strategies/jwt.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::fmt::{self, Debug, Formatter};

use chrono::{DateTime, Utc};

use crate::{
Expand All @@ -15,6 +17,12 @@ pub struct Jwt {
keys: Vec<jsonwebtoken::DecodingKey>,
}

impl Debug for Jwt {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("Jwt").finish()
}
}

impl UserAuthStrategy for Jwt {
fn authenticate(&self, ctx: UserAuthContext) -> Result<Authenticated, AuthError> {
tracing::trace!("executing jwt auth");
Expand Down
2 changes: 1 addition & 1 deletion libsql-server/src/auth/user_auth_strategies/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ impl UserAuthContext {
}
}

pub trait UserAuthStrategy: Sync + Send {
pub trait UserAuthStrategy: Sync + Send + std::fmt::Debug {
/// Returns a list of fields required by the stragegy.
/// Every strategy implementation should override this function if it requires input to work.
/// Strategy implementations should validate the content of provided fields.
Expand Down
34 changes: 24 additions & 10 deletions libsql-server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,12 @@ pub(crate) static BLOCKING_RT: Lazy<Runtime> = Lazy::new(|| {
type Result<T, E = Error> = std::result::Result<T, E>;
type StatsSender = mpsc::Sender<(NamespaceName, MetaStoreHandle, Weak<Stats>)>;
type MakeReplicationSvc = Box<
dyn FnOnce(
dyn Fn(
NamespaceStore,
Option<Auth>,
Option<IdleShutdownKicker>,
bool,
bool,
) -> BoxReplicationService
+ Send
+ 'static,
Expand Down Expand Up @@ -620,17 +621,18 @@ where

let replication_service = make_replication_svc(
namespace_store.clone(),
None,
Some(user_auth_strategy.clone()),
idle_shutdown_kicker.clone(),
false,
true,
);

task_manager.spawn_until_shutdown(run_rpc_server(
proxy_service,
config.acceptor,
config.tls_config,
idle_shutdown_kicker.clone(),
replication_service,
replication_service, // internal replicaton service
));
}

Expand Down Expand Up @@ -658,12 +660,12 @@ where
.await?;
}

let replication_svc = ReplicationLogService::new(
let replication_svc = make_replication_svc(
namespace_store.clone(),
idle_shutdown_kicker.clone(),
Some(user_auth_strategy.clone()),
self.disable_namespaces,
idle_shutdown_kicker.clone(),
true,
false, // external replication service
);

let proxy_svc = ProxyService::new(
Expand Down Expand Up @@ -936,9 +938,9 @@ where
let make_replication_svc = Box::new({
let registry = registry.clone();
let disable_namespaces = self.disable_namespaces;
move |store, user_auth, _, _| -> BoxReplicationService {
move |store, user_auth, _, _, _| -> BoxReplicationService {
Box::new(LibsqlReplicationService::new(
registry,
registry.clone(),
store,
user_auth,
disable_namespaces,
Expand Down Expand Up @@ -1023,13 +1025,19 @@ where

let make_replication_svc = Box::new({
let disable_namespaces = self.disable_namespaces;
move |store, client_auth, idle_shutdown, collect_stats| -> BoxReplicationService {
move |store,
client_auth,
idle_shutdown,
collect_stats,
is_internal|
-> BoxReplicationService {
Box::new(ReplicationLogService::new(
store,
idle_shutdown,
client_auth,
disable_namespaces,
collect_stats,
is_internal,
))
}
});
Expand All @@ -1055,13 +1063,19 @@ where

let make_replication_svc = Box::new({
let disable_namespaces = self.disable_namespaces;
move |store, client_auth, idle_shutdown, collect_stats| -> BoxReplicationService {
move |store,
client_auth,
idle_shutdown,
collect_stats,
is_internal|
-> BoxReplicationService {
Box::new(ReplicationLogService::new(
store,
idle_shutdown,
client_auth,
disable_namespaces,
collect_stats,
is_internal,
))
}
});
Expand Down
27 changes: 19 additions & 8 deletions libsql-server/src/rpc/replication/replication_log.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ pub struct ReplicationLogService {
disable_namespaces: bool,
session_token: Bytes,
collect_stats: bool,
// whether this is an internal service. If it is an internal service, auth is checked for
// proxied requests
service_internal: bool,

//deprecated:
generation_id: Uuid,
Expand All @@ -52,6 +55,7 @@ impl ReplicationLogService {
user_auth_strategy: Option<Auth>,
disable_namespaces: bool,
collect_stats: bool,
service_internal: bool,
) -> Self {
let session_token = Uuid::new_v4().to_string().into();
Self {
Expand All @@ -63,6 +67,7 @@ impl ReplicationLogService {
collect_stats,
generation_id: Uuid::new_v4(),
replicas_with_hello: Default::default(),
service_internal,
}
}

Expand All @@ -71,14 +76,20 @@ impl ReplicationLogService {
req: &tonic::Request<T>,
namespace: NamespaceName,
) -> Result<(), Status> {
super::auth::authenticate(
&self.namespaces,
req,
namespace,
&self.user_auth_strategy,
true,
)
.await
if self.service_internal && req.metadata().get("libsql-proxied").is_some()
|| !self.service_internal
{
super::auth::authenticate(
&self.namespaces,
req,
namespace,
&self.user_auth_strategy,
true,
)
.await
} else {
Ok(())
}
}

fn verify_session_token<R>(
Expand Down
15 changes: 11 additions & 4 deletions libsql-server/src/rpc/replication/replication_log_proxy.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use hyper::Uri;
use tonic::metadata::AsciiMetadataValue;
use tonic::{transport::Channel, Status};

use super::replication_log::rpc::replication_log_client::ReplicationLogClient;
Expand All @@ -19,6 +20,12 @@ impl ReplicationLogProxyService {
}
}

fn mark_proxied<T>(mut req: tonic::Request<T>) -> tonic::Request<T> {
req.metadata_mut()
.insert("libsql-proxied", AsciiMetadataValue::from_static("true"));
req
}

#[tonic::async_trait]
impl ReplicationLog for ReplicationLogProxyService {
type LogEntriesStream = tonic::codec::Streaming<Frame>;
Expand All @@ -29,30 +36,30 @@ impl ReplicationLog for ReplicationLogProxyService {
req: tonic::Request<LogOffset>,
) -> Result<tonic::Response<Self::LogEntriesStream>, Status> {
let mut client = self.client.clone();
client.log_entries(req).await
client.log_entries(mark_proxied(req)).await
}

async fn batch_log_entries(
&self,
req: tonic::Request<LogOffset>,
) -> Result<tonic::Response<Frames>, Status> {
let mut client = self.client.clone();
client.batch_log_entries(req).await
client.batch_log_entries(mark_proxied(req)).await
}

async fn hello(
&self,
req: tonic::Request<HelloRequest>,
) -> Result<tonic::Response<HelloResponse>, Status> {
let mut client = self.client.clone();
client.hello(req).await
client.hello(mark_proxied(req)).await
}

async fn snapshot(
&self,
req: tonic::Request<LogOffset>,
) -> Result<tonic::Response<Self::SnapshotStream>, Status> {
let mut client = self.client.clone();
client.snapshot(req).await
client.snapshot(mark_proxied(req)).await
}
}
Loading

0 comments on commit eb1b39e

Please sign in to comment.