Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make callback interactions with ClientSessionDelegate async #4549

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions bindings/matrix-sdk-ffi/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ vergen = { version = "8.1.3", features = ["build", "git", "gitcl"] }
anyhow = { workspace = true }
as_variant = { workspace = true }
async-compat = "0.2.1"
async-trait = { workspace = true }
eyeball-im = { workspace = true }
extension-trait = "1.0.1"
futures-util = { workspace = true }
Expand Down
83 changes: 44 additions & 39 deletions bindings/matrix-sdk-ffi/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use std::{

use anyhow::{anyhow, Context as _};
use matrix_sdk::{
authentication::{ReloadSessionCallback, SaveSessionCallback, SessionCallbackError},
media::{
MediaFileHandle as SdkMediaFileHandle, MediaFormat, MediaRequestParameters,
MediaThumbnailSettings,
Expand Down Expand Up @@ -153,9 +154,13 @@ pub trait ClientDelegate: Sync + Send {
}

#[matrix_sdk_ffi_macros::export(callback_interface)]
#[async_trait::async_trait]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need async_trait with our minimum required version of Rust? I'm not sure.

pub trait ClientSessionDelegate: Sync + Send {
fn retrieve_session_from_keychain(&self, user_id: String) -> Result<Session, ClientError>;
fn save_session_in_keychain(&self, session: Session);
async fn retrieve_session_from_keychain<'a>(
&'a self,
user_id: String,
) -> Result<Session, ClientError>;
async fn save_session_in_keychain<'a>(&'a self, session: Session);
}

#[matrix_sdk_ffi_macros::export(callback_interface)]
Expand Down Expand Up @@ -186,6 +191,41 @@ impl From<matrix_sdk::TransmissionProgress> for TransmissionProgress {
}
}

struct FfiReloadSessionCallback {
session_delegate: Arc<dyn ClientSessionDelegate>,
}

#[async_trait::async_trait]
impl ReloadSessionCallback for FfiReloadSessionCallback {
async fn reload_session(
&self,
client: matrix_sdk::Client,
) -> Result<SessionTokens, SessionCallbackError> {
let user_id = client.user_id().context("user isn't logged in")?;
let session =
self.session_delegate.retrieve_session_from_keychain(user_id.to_string()).await?;
let auth_session = TryInto::<AuthSession>::try_into(session)?;
match auth_session {
AuthSession::Oidc(session) => Ok(SessionTokens::Oidc(session.user.tokens)),
AuthSession::Matrix(session) => Ok(SessionTokens::Matrix(session.tokens)),
_ => Err(anyhow!("unsupported session type").into()),
}
}
}

struct FfiSaveSessionCallback {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please, don't use an Ffi prefix for type names. If it collides with the SDK names, you can rename the SDK names while importing names, by adding an Sdk prefix, as we do in other modules of this crate. Thus:

use matrix_sdk::authentication::SaveSessionCallback as SdkSaveSessionCallback;

session_delegate: Arc<dyn ClientSessionDelegate>,
}

#[async_trait::async_trait]
impl SaveSessionCallback for FfiSaveSessionCallback {
async fn save_session(&self, client: matrix_sdk::Client) -> Result<(), SessionCallbackError> {
let session = Client::session_inner(client)?;
self.session_delegate.save_session_in_keychain(session).await;
Ok(())
}
}

#[derive(uniffi::Object)]
pub struct Client {
pub(crate) inner: AsyncRuntimeDropped<MatrixClient>,
Expand Down Expand Up @@ -238,21 +278,8 @@ impl Client {

if let Some(session_delegate) = session_delegate {
client.inner.set_session_callbacks(
{
let session_delegate = session_delegate.clone();
Box::new(move |client| {
let session_delegate = session_delegate.clone();
let user_id = client.user_id().context("user isn't logged in")?;
Ok(Self::retrieve_session(session_delegate, user_id)?)
})
},
{
let session_delegate = session_delegate.clone();
Box::new(move |client| {
let session_delegate = session_delegate.clone();
Ok(Self::save_session(session_delegate, client)?)
})
},
Box::new(FfiReloadSessionCallback { session_delegate: session_delegate.clone() }),
Box::new(FfiSaveSessionCallback { session_delegate: session_delegate.clone() }),
)?;
}

Expand Down Expand Up @@ -1247,19 +1274,6 @@ impl Client {
}
}

fn retrieve_session(
session_delegate: Arc<dyn ClientSessionDelegate>,
user_id: &UserId,
) -> anyhow::Result<SessionTokens> {
let session = session_delegate.retrieve_session_from_keychain(user_id.to_string())?;
let auth_session = TryInto::<AuthSession>::try_into(session)?;
match auth_session {
AuthSession::Oidc(session) => Ok(SessionTokens::Oidc(session.user.tokens)),
AuthSession::Matrix(session) => Ok(SessionTokens::Matrix(session.tokens)),
_ => anyhow::bail!("Unexpected session kind."),
}
}

fn session_inner(client: matrix_sdk::Client) -> Result<Session, ClientError> {
let auth_api = client.auth_api().context("Missing authentication API")?;

Expand All @@ -1268,15 +1282,6 @@ impl Client {

Session::new(auth_api, homeserver_url, sliding_sync_version.into())
}

fn save_session(
session_delegate: Arc<dyn ClientSessionDelegate>,
client: matrix_sdk::Client,
) -> anyhow::Result<()> {
let session = Self::session_inner(client)?;
session_delegate.save_session_in_keychain(session);
Ok(())
}
}

#[derive(uniffi::Record)]
Expand Down
23 changes: 16 additions & 7 deletions crates/matrix-sdk/src/authentication/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,21 @@ pub enum SessionTokens {
Oidc(oidc::OidcSessionTokens),
}

pub(crate) type SessionCallbackError = Box<dyn std::error::Error + Send + Sync>;
pub(crate) type SaveSessionCallback =
dyn Fn(Client) -> Result<(), SessionCallbackError> + Send + Sync;
pub(crate) type ReloadSessionCallback =
dyn Fn(Client) -> Result<SessionTokens, SessionCallbackError> + Send + Sync;
/// An error that results from setting the session callback.
pub type SessionCallbackError = Box<dyn std::error::Error + Send + Sync>;
/// Save the session tokens from the source of truth.
#[async_trait::async_trait]
pub trait SaveSessionCallback: Send + Sync {
/// Save the session tokens from the source of truth.
async fn save_session(&self, client: Client) -> Result<(), SessionCallbackError>;
}

/// Reload the session tokens from the source of truth.
#[async_trait::async_trait]
pub trait ReloadSessionCallback: Send + Sync {
/// Reload the session tokens from the source of truth.
async fn reload_session(&self, client: Client) -> Result<SessionTokens, SessionCallbackError>;
}
/// All the data relative to authentication, and that must be shared between a
/// client and all its children.
pub(crate) struct AuthCtx {
Expand All @@ -74,7 +83,7 @@ pub(crate) struct AuthCtx {
/// current session tokens.
///
/// This is required only in multiple processes setups.
pub(crate) reload_session_callback: OnceCell<Box<ReloadSessionCallback>>,
pub(crate) reload_session_callback: OnceCell<Box<dyn ReloadSessionCallback>>,

/// A callback to save a session back into the app's secure storage.
///
Expand All @@ -83,7 +92,7 @@ pub(crate) struct AuthCtx {
///
/// Internal invariant: this must be called only after `set_session_tokens`
/// has been called, not before.
pub(crate) save_session_callback: OnceCell<Box<SaveSessionCallback>>,
pub(crate) save_session_callback: OnceCell<Box<dyn SaveSessionCallback>>,
}

/// An enum over all the possible authentication APIs.
Expand Down
4 changes: 2 additions & 2 deletions crates/matrix-sdk/src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2370,8 +2370,8 @@ impl Client {
/// while [`Self::subscribe_to_session_changes`] provides an async update.
pub fn set_session_callbacks(
&self,
reload_session_callback: Box<ReloadSessionCallback>,
save_session_callback: Box<SaveSessionCallback>,
reload_session_callback: Box<dyn ReloadSessionCallback>,
save_session_callback: Box<dyn SaveSessionCallback>,
) -> Result<()> {
self.inner
.auth_ctx
Expand Down
3 changes: 2 additions & 1 deletion crates/matrix-sdk/src/matrix_auth/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,8 @@ impl MatrixAuth {
if let Some(save_session_callback) =
self.client.inner.auth_ctx.save_session_callback.get()
{
if let Err(err) = save_session_callback(self.client.clone()) {
if let Err(err) = save_session_callback.save_session(self.client.clone()).await
{
error!("when saving session after refresh: {err}");
}
}
Expand Down
52 changes: 33 additions & 19 deletions crates/matrix-sdk/src/oidc/cross_process.rs
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,9 @@ mod tests {

use super::compute_session_hash;
use crate::{
authentication::{
ReloadSessionCallback, SaveSessionCallback, SessionCallbackError, SessionTokens,
},
oidc::{
backend::mock::{MockImpl, ISSUER_URL},
cross_process::SessionHash,
Expand All @@ -272,9 +275,32 @@ mod tests {
Oidc, OidcSessionTokens,
},
test_utils::test_client_builder,
Error,
Client, Error,
};

struct TestReloadSessionCallback {
tokens: OidcSessionTokens,
}

#[async_trait::async_trait]
impl ReloadSessionCallback for TestReloadSessionCallback {
async fn reload_session(
&self,
_client: Client,
) -> Result<SessionTokens, SessionCallbackError> {
Ok(SessionTokens::Oidc(self.tokens.clone()))
}
}

struct TestSaveSessionCallback {}

#[async_trait::async_trait]
impl SaveSessionCallback for TestSaveSessionCallback {
async fn save_session(&self, _client: Client) -> Result<(), SessionCallbackError> {
panic!("save_session_callback shouldn't be called here")
}
}

#[async_test]
async fn test_restore_session_lock() -> Result<(), Error> {
// Create a client that will use sqlite databases.
Expand All @@ -295,12 +321,8 @@ mod tests {
client.oidc().enable_cross_process_refresh_lock("test".to_owned()).await?;

client.set_session_callbacks(
Box::new({
// This is only called because of extra checks in the code.
let tokens = tokens.clone();
move |_| Ok(crate::authentication::SessionTokens::Oidc(tokens.clone()))
}),
Box::new(|_| panic!("save_session_callback shouldn't be called here")),
Box::new(TestReloadSessionCallback { tokens: tokens.clone() }),
Box::new(TestSaveSessionCallback {}),
)?;

let session_hash = compute_session_hash(&tokens);
Expand Down Expand Up @@ -521,12 +543,8 @@ mod tests {
let oidc = unrestored_oidc;

unrestored_client.set_session_callbacks(
Box::new({
// This is only called because of extra checks in the code.
let tokens = next_tokens.clone();
move |_| Ok(crate::authentication::SessionTokens::Oidc(tokens.clone()))
}),
Box::new(|_| panic!("save_session_callback shouldn't be called here")),
Box::new(TestReloadSessionCallback { tokens: next_tokens.clone() }),
Box::new(TestSaveSessionCallback {}),
)?;

oidc.restore_session(tests::mock_session(prev_tokens.clone())).await?;
Expand Down Expand Up @@ -558,12 +576,8 @@ mod tests {
}

client.set_session_callbacks(
Box::new({
// This is only called because of extra checks in the code.
let tokens = next_tokens.clone();
move |_| Ok(crate::authentication::SessionTokens::Oidc(tokens.clone()))
}),
Box::new(|_| panic!("save_session_callback shouldn't be called here")),
Box::new(TestReloadSessionCallback { tokens: next_tokens.clone() }),
Box::new(TestSaveSessionCallback {}),
)?;

oidc.refresh_access_token().await?;
Expand Down
4 changes: 2 additions & 2 deletions crates/matrix-sdk/src/oidc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1015,7 +1015,7 @@ impl Oidc {
.get()
.ok_or(CrossProcessRefreshLockError::MissingReloadSession)?;

match callback(self.client.clone()) {
match callback.reload_session(self.client.clone()).await {
Ok(tokens) => {
let crate::authentication::SessionTokens::Oidc(tokens) = tokens else {
return Err(CrossProcessRefreshLockError::InvalidSessionTokens);
Expand Down Expand Up @@ -1337,7 +1337,7 @@ impl Oidc {
{
// Satisfies the save_session_callback invariant: set_session_tokens has
// been called just above.
if let Err(err) = save_session_callback(self.client.clone()) {
if let Err(err) = save_session_callback.save_session(self.client.clone()).await {
error!("when saving session after refresh: {err}");
}
}
Expand Down
38 changes: 31 additions & 7 deletions crates/matrix-sdk/tests/integration/refresh_token.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ use assert_matches::assert_matches;
use assert_matches2::assert_let;
use futures_util::StreamExt;
use matrix_sdk::{
authentication::{
ReloadSessionCallback, SaveSessionCallback, SessionCallbackError, SessionTokens,
},
config::RequestConfig,
executor::spawn,
matrix_auth::{MatrixSession, MatrixSessionTokens},
Expand Down Expand Up @@ -163,6 +166,30 @@ async fn test_no_refresh_token() {
assert_matches!(res, Err(RefreshTokenError::RefreshTokenRequired));
}

struct TestReloadSessionCallback {}

#[async_trait::async_trait]
impl ReloadSessionCallback for TestReloadSessionCallback {
async fn reload_session(
&self,
_client: matrix_sdk::Client,
) -> Result<SessionTokens, SessionCallbackError> {
panic!("reload session never called")
}
}

struct TestSaveSessionCallback {
counter: Arc<Mutex<u64>>,
}

#[async_trait::async_trait]
impl SaveSessionCallback for TestSaveSessionCallback {
async fn save_session(&self, _client: matrix_sdk::Client) -> Result<(), SessionCallbackError> {
*self.counter.lock().unwrap() += 1;
Ok(())
}
}

#[async_test]
async fn test_refresh_token() {
let (builder, server) = test_client_builder_with_server().await;
Expand All @@ -176,13 +203,10 @@ async fn test_refresh_token() {

let num_save_session_callback_calls = Arc::new(Mutex::new(0));
client
.set_session_callbacks(Box::new(|_| panic!("reload session never called")), {
let num_save_session_callback_calls = num_save_session_callback_calls.clone();
Box::new(move |_client| {
*num_save_session_callback_calls.lock().unwrap() += 1;
Ok(())
})
})
.set_session_callbacks(
Box::new(TestReloadSessionCallback {}),
Box::new(TestSaveSessionCallback { counter: num_save_session_callback_calls.clone() }),
)
.unwrap();

let mut session_changes = client.subscribe_to_session_changes();
Expand Down
Loading