diff --git a/Cargo.lock b/Cargo.lock index 03552059e60..0574e915bca 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3297,6 +3297,7 @@ dependencies = [ "anyhow", "as_variant", "async-compat", + "async-trait", "extension-trait", "eyeball-im", "futures-util", diff --git a/bindings/matrix-sdk-ffi/Cargo.toml b/bindings/matrix-sdk-ffi/Cargo.toml index b771e6f5759..f4a9cc28675 100644 --- a/bindings/matrix-sdk-ffi/Cargo.toml +++ b/bindings/matrix-sdk-ffi/Cargo.toml @@ -24,6 +24,7 @@ vergen = { version = "8.1.3", features = ["build", "git", "gitcl"] } anyhow = { workspace = true } as_variant = { workspace = true } async-compat = "0.2.1" +async-trait = { workspace = true } eyeball-im = { workspace = true } extension-trait = "1.0.1" futures-util = { workspace = true } diff --git a/bindings/matrix-sdk-ffi/src/client.rs b/bindings/matrix-sdk-ffi/src/client.rs index af8ddb4389e..fb4a98559c2 100644 --- a/bindings/matrix-sdk-ffi/src/client.rs +++ b/bindings/matrix-sdk-ffi/src/client.rs @@ -7,6 +7,7 @@ use std::{ use anyhow::{anyhow, Context as _}; use matrix_sdk::{ + authentication::{ReloadSessionCallback, SaveSessionCallback, SessionCallbackError}, media::{ MediaFileHandle as SdkMediaFileHandle, MediaFormat, MediaRequestParameters, MediaThumbnailSettings, @@ -153,9 +154,13 @@ pub trait ClientDelegate: Sync + Send { } #[matrix_sdk_ffi_macros::export(callback_interface)] +#[async_trait::async_trait] pub trait ClientSessionDelegate: Sync + Send { - fn retrieve_session_from_keychain(&self, user_id: String) -> Result; - fn save_session_in_keychain(&self, session: Session); + async fn retrieve_session_from_keychain<'a>( + &'a self, + user_id: String, + ) -> Result; + async fn save_session_in_keychain<'a>(&'a self, session: Session); } #[matrix_sdk_ffi_macros::export(callback_interface)] @@ -186,6 +191,41 @@ impl From for TransmissionProgress { } } +struct FfiReloadSessionCallback { + session_delegate: Arc, +} + +#[async_trait::async_trait] +impl ReloadSessionCallback for FfiReloadSessionCallback { + async fn reload_session( + &self, + client: matrix_sdk::Client, + ) -> Result { + let user_id = client.user_id().context("user isn't logged in")?; + let session = + self.session_delegate.retrieve_session_from_keychain(user_id.to_string()).await?; + let auth_session = TryInto::::try_into(session)?; + match auth_session { + AuthSession::Oidc(session) => Ok(SessionTokens::Oidc(session.user.tokens)), + AuthSession::Matrix(session) => Ok(SessionTokens::Matrix(session.tokens)), + _ => Err(anyhow!("unsupported session type").into()), + } + } +} + +struct FfiSaveSessionCallback { + session_delegate: Arc, +} + +#[async_trait::async_trait] +impl SaveSessionCallback for FfiSaveSessionCallback { + async fn save_session(&self, client: matrix_sdk::Client) -> Result<(), SessionCallbackError> { + let session = Client::session_inner(client)?; + self.session_delegate.save_session_in_keychain(session).await; + Ok(()) + } +} + #[derive(uniffi::Object)] pub struct Client { pub(crate) inner: AsyncRuntimeDropped, @@ -238,21 +278,8 @@ impl Client { if let Some(session_delegate) = session_delegate { client.inner.set_session_callbacks( - { - let session_delegate = session_delegate.clone(); - Box::new(move |client| { - let session_delegate = session_delegate.clone(); - let user_id = client.user_id().context("user isn't logged in")?; - Ok(Self::retrieve_session(session_delegate, user_id)?) - }) - }, - { - let session_delegate = session_delegate.clone(); - Box::new(move |client| { - let session_delegate = session_delegate.clone(); - Ok(Self::save_session(session_delegate, client)?) - }) - }, + Box::new(FfiReloadSessionCallback { session_delegate: session_delegate.clone() }), + Box::new(FfiSaveSessionCallback { session_delegate: session_delegate.clone() }), )?; } @@ -1247,19 +1274,6 @@ impl Client { } } - fn retrieve_session( - session_delegate: Arc, - user_id: &UserId, - ) -> anyhow::Result { - let session = session_delegate.retrieve_session_from_keychain(user_id.to_string())?; - let auth_session = TryInto::::try_into(session)?; - match auth_session { - AuthSession::Oidc(session) => Ok(SessionTokens::Oidc(session.user.tokens)), - AuthSession::Matrix(session) => Ok(SessionTokens::Matrix(session.tokens)), - _ => anyhow::bail!("Unexpected session kind."), - } - } - fn session_inner(client: matrix_sdk::Client) -> Result { let auth_api = client.auth_api().context("Missing authentication API")?; @@ -1268,15 +1282,6 @@ impl Client { Session::new(auth_api, homeserver_url, sliding_sync_version.into()) } - - fn save_session( - session_delegate: Arc, - client: matrix_sdk::Client, - ) -> anyhow::Result<()> { - let session = Self::session_inner(client)?; - session_delegate.save_session_in_keychain(session); - Ok(()) - } } #[derive(uniffi::Record)] diff --git a/crates/matrix-sdk/src/authentication/mod.rs b/crates/matrix-sdk/src/authentication/mod.rs index 408de86805e..6fcba73695c 100644 --- a/crates/matrix-sdk/src/authentication/mod.rs +++ b/crates/matrix-sdk/src/authentication/mod.rs @@ -43,12 +43,21 @@ pub enum SessionTokens { Oidc(oidc::OidcSessionTokens), } -pub(crate) type SessionCallbackError = Box; -pub(crate) type SaveSessionCallback = - dyn Fn(Client) -> Result<(), SessionCallbackError> + Send + Sync; -pub(crate) type ReloadSessionCallback = - dyn Fn(Client) -> Result + Send + Sync; +/// An error that results from setting the session callback. +pub type SessionCallbackError = Box; +/// Save the session tokens from the source of truth. +#[async_trait::async_trait] +pub trait SaveSessionCallback: Send + Sync { + /// Save the session tokens from the source of truth. + async fn save_session(&self, client: Client) -> Result<(), SessionCallbackError>; +} +/// Reload the session tokens from the source of truth. +#[async_trait::async_trait] +pub trait ReloadSessionCallback: Send + Sync { + /// Reload the session tokens from the source of truth. + async fn reload_session(&self, client: Client) -> Result; +} /// All the data relative to authentication, and that must be shared between a /// client and all its children. pub(crate) struct AuthCtx { @@ -74,7 +83,7 @@ pub(crate) struct AuthCtx { /// current session tokens. /// /// This is required only in multiple processes setups. - pub(crate) reload_session_callback: OnceCell>, + pub(crate) reload_session_callback: OnceCell>, /// A callback to save a session back into the app's secure storage. /// @@ -83,7 +92,7 @@ pub(crate) struct AuthCtx { /// /// Internal invariant: this must be called only after `set_session_tokens` /// has been called, not before. - pub(crate) save_session_callback: OnceCell>, + pub(crate) save_session_callback: OnceCell>, } /// An enum over all the possible authentication APIs. diff --git a/crates/matrix-sdk/src/client/mod.rs b/crates/matrix-sdk/src/client/mod.rs index f2daad94aee..b2d29b74eb1 100644 --- a/crates/matrix-sdk/src/client/mod.rs +++ b/crates/matrix-sdk/src/client/mod.rs @@ -2370,8 +2370,8 @@ impl Client { /// while [`Self::subscribe_to_session_changes`] provides an async update. pub fn set_session_callbacks( &self, - reload_session_callback: Box, - save_session_callback: Box, + reload_session_callback: Box, + save_session_callback: Box, ) -> Result<()> { self.inner .auth_ctx diff --git a/crates/matrix-sdk/src/matrix_auth/mod.rs b/crates/matrix-sdk/src/matrix_auth/mod.rs index 8beb76de189..3be80b1dcf4 100644 --- a/crates/matrix-sdk/src/matrix_auth/mod.rs +++ b/crates/matrix-sdk/src/matrix_auth/mod.rs @@ -548,7 +548,8 @@ impl MatrixAuth { if let Some(save_session_callback) = self.client.inner.auth_ctx.save_session_callback.get() { - if let Err(err) = save_session_callback(self.client.clone()) { + if let Err(err) = save_session_callback.save_session(self.client.clone()).await + { error!("when saving session after refresh: {err}"); } } diff --git a/crates/matrix-sdk/src/oidc/cross_process.rs b/crates/matrix-sdk/src/oidc/cross_process.rs index 87c057d1740..97a37d066e7 100644 --- a/crates/matrix-sdk/src/oidc/cross_process.rs +++ b/crates/matrix-sdk/src/oidc/cross_process.rs @@ -264,6 +264,9 @@ mod tests { use super::compute_session_hash; use crate::{ + authentication::{ + ReloadSessionCallback, SaveSessionCallback, SessionCallbackError, SessionTokens, + }, oidc::{ backend::mock::{MockImpl, ISSUER_URL}, cross_process::SessionHash, @@ -272,9 +275,32 @@ mod tests { Oidc, OidcSessionTokens, }, test_utils::test_client_builder, - Error, + Client, Error, }; + struct TestReloadSessionCallback { + tokens: OidcSessionTokens, + } + + #[async_trait::async_trait] + impl ReloadSessionCallback for TestReloadSessionCallback { + async fn reload_session( + &self, + _client: Client, + ) -> Result { + Ok(SessionTokens::Oidc(self.tokens.clone())) + } + } + + struct TestSaveSessionCallback {} + + #[async_trait::async_trait] + impl SaveSessionCallback for TestSaveSessionCallback { + async fn save_session(&self, _client: Client) -> Result<(), SessionCallbackError> { + panic!("save_session_callback shouldn't be called here") + } + } + #[async_test] async fn test_restore_session_lock() -> Result<(), Error> { // Create a client that will use sqlite databases. @@ -295,12 +321,8 @@ mod tests { client.oidc().enable_cross_process_refresh_lock("test".to_owned()).await?; client.set_session_callbacks( - Box::new({ - // This is only called because of extra checks in the code. - let tokens = tokens.clone(); - move |_| Ok(crate::authentication::SessionTokens::Oidc(tokens.clone())) - }), - Box::new(|_| panic!("save_session_callback shouldn't be called here")), + Box::new(TestReloadSessionCallback { tokens: tokens.clone() }), + Box::new(TestSaveSessionCallback {}), )?; let session_hash = compute_session_hash(&tokens); @@ -521,12 +543,8 @@ mod tests { let oidc = unrestored_oidc; unrestored_client.set_session_callbacks( - Box::new({ - // This is only called because of extra checks in the code. - let tokens = next_tokens.clone(); - move |_| Ok(crate::authentication::SessionTokens::Oidc(tokens.clone())) - }), - Box::new(|_| panic!("save_session_callback shouldn't be called here")), + Box::new(TestReloadSessionCallback { tokens: next_tokens.clone() }), + Box::new(TestSaveSessionCallback {}), )?; oidc.restore_session(tests::mock_session(prev_tokens.clone())).await?; @@ -558,12 +576,8 @@ mod tests { } client.set_session_callbacks( - Box::new({ - // This is only called because of extra checks in the code. - let tokens = next_tokens.clone(); - move |_| Ok(crate::authentication::SessionTokens::Oidc(tokens.clone())) - }), - Box::new(|_| panic!("save_session_callback shouldn't be called here")), + Box::new(TestReloadSessionCallback { tokens: next_tokens.clone() }), + Box::new(TestSaveSessionCallback {}), )?; oidc.refresh_access_token().await?; diff --git a/crates/matrix-sdk/src/oidc/mod.rs b/crates/matrix-sdk/src/oidc/mod.rs index 5fb525d75a1..790fe6ca0fe 100644 --- a/crates/matrix-sdk/src/oidc/mod.rs +++ b/crates/matrix-sdk/src/oidc/mod.rs @@ -1015,7 +1015,7 @@ impl Oidc { .get() .ok_or(CrossProcessRefreshLockError::MissingReloadSession)?; - match callback(self.client.clone()) { + match callback.reload_session(self.client.clone()).await { Ok(tokens) => { let crate::authentication::SessionTokens::Oidc(tokens) = tokens else { return Err(CrossProcessRefreshLockError::InvalidSessionTokens); @@ -1337,7 +1337,7 @@ impl Oidc { { // Satisfies the save_session_callback invariant: set_session_tokens has // been called just above. - if let Err(err) = save_session_callback(self.client.clone()) { + if let Err(err) = save_session_callback.save_session(self.client.clone()).await { error!("when saving session after refresh: {err}"); } } diff --git a/crates/matrix-sdk/tests/integration/refresh_token.rs b/crates/matrix-sdk/tests/integration/refresh_token.rs index 6067698c68b..e0b6439490c 100644 --- a/crates/matrix-sdk/tests/integration/refresh_token.rs +++ b/crates/matrix-sdk/tests/integration/refresh_token.rs @@ -7,6 +7,9 @@ use assert_matches::assert_matches; use assert_matches2::assert_let; use futures_util::StreamExt; use matrix_sdk::{ + authentication::{ + ReloadSessionCallback, SaveSessionCallback, SessionCallbackError, SessionTokens, + }, config::RequestConfig, executor::spawn, matrix_auth::{MatrixSession, MatrixSessionTokens}, @@ -163,6 +166,30 @@ async fn test_no_refresh_token() { assert_matches!(res, Err(RefreshTokenError::RefreshTokenRequired)); } +struct TestReloadSessionCallback {} + +#[async_trait::async_trait] +impl ReloadSessionCallback for TestReloadSessionCallback { + async fn reload_session( + &self, + _client: matrix_sdk::Client, + ) -> Result { + panic!("reload session never called") + } +} + +struct TestSaveSessionCallback { + counter: Arc>, +} + +#[async_trait::async_trait] +impl SaveSessionCallback for TestSaveSessionCallback { + async fn save_session(&self, _client: matrix_sdk::Client) -> Result<(), SessionCallbackError> { + *self.counter.lock().unwrap() += 1; + Ok(()) + } +} + #[async_test] async fn test_refresh_token() { let (builder, server) = test_client_builder_with_server().await; @@ -176,13 +203,10 @@ async fn test_refresh_token() { let num_save_session_callback_calls = Arc::new(Mutex::new(0)); client - .set_session_callbacks(Box::new(|_| panic!("reload session never called")), { - let num_save_session_callback_calls = num_save_session_callback_calls.clone(); - Box::new(move |_client| { - *num_save_session_callback_calls.lock().unwrap() += 1; - Ok(()) - }) - }) + .set_session_callbacks( + Box::new(TestReloadSessionCallback {}), + Box::new(TestSaveSessionCallback { counter: num_save_session_callback_calls.clone() }), + ) .unwrap(); let mut session_changes = client.subscribe_to_session_changes();