Skip to content

Commit

Permalink
Make callback interactions with ClientSessionDelegate async
Browse files Browse the repository at this point in the history
  • Loading branch information
Daniel Salinas authored and Daniel Salinas committed Jan 17, 2025
1 parent 2bd8c56 commit 05c0649
Show file tree
Hide file tree
Showing 9 changed files with 130 additions and 76 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions bindings/matrix-sdk-ffi/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ vergen = { version = "8.1.3", features = ["build", "git", "gitcl"] }
anyhow = { workspace = true }
as_variant = { workspace = true }
async-compat = "0.2.1"
async-trait = { workspace = true }
eyeball-im = { workspace = true }
extension-trait = "1.0.1"
futures-util = { workspace = true }
Expand Down
81 changes: 42 additions & 39 deletions bindings/matrix-sdk-ffi/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use std::{

use anyhow::{anyhow, Context as _};
use matrix_sdk::{
authentication::{ReloadSessionCallback, SaveSessionCallback, SessionCallbackError},
media::{
MediaFileHandle as SdkMediaFileHandle, MediaFormat, MediaRequestParameters,
MediaThumbnailSettings,
Expand Down Expand Up @@ -153,9 +154,11 @@ pub trait ClientDelegate: Sync + Send {
}

#[matrix_sdk_ffi_macros::export(callback_interface)]
#[async_trait::async_trait]
pub trait ClientSessionDelegate: Sync + Send {
fn retrieve_session_from_keychain(&self, user_id: String) -> Result<Session, ClientError>;
fn save_session_in_keychain(&self, session: Session);
async fn retrieve_session_from_keychain(&self, user_id: String)
-> Result<Session, ClientError>;
async fn save_session_in_keychain(&self, session: Session);
}

#[matrix_sdk_ffi_macros::export(callback_interface)]
Expand Down Expand Up @@ -186,6 +189,41 @@ impl From<matrix_sdk::TransmissionProgress> for TransmissionProgress {
}
}

struct FfiReloadSessionCallback {
session_delegate: Arc<dyn ClientSessionDelegate>,
}

#[async_trait::async_trait]
impl ReloadSessionCallback for FfiReloadSessionCallback {
async fn reload_session(
&self,
client: matrix_sdk::Client,
) -> Result<SessionTokens, SessionCallbackError> {
let user_id = client.user_id().context("user isn't logged in")?;
let session =
self.session_delegate.retrieve_session_from_keychain(user_id.to_string()).await?;
let auth_session = TryInto::<AuthSession>::try_into(session)?;
match auth_session {
AuthSession::Oidc(session) => Ok(SessionTokens::Oidc(session.user.tokens)),
AuthSession::Matrix(session) => Ok(SessionTokens::Matrix(session.tokens)),
_ => Err(anyhow!("unsupported session type").into()),
}
}
}

struct FfiSaveSessionCallback {
session_delegate: Arc<dyn ClientSessionDelegate>,
}

#[async_trait::async_trait]
impl SaveSessionCallback for FfiSaveSessionCallback {
async fn save_session(&self, client: matrix_sdk::Client) -> Result<(), SessionCallbackError> {
let session = Client::session_inner(client)?;
self.session_delegate.save_session_in_keychain(session).await;
Ok(())
}
}

#[derive(uniffi::Object)]
pub struct Client {
pub(crate) inner: AsyncRuntimeDropped<MatrixClient>,
Expand Down Expand Up @@ -238,21 +276,8 @@ impl Client {

if let Some(session_delegate) = session_delegate {
client.inner.set_session_callbacks(
{
let session_delegate = session_delegate.clone();
Box::new(move |client| {
let session_delegate = session_delegate.clone();
let user_id = client.user_id().context("user isn't logged in")?;
Ok(Self::retrieve_session(session_delegate, user_id)?)
})
},
{
let session_delegate = session_delegate.clone();
Box::new(move |client| {
let session_delegate = session_delegate.clone();
Ok(Self::save_session(session_delegate, client)?)
})
},
Box::new(FfiReloadSessionCallback { session_delegate: session_delegate.clone() }),
Box::new(FfiSaveSessionCallback { session_delegate: session_delegate.clone() }),
)?;
}

Expand Down Expand Up @@ -1247,19 +1272,6 @@ impl Client {
}
}

fn retrieve_session(
session_delegate: Arc<dyn ClientSessionDelegate>,
user_id: &UserId,
) -> anyhow::Result<SessionTokens> {
let session = session_delegate.retrieve_session_from_keychain(user_id.to_string())?;
let auth_session = TryInto::<AuthSession>::try_into(session)?;
match auth_session {
AuthSession::Oidc(session) => Ok(SessionTokens::Oidc(session.user.tokens)),
AuthSession::Matrix(session) => Ok(SessionTokens::Matrix(session.tokens)),
_ => anyhow::bail!("Unexpected session kind."),
}
}

fn session_inner(client: matrix_sdk::Client) -> Result<Session, ClientError> {
let auth_api = client.auth_api().context("Missing authentication API")?;

Expand All @@ -1268,15 +1280,6 @@ impl Client {

Session::new(auth_api, homeserver_url, sliding_sync_version.into())
}

fn save_session(
session_delegate: Arc<dyn ClientSessionDelegate>,
client: matrix_sdk::Client,
) -> anyhow::Result<()> {
let session = Self::session_inner(client)?;
session_delegate.save_session_in_keychain(session);
Ok(())
}
}

#[derive(uniffi::Record)]
Expand Down
23 changes: 16 additions & 7 deletions crates/matrix-sdk/src/authentication/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,21 @@ pub enum SessionTokens {
Oidc(oidc::OidcSessionTokens),
}

pub(crate) type SessionCallbackError = Box<dyn std::error::Error + Send + Sync>;
pub(crate) type SaveSessionCallback =
dyn Fn(Client) -> Result<(), SessionCallbackError> + Send + Sync;
pub(crate) type ReloadSessionCallback =
dyn Fn(Client) -> Result<SessionTokens, SessionCallbackError> + Send + Sync;
/// An error that results from setting the session callback.
pub type SessionCallbackError = Box<dyn std::error::Error + Send + Sync>;
/// Save the session tokens from the source of truth.
#[async_trait::async_trait]
pub trait SaveSessionCallback: Send + Sync {
/// Save the session tokens from the source of truth.
async fn save_session(&self, client: Client) -> Result<(), SessionCallbackError>;
}

/// Reload the session tokens from the source of truth.
#[async_trait::async_trait]
pub trait ReloadSessionCallback: Send + Sync {
/// Reload the session tokens from the source of truth.
async fn reload_session(&self, client: Client) -> Result<SessionTokens, SessionCallbackError>;
}
/// All the data relative to authentication, and that must be shared between a
/// client and all its children.
pub(crate) struct AuthCtx {
Expand All @@ -74,7 +83,7 @@ pub(crate) struct AuthCtx {
/// current session tokens.
///
/// This is required only in multiple processes setups.
pub(crate) reload_session_callback: OnceCell<Box<ReloadSessionCallback>>,
pub(crate) reload_session_callback: OnceCell<Box<dyn ReloadSessionCallback>>,

/// A callback to save a session back into the app's secure storage.
///
Expand All @@ -83,7 +92,7 @@ pub(crate) struct AuthCtx {
///
/// Internal invariant: this must be called only after `set_session_tokens`
/// has been called, not before.
pub(crate) save_session_callback: OnceCell<Box<SaveSessionCallback>>,
pub(crate) save_session_callback: OnceCell<Box<dyn SaveSessionCallback>>,
}

/// An enum over all the possible authentication APIs.
Expand Down
4 changes: 2 additions & 2 deletions crates/matrix-sdk/src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2370,8 +2370,8 @@ impl Client {
/// while [`Self::subscribe_to_session_changes`] provides an async update.
pub fn set_session_callbacks(
&self,
reload_session_callback: Box<ReloadSessionCallback>,
save_session_callback: Box<SaveSessionCallback>,
reload_session_callback: Box<dyn ReloadSessionCallback>,
save_session_callback: Box<dyn SaveSessionCallback>,
) -> Result<()> {
self.inner
.auth_ctx
Expand Down
3 changes: 2 additions & 1 deletion crates/matrix-sdk/src/matrix_auth/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,8 @@ impl MatrixAuth {
if let Some(save_session_callback) =
self.client.inner.auth_ctx.save_session_callback.get()
{
if let Err(err) = save_session_callback(self.client.clone()) {
if let Err(err) = save_session_callback.save_session(self.client.clone()).await
{
error!("when saving session after refresh: {err}");
}
}
Expand Down
51 changes: 33 additions & 18 deletions crates/matrix-sdk/src/oidc/cross_process.rs
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ mod tests {

use super::compute_session_hash;
use crate::{
authentication::{ReloadSessionCallback, SaveSessionCallback, SessionCallbackError},
oidc::{
backend::mock::{MockImpl, ISSUER_URL},
cross_process::SessionHash,
Expand All @@ -275,6 +276,32 @@ mod tests {
Error,
};

struct TestReloadSessionCallback {
tokens: OidcSessionTokens,
}

#[async_trait::async_trait]
impl ReloadSessionCallback for TestReloadSessionCallback {
async fn reload_session(
&self,
client: matrix_sdk::Client,
) -> Result<SessionTokens, SessionCallbackError> {
crate::authentication::SessionTokens::Oidc(self.tokens.clone())
}
}

struct TestSaveSessionCallback {}

#[async_trait::async_trait]
impl SaveSessionCallback for TestSaveSessionCallback {
async fn save_session(
&self,
client: matrix_sdk::Client,
) -> Result<(), SessionCallbackError> {
panic!("save_session_callback shouldn't be called here")
}
}

#[async_test]
async fn test_restore_session_lock() -> Result<(), Error> {
// Create a client that will use sqlite databases.
Expand All @@ -295,12 +322,8 @@ mod tests {
client.oidc().enable_cross_process_refresh_lock("test".to_owned()).await?;

client.set_session_callbacks(
Box::new({
// This is only called because of extra checks in the code.
let tokens = tokens.clone();
move |_| Ok(crate::authentication::SessionTokens::Oidc(tokens.clone()))
}),
Box::new(|_| panic!("save_session_callback shouldn't be called here")),
Box::new(TestReloadSessionCallback { tokens }),
Box::new(TestSaveSessionCallback {}),
)?;

let session_hash = compute_session_hash(&tokens);
Expand Down Expand Up @@ -521,12 +544,8 @@ mod tests {
let oidc = unrestored_oidc;

unrestored_client.set_session_callbacks(
Box::new({
// This is only called because of extra checks in the code.
let tokens = next_tokens.clone();
move |_| Ok(crate::authentication::SessionTokens::Oidc(tokens.clone()))
}),
Box::new(|_| panic!("save_session_callback shouldn't be called here")),
Box::new(TestReloadSessionCallback { tokens: next_tokens }),
Box::new(TestSaveSessionCallback {}),
)?;

oidc.restore_session(tests::mock_session(prev_tokens.clone())).await?;
Expand Down Expand Up @@ -558,12 +577,8 @@ mod tests {
}

client.set_session_callbacks(
Box::new({
// This is only called because of extra checks in the code.
let tokens = next_tokens.clone();
move |_| Ok(crate::authentication::SessionTokens::Oidc(tokens.clone()))
}),
Box::new(|_| panic!("save_session_callback shouldn't be called here")),
Box::new(TestReloadSessionCallback { tokens: next_tokens }),
Box::new(TestSaveSessionCallback {}),
)?;

oidc.refresh_access_token().await?;
Expand Down
4 changes: 2 additions & 2 deletions crates/matrix-sdk/src/oidc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1015,7 +1015,7 @@ impl Oidc {
.get()
.ok_or(CrossProcessRefreshLockError::MissingReloadSession)?;

match callback(self.client.clone()) {
match callback.reload_session(self.client.clone()).await {
Ok(tokens) => {
let crate::authentication::SessionTokens::Oidc(tokens) = tokens else {
return Err(CrossProcessRefreshLockError::InvalidSessionTokens);
Expand Down Expand Up @@ -1337,7 +1337,7 @@ impl Oidc {
{
// Satisfies the save_session_callback invariant: set_session_tokens has
// been called just above.
if let Err(err) = save_session_callback(self.client.clone()) {
if let Err(err) = save_session_callback.save_session(self.client.clone()).await {
error!("when saving session after refresh: {err}");
}
}
Expand Down
38 changes: 31 additions & 7 deletions crates/matrix-sdk/tests/integration/refresh_token.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ use assert_matches::assert_matches;
use assert_matches2::assert_let;
use futures_util::StreamExt;
use matrix_sdk::{
authentication::{
ReloadSessionCallback, SaveSessionCallback, SessionCallbackError, SessionTokens,
},
config::RequestConfig,
executor::spawn,
matrix_auth::{MatrixSession, MatrixSessionTokens},
Expand Down Expand Up @@ -163,6 +166,30 @@ async fn test_no_refresh_token() {
assert_matches!(res, Err(RefreshTokenError::RefreshTokenRequired));
}

struct TestReloadSessionCallback {}

#[async_trait::async_trait]
impl ReloadSessionCallback for TestReloadSessionCallback {
async fn reload_session(
&self,
_client: matrix_sdk::Client,
) -> Result<SessionTokens, SessionCallbackError> {
panic!("reload session never called")
}
}

struct TestSaveSessionCallback {
counter: Arc<Mutex<u64>>,
}

#[async_trait::async_trait]
impl SaveSessionCallback for TestSaveSessionCallback {
async fn save_session(&self, _client: matrix_sdk::Client) -> Result<(), SessionCallbackError> {
*self.counter.lock().unwrap() += 1;
Ok(())
}
}

#[async_test]
async fn test_refresh_token() {
let (builder, server) = test_client_builder_with_server().await;
Expand All @@ -176,13 +203,10 @@ async fn test_refresh_token() {

let num_save_session_callback_calls = Arc::new(Mutex::new(0));
client
.set_session_callbacks(Box::new(|_| panic!("reload session never called")), {
let num_save_session_callback_calls = num_save_session_callback_calls.clone();
Box::new(move |_client| {
*num_save_session_callback_calls.lock().unwrap() += 1;
Ok(())
})
})
.set_session_callbacks(
Box::new(TestReloadSessionCallback {}),
Box::new(TestSaveSessionCallback { counter: num_save_session_callback_calls.clone() }),
)
.unwrap();

let mut session_changes = client.subscribe_to_session_changes();
Expand Down

0 comments on commit 05c0649

Please sign in to comment.