From 6f6629bb3eb7025d41c796d1b5909a077d60b8e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Damir=20Jeli=C4=87?= Date: Wed, 8 Jan 2025 15:51:01 +0100 Subject: [PATCH] refactor(crypto): Use the simplified locks across the crypto crate --- .../src/backups/keys/backup.rs | 7 +- .../src/gossiping/machine.rs | 53 ++++++------ crates/matrix-sdk-crypto/src/gossiping/mod.rs | 10 ++- .../src/identities/device.rs | 10 +-- .../matrix-sdk-crypto/src/identities/user.rs | 41 ++++----- crates/matrix-sdk-crypto/src/machine/mod.rs | 3 +- .../src/olm/group_sessions/outbound.rs | 86 +++++++++---------- .../src/session_manager/group_sessions/mod.rs | 22 ++--- .../group_sessions/share_strategy.rs | 4 +- .../src/session_manager/sessions.rs | 50 ++++------- crates/matrix-sdk-crypto/src/store/caches.rs | 16 ++-- .../src/store/memorystore.rs | 79 ++++++++--------- crates/matrix-sdk-crypto/src/store/mod.rs | 15 ++-- .../src/verification/cache.rs | 29 +++---- .../src/verification/machine.rs | 13 ++- .../src/verification/sas/sas_state.rs | 33 ++++--- 16 files changed, 213 insertions(+), 258 deletions(-) diff --git a/crates/matrix-sdk-crypto/src/backups/keys/backup.rs b/crates/matrix-sdk-crypto/src/backups/keys/backup.rs index 83409d950b6..eac48b52b63 100644 --- a/crates/matrix-sdk-crypto/src/backups/keys/backup.rs +++ b/crates/matrix-sdk-crypto/src/backups/keys/backup.rs @@ -12,8 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::sync::{Arc, Mutex}; +use std::sync::Arc; +use matrix_sdk_common::locks::Mutex; use ruma::{ api::client::backup::{EncryptedSessionDataInit, KeyBackupData, KeyBackupDataInit}, serde::Base64, @@ -87,7 +88,7 @@ impl MegolmV1BackupKey { /// Get the backup version that this key is used with, if any. pub fn backup_version(&self) -> Option { - self.inner.version.lock().unwrap().clone() + self.inner.version.lock().clone() } /// Set the backup version that this `MegolmV1BackupKey` will be used with. @@ -95,7 +96,7 @@ impl MegolmV1BackupKey { /// The key won't be able to encrypt room keys unless a version has been /// set. pub fn set_version(&self, version: String) { - *self.inner.version.lock().unwrap() = Some(version); + *self.inner.version.lock() = Some(version); } /// Export the given inbound group session, and encrypt the data, ready for diff --git a/crates/matrix-sdk-crypto/src/gossiping/machine.rs b/crates/matrix-sdk-crypto/src/gossiping/machine.rs index 5b7a7055ab9..31eb23f23d2 100644 --- a/crates/matrix-sdk-crypto/src/gossiping/machine.rs +++ b/crates/matrix-sdk-crypto/src/gossiping/machine.rs @@ -25,10 +25,11 @@ use std::{ mem, sync::{ atomic::{AtomicBool, Ordering}, - Arc, RwLock as StdRwLock, + Arc, }, }; +use matrix_sdk_common::locks::RwLock as StdRwLock; use ruma::{ api::client::keys::claim_keys::v3::Request as KeysClaimRequest, events::secret::request::{ @@ -168,14 +169,13 @@ impl GossipMachine { ) -> Result, CryptoStoreError> { let mut key_requests = self.load_outgoing_requests().await?; let key_forwards: Vec = - self.inner.outgoing_requests.read().unwrap().values().cloned().collect(); + self.inner.outgoing_requests.read().values().cloned().collect(); key_requests.extend(key_forwards); let users_for_key_claim: BTreeMap<_, _> = self .inner .users_for_key_claim .read() - .unwrap() .iter() .map(|(key, value)| { let device_map = value @@ -213,7 +213,7 @@ impl GossipMachine { trace!("Received a secret request event from ourselves, ignoring") } else { let request_info = event.to_request_info(); - self.inner.incoming_key_requests.write().unwrap().insert(request_info, event); + self.inner.incoming_key_requests.write().insert(request_info, event); } } @@ -229,8 +229,7 @@ impl GossipMachine { ) -> OlmResult> { let mut changed_sessions = Vec::new(); - let incoming_key_requests = - mem::take(&mut *self.inner.incoming_key_requests.write().unwrap()); + let incoming_key_requests = mem::take(&mut *self.inner.incoming_key_requests.write()); for event in incoming_key_requests.values() { if let Some(s) = match event { @@ -254,7 +253,6 @@ impl GossipMachine { self.inner .users_for_key_claim .write() - .unwrap() .entry(device.user_id().to_owned()) .or_default() .insert(device.device_id().into()); @@ -275,7 +273,7 @@ impl GossipMachine { /// * `device_id` - The device ID of the device that got the Olm session. pub fn retry_keyshare(&self, user_id: &UserId, device_id: &DeviceId) { if let Entry::Occupied(mut e) = - self.inner.users_for_key_claim.write().unwrap().entry(user_id.to_owned()) + self.inner.users_for_key_claim.write().entry(user_id.to_owned()) { e.get_mut().remove(device_id); @@ -284,7 +282,7 @@ impl GossipMachine { } } - let mut incoming_key_requests = self.inner.incoming_key_requests.write().unwrap(); + let mut incoming_key_requests = self.inner.incoming_key_requests.write(); for (key, event) in self.inner.wait_queue.remove(user_id, device_id) { incoming_key_requests.entry(key).or_insert(event); } @@ -555,7 +553,7 @@ impl GossipMachine { request_id: request.txn_id.clone(), request: Arc::new(request.into()), }; - self.inner.outgoing_requests.write().unwrap().insert(request.request_id.clone(), request); + self.inner.outgoing_requests.write().insert(request.request_id.clone(), request); Ok(used_session) } @@ -581,7 +579,7 @@ impl GossipMachine { request_id: request.txn_id.clone(), request: Arc::new(request.into()), }; - self.inner.outgoing_requests.write().unwrap().insert(request.request_id.clone(), request); + self.inner.outgoing_requests.write().insert(request.request_id.clone(), request); Ok(used_session) } @@ -824,7 +822,7 @@ impl GossipMachine { self.save_outgoing_key_info(info).await?; } - self.inner.outgoing_requests.write().unwrap().remove(id); + self.inner.outgoing_requests.write().remove(id); Ok(()) } @@ -840,13 +838,13 @@ impl GossipMachine { "Successfully received a secret, removing the request" ); - self.inner.outgoing_requests.write().unwrap().remove(&key_info.request_id); + self.inner.outgoing_requests.write().remove(&key_info.request_id); // TODO return the key info instead of deleting it so the sync handler // can delete it in one transaction. self.delete_key_info(key_info).await?; let request = key_info.to_cancellation(self.device_id()); - self.inner.outgoing_requests.write().unwrap().insert(request.request_id.clone(), request); + self.inner.outgoing_requests.write().insert(request.request_id.clone(), request); Ok(()) } @@ -1511,7 +1509,6 @@ mod tests { .inner .outgoing_requests .read() - .unwrap() .first_key_value() .map(|(_, r)| r.request_id.clone()) .unwrap(); @@ -1692,7 +1689,7 @@ mod tests { alice_machine.mark_outgoing_request_as_sent(&request.request_id).await.unwrap(); // Bob doesn't have any outgoing requests. - assert!(bob_machine.inner.outgoing_requests.read().unwrap().is_empty()); + assert!(bob_machine.inner.outgoing_requests.read().is_empty()); // Receive the room key request from alice. bob_machine.receive_incoming_key_request(&event); @@ -1702,7 +1699,7 @@ mod tests { bob_machine.collect_incoming_key_requests(&bob_cache).await.unwrap(); } // Now bob does have an outgoing request. - assert!(!bob_machine.inner.outgoing_requests.read().unwrap().is_empty()); + assert!(!bob_machine.inner.outgoing_requests.read().is_empty()); // Get the request and convert it to a encrypted to-device event. let requests = bob_machine.outgoing_to_device_requests().await.unwrap(); @@ -1774,7 +1771,7 @@ mod tests { alice_machine.mark_outgoing_request_as_sent(&request.request_id).await.unwrap(); // Bob doesn't have any outgoing requests. - assert!(bob_machine.inner.outgoing_requests.read().unwrap().is_empty()); + assert!(bob_machine.inner.outgoing_requests.read().is_empty()); // Receive the room key request from alice. bob_machine.receive_incoming_key_request(&event); @@ -1783,7 +1780,7 @@ mod tests { bob_machine.collect_incoming_key_requests(&bob_cache).await.unwrap(); } // Now bob does have an outgoing request. - assert!(!bob_machine.inner.outgoing_requests.read().unwrap().is_empty()); + assert!(!bob_machine.inner.outgoing_requests.read().is_empty()); // Get the request and convert it to a encrypted to-device event. let requests = bob_machine.outgoing_to_device_requests().await.unwrap(); @@ -1875,13 +1872,13 @@ mod tests { }; // No secret found - assert!(alice_machine.inner.outgoing_requests.read().unwrap().is_empty()); + assert!(alice_machine.inner.outgoing_requests.read().is_empty()); alice_machine.receive_incoming_secret_request(&event); { let alice_cache = alice_machine.inner.store.cache().await.unwrap(); alice_machine.collect_incoming_key_requests(&alice_cache).await.unwrap(); } - assert!(alice_machine.inner.outgoing_requests.read().unwrap().is_empty()); + assert!(alice_machine.inner.outgoing_requests.read().is_empty()); // No device found alice_machine.inner.store.reset_cross_signing_identity().await; @@ -1890,7 +1887,7 @@ mod tests { let alice_cache = alice_machine.inner.store.cache().await.unwrap(); alice_machine.collect_incoming_key_requests(&alice_cache).await.unwrap(); } - assert!(alice_machine.inner.outgoing_requests.read().unwrap().is_empty()); + assert!(alice_machine.inner.outgoing_requests.read().is_empty()); alice_machine.inner.store.save_device_data(&[bob_device]).await.unwrap(); @@ -1901,7 +1898,7 @@ mod tests { let alice_cache = alice_machine.inner.store.cache().await.unwrap(); alice_machine.collect_incoming_key_requests(&alice_cache).await.unwrap(); } - assert!(alice_machine.inner.outgoing_requests.read().unwrap().is_empty()); + assert!(alice_machine.inner.outgoing_requests.read().is_empty()); let event = RumaToDeviceEvent { sender: alice_id().to_owned(), @@ -1918,7 +1915,7 @@ mod tests { let alice_cache = alice_machine.inner.store.cache().await.unwrap(); alice_machine.collect_incoming_key_requests(&alice_cache).await.unwrap(); } - assert!(alice_machine.inner.outgoing_requests.read().unwrap().is_empty()); + assert!(alice_machine.inner.outgoing_requests.read().is_empty()); // We need a trusted device, otherwise we won't serve secrets alice_device.set_trust_state(LocalTrust::Verified); @@ -1929,7 +1926,7 @@ mod tests { let alice_cache = alice_machine.inner.store.cache().await.unwrap(); alice_machine.collect_incoming_key_requests(&alice_cache).await.unwrap(); } - assert!(!alice_machine.inner.outgoing_requests.read().unwrap().is_empty()); + assert!(!alice_machine.inner.outgoing_requests.read().is_empty()); } #[async_test] @@ -2053,7 +2050,7 @@ mod tests { // Bob doesn't have any outgoing requests. assert!(bob_machine.outgoing_to_device_requests().await.unwrap().is_empty()); - assert!(bob_machine.inner.users_for_key_claim.read().unwrap().is_empty()); + assert!(bob_machine.inner.users_for_key_claim.read().is_empty()); assert!(bob_machine.inner.wait_queue.is_empty()); // Receive the room key request from alice. @@ -2068,7 +2065,7 @@ mod tests { bob_machine.outgoing_to_device_requests().await.unwrap()[0].request(), AnyOutgoingRequest::KeysClaim(_) ); - assert!(!bob_machine.inner.users_for_key_claim.read().unwrap().is_empty()); + assert!(!bob_machine.inner.users_for_key_claim.read().is_empty()); assert!(!bob_machine.inner.wait_queue.is_empty()); let (alice_session, bob_session) = alice_machine @@ -2096,7 +2093,7 @@ mod tests { bob_machine.inner.store.save_sessions(&[bob_session]).await.unwrap(); bob_machine.retry_keyshare(alice_id(), alice_device_id()); - assert!(bob_machine.inner.users_for_key_claim.read().unwrap().is_empty()); + assert!(bob_machine.inner.users_for_key_claim.read().is_empty()); { let bob_cache = bob_machine.inner.store.cache().await.unwrap(); bob_machine.collect_incoming_key_requests(&bob_cache).await.unwrap(); diff --git a/crates/matrix-sdk-crypto/src/gossiping/mod.rs b/crates/matrix-sdk-crypto/src/gossiping/mod.rs index 2593091b2eb..f608e5fb36c 100644 --- a/crates/matrix-sdk-crypto/src/gossiping/mod.rs +++ b/crates/matrix-sdk-crypto/src/gossiping/mod.rs @@ -16,10 +16,11 @@ mod machine; use std::{ collections::{BTreeMap, BTreeSet}, - sync::{Arc, RwLock as StdRwLock}, + sync::Arc, }; pub(crate) use machine::GossipMachine; +use matrix_sdk_common::locks::RwLock as StdRwLock; use ruma::{ events::{ room_key_request::{Action, ToDeviceRoomKeyRequestEventContent}, @@ -323,7 +324,7 @@ impl WaitQueue { #[cfg(all(test, feature = "automatic-room-key-forwarding"))] fn is_empty(&self) -> bool { - let read_guard = self.inner.read().unwrap(); + let read_guard = self.inner.read(); read_guard.requests_ids_waiting.is_empty() && read_guard.requests_waiting_for_session.is_empty() } @@ -337,13 +338,14 @@ impl WaitQueue { ); let ids_waiting_key = (device.user_id().to_owned(), device.device_id().into()); - let mut write_guard = self.inner.write().unwrap(); + let mut write_guard = self.inner.write(); write_guard.requests_waiting_for_session.insert(requests_waiting_key, event); write_guard.requests_ids_waiting.entry(ids_waiting_key).or_default().insert(request_id); } fn remove(&self, user_id: &UserId, device_id: &DeviceId) -> Vec<(RequestInfo, RequestEvent)> { - let mut write_guard = self.inner.write().unwrap(); + let mut write_guard = self.inner.write(); + write_guard .requests_ids_waiting .remove(&(user_id.to_owned(), device_id.into())) diff --git a/crates/matrix-sdk-crypto/src/identities/device.rs b/crates/matrix-sdk-crypto/src/identities/device.rs index da84d068a41..cea61624e16 100644 --- a/crates/matrix-sdk-crypto/src/identities/device.rs +++ b/crates/matrix-sdk-crypto/src/identities/device.rs @@ -17,11 +17,11 @@ use std::{ ops::Deref, sync::{ atomic::{AtomicBool, Ordering}, - Arc, RwLock, + Arc, }, }; -use matrix_sdk_common::deserialized_responses::WithheldCode; +use matrix_sdk_common::{deserialized_responses::WithheldCode, locks::RwLock}; use ruma::{ api::client::keys::upload_signatures::v3::Request as SignatureUploadRequest, events::{key::verification::VerificationMethod, AnyToDeviceEventContent}, @@ -470,7 +470,7 @@ impl Device { ) -> OlmResult> { let (used_session, raw_encrypted) = self.encrypt(event_type, content).await?; - // perist the used session + // Persist the used session self.verification_machine .store .save_changes(Changes { sessions: vec![used_session], ..Default::default() }) @@ -626,7 +626,7 @@ impl DeviceData { /// Get the trust state of the device. pub fn local_trust_state(&self) -> LocalTrust { - *self.trust_state.read().unwrap() + *self.trust_state.read() } /// Is the device locally marked as trusted. @@ -646,7 +646,7 @@ impl DeviceData { /// Note: This should only done in the crypto store where the trust state /// can be stored. pub(crate) fn set_trust_state(&self, state: LocalTrust) { - *self.trust_state.write().unwrap() = state; + *self.trust_state.write() = state; } pub(crate) fn mark_withheld_code_as_sent(&self) { diff --git a/crates/matrix-sdk-crypto/src/identities/user.rs b/crates/matrix-sdk-crypto/src/identities/user.rs index 75f927fd17f..4a4c788dc60 100644 --- a/crates/matrix-sdk-crypto/src/identities/user.rs +++ b/crates/matrix-sdk-crypto/src/identities/user.rs @@ -17,11 +17,12 @@ use std::{ ops::{Deref, DerefMut}, sync::{ atomic::{AtomicBool, Ordering}, - Arc, RwLock, + Arc, }, }; use as_variant::as_variant; +use matrix_sdk_common::locks::RwLock; use ruma::{ api::client::keys::upload_signatures::v3::Request as SignatureUploadRequest, events::{ @@ -684,7 +685,7 @@ impl From for OtherUserIdentityDataSerializer { user_id: value.user_id.clone(), master_key: value.master_key().to_owned(), self_signing_key: value.self_signing_key().to_owned(), - pinned_master_key: value.pinned_master_key.read().unwrap().clone(), + pinned_master_key: value.pinned_master_key.read().clone(), previously_verified: value.previously_verified.load(Ordering::SeqCst), }; OtherUserIdentityDataSerializer { @@ -787,7 +788,7 @@ impl OtherUserIdentityData { /// which is not verified and is in pin violation. See /// [`OtherUserIdentity::identity_needs_user_approval`]. pub(crate) fn pin(&self) { - let mut m = self.pinned_master_key.write().unwrap(); + let mut m = self.pinned_master_key.write(); *m = self.master_key.as_ref().clone() } @@ -828,7 +829,7 @@ impl OtherUserIdentityData { /// accept and pin the new identity, perform a verification, or /// stop communications. pub(crate) fn has_pin_violation(&self) -> bool { - let pinned_master_key = self.pinned_master_key.read().unwrap(); + let pinned_master_key = self.pinned_master_key.read(); pinned_master_key.get_first_key() != self.master_key().get_first_key() } @@ -858,7 +859,7 @@ impl OtherUserIdentityData { // the previous pinned master key. // This identity will have a pin violation until the new master key is pinned // (see `has_pin_violation()`). - let pinned_master_key = self.pinned_master_key.read().unwrap().clone(); + let pinned_master_key = self.pinned_master_key.read().clone(); // Check if the new master_key is signed by our own **verified** // user_signing_key. If the identity was verified we remember it. @@ -947,7 +948,7 @@ impl PartialEq for OwnUserIdentityData { && self.master_key == other.master_key && self.self_signing_key == other.self_signing_key && self.user_signing_key == other.user_signing_key - && *self.verified.read().unwrap() == *other.verified.read().unwrap() + && *self.verified.read() == *other.verified.read() && self.master_key.signatures() == other.master_key.signatures() } } @@ -1067,12 +1068,12 @@ impl OwnUserIdentityData { /// Mark our identity as verified. pub fn mark_as_verified(&self) { - *self.verified.write().unwrap() = OwnUserIdentityVerifiedState::Verified; + *self.verified.write() = OwnUserIdentityVerifiedState::Verified; } /// Mark our identity as unverified. pub(crate) fn mark_as_unverified(&self) { - let mut guard = self.verified.write().unwrap(); + let mut guard = self.verified.write(); if *guard == OwnUserIdentityVerifiedState::Verified { *guard = OwnUserIdentityVerifiedState::VerificationViolation; } @@ -1080,7 +1081,7 @@ impl OwnUserIdentityData { /// Check if our identity is verified. pub fn is_verified(&self) -> bool { - *self.verified.read().unwrap() == OwnUserIdentityVerifiedState::Verified + *self.verified.read() == OwnUserIdentityVerifiedState::Verified } /// True if we verified our own identity at some point in the past. @@ -1089,7 +1090,7 @@ impl OwnUserIdentityData { /// [`OwnUserIdentityData::withdraw_verification()`]. pub fn was_previously_verified(&self) -> bool { matches!( - *self.verified.read().unwrap(), + *self.verified.read(), OwnUserIdentityVerifiedState::Verified | OwnUserIdentityVerifiedState::VerificationViolation ) @@ -1101,7 +1102,7 @@ impl OwnUserIdentityData { /// reported to the user. In order to remove this notice users have to /// verify again or to withdraw the verification requirement. pub fn withdraw_verification(&self) { - let mut guard = self.verified.write().unwrap(); + let mut guard = self.verified.write(); if *guard == OwnUserIdentityVerifiedState::VerificationViolation { *guard = OwnUserIdentityVerifiedState::NeverVerified; } @@ -1117,7 +1118,7 @@ impl OwnUserIdentityData { /// - Or by withdrawing the verification requirement /// [`OwnUserIdentity::withdraw_verification`]. pub fn has_verification_violation(&self) -> bool { - *self.verified.read().unwrap() == OwnUserIdentityVerifiedState::VerificationViolation + *self.verified.read() == OwnUserIdentityVerifiedState::VerificationViolation } /// Update the identity with a new master key and self signing key. @@ -1523,7 +1524,7 @@ pub(crate) mod tests { }); let migrated: OtherUserIdentityData = serde_json::from_value(serialized_value).unwrap(); - let pinned_master_key = migrated.pinned_master_key.read().unwrap(); + let pinned_master_key = migrated.pinned_master_key.read(); assert_eq!(*pinned_master_key, migrated.master_key().clone()); // Serialize back @@ -1547,12 +1548,12 @@ pub(crate) mod tests { // Set `"verified": false` *json.get_mut("verified").unwrap() = false.into(); let id: OwnUserIdentityData = serde_json::from_value(json.clone()).unwrap(); - assert_eq!(*id.verified.read().unwrap(), OwnUserIdentityVerifiedState::NeverVerified); + assert_eq!(*id.verified.read(), OwnUserIdentityVerifiedState::NeverVerified); // Tweak the json to have `"verified": true`, and repeat *json.get_mut("verified").unwrap() = true.into(); let id: OwnUserIdentityData = serde_json::from_value(json.clone()).unwrap(); - assert_eq!(*id.verified.read().unwrap(), OwnUserIdentityVerifiedState::Verified); + assert_eq!(*id.verified.read(), OwnUserIdentityVerifiedState::Verified); } #[test] @@ -1565,10 +1566,7 @@ pub(crate) mod tests { let id: OwnUserIdentityData = serde_json::from_value(json.clone()).unwrap(); // Then the value is correctly populated - assert_eq!( - *id.verified.read().unwrap(), - OwnUserIdentityVerifiedState::VerificationViolation - ); + assert_eq!(*id.verified.read(), OwnUserIdentityVerifiedState::VerificationViolation); } #[test] @@ -1581,10 +1579,7 @@ pub(crate) mod tests { let id: OwnUserIdentityData = serde_json::from_value(json.clone()).unwrap(); // Then the old value is re-interpreted as VerificationViolation - assert_eq!( - *id.verified.read().unwrap(), - OwnUserIdentityVerifiedState::VerificationViolation - ); + assert_eq!(*id.verified.read(), OwnUserIdentityVerifiedState::VerificationViolation); } #[test] diff --git a/crates/matrix-sdk-crypto/src/machine/mod.rs b/crates/matrix-sdk-crypto/src/machine/mod.rs index aa5bbc9a00e..fc79eea1631 100644 --- a/crates/matrix-sdk-crypto/src/machine/mod.rs +++ b/crates/matrix-sdk-crypto/src/machine/mod.rs @@ -14,7 +14,7 @@ use std::{ collections::{BTreeMap, HashMap, HashSet}, - sync::{Arc, RwLock as StdRwLock}, + sync::Arc, time::Duration, }; @@ -25,6 +25,7 @@ use matrix_sdk_common::{ UnableToDecryptReason, UnsignedDecryptionResult, UnsignedEventLocation, VerificationLevel, VerificationState, }, + locks::RwLock as StdRwLock, BoxFuture, }; use ruma::{ diff --git a/crates/matrix-sdk-crypto/src/olm/group_sessions/outbound.rs b/crates/matrix-sdk-crypto/src/olm/group_sessions/outbound.rs index 1c5539d7203..95a5949de20 100644 --- a/crates/matrix-sdk-crypto/src/olm/group_sessions/outbound.rs +++ b/crates/matrix-sdk-crypto/src/olm/group_sessions/outbound.rs @@ -18,12 +18,12 @@ use std::{ fmt, sync::{ atomic::{AtomicBool, AtomicU64, Ordering}, - Arc, RwLock as StdRwLock, + Arc, }, time::Duration, }; -use matrix_sdk_common::deserialized_responses::WithheldCode; +use matrix_sdk_common::{deserialized_responses::WithheldCode, locks::RwLock as StdRwLock}; use ruma::{ events::{ room::{encryption::RoomEncryptionEventContent, history_visibility::HistoryVisibility}, @@ -274,7 +274,7 @@ impl OutboundGroupSession { request: Arc, share_infos: ShareInfoSet, ) { - self.to_share_with_set.write().unwrap().insert(request_id, (request, share_infos)); + self.to_share_with_set.write().insert(request_id, (request, share_infos)); } /// Create a new `m.room_key.withheld` event content with the given code for @@ -310,7 +310,7 @@ impl OutboundGroupSession { ) -> BTreeMap> { let mut no_olm_devices = BTreeMap::new(); - let removed = self.to_share_with_set.write().unwrap().remove(request_id); + let removed = self.to_share_with_set.write().remove(request_id); if let Some((to_device, request)) = removed { let recipients: BTreeMap<&UserId, BTreeSet<&DeviceId>> = request .iter() @@ -332,10 +332,10 @@ impl OutboundGroupSession { .collect(); no_olm_devices.insert(user_id.to_owned(), no_olms); - self.shared_with_set.write().unwrap().entry(user_id).or_default().extend(info); + self.shared_with_set.write().entry(user_id).or_default().extend(info); } - if self.to_share_with_set.read().unwrap().is_empty() { + if self.to_share_with_set.read().is_empty() { debug!( session_id = self.session_id(), room_id = ?self.room_id, @@ -347,7 +347,7 @@ impl OutboundGroupSession { } } else { let request_ids: Vec = - self.to_share_with_set.read().unwrap().keys().map(|k| k.to_string()).collect(); + self.to_share_with_set.read().keys().map(|k| k.to_string()).collect(); error!( all_request_ids = ?request_ids, @@ -540,22 +540,21 @@ impl OutboundGroupSession { /// Has or will the session be shared with the given user/device pair. pub(crate) fn is_shared_with(&self, device: &DeviceData) -> ShareState { // Check if we shared the session. - let shared_state = - self.shared_with_set.read().unwrap().get(device.user_id()).and_then(|d| { - d.get(device.device_id()).map(|s| match s { - ShareInfo::Shared(s) => { - if device.curve25519_key() == Some(s.sender_key) { - ShareState::Shared { - message_index: s.message_index, - olm_wedging_index: s.olm_wedging_index, - } - } else { - ShareState::SharedButChangedSenderKey + let shared_state = self.shared_with_set.read().get(device.user_id()).and_then(|d| { + d.get(device.device_id()).map(|s| match s { + ShareInfo::Shared(s) => { + if device.curve25519_key() == Some(s.sender_key) { + ShareState::Shared { + message_index: s.message_index, + olm_wedging_index: s.olm_wedging_index, } + } else { + ShareState::SharedButChangedSenderKey } - ShareInfo::Withheld(_) => ShareState::NotShared, - }) - }); + } + ShareInfo::Withheld(_) => ShareState::NotShared, + }) + }); if let Some(state) = shared_state { state @@ -565,24 +564,23 @@ impl OutboundGroupSession { // Find the first request that contains the given user id and // device ID. - let shared = - self.to_share_with_set.read().unwrap().values().find_map(|(_, share_info)| { - let d = share_info.get(device.user_id())?; - let info = d.get(device.device_id())?; - Some(match info { - ShareInfo::Shared(info) => { - if device.curve25519_key() == Some(info.sender_key) { - ShareState::Shared { - message_index: info.message_index, - olm_wedging_index: info.olm_wedging_index, - } - } else { - ShareState::SharedButChangedSenderKey + let shared = self.to_share_with_set.read().values().find_map(|(_, share_info)| { + let d = share_info.get(device.user_id())?; + let info = d.get(device.device_id())?; + Some(match info { + ShareInfo::Shared(info) => { + if device.curve25519_key() == Some(info.sender_key) { + ShareState::Shared { + message_index: info.message_index, + olm_wedging_index: info.olm_wedging_index, } + } else { + ShareState::SharedButChangedSenderKey } - ShareInfo::Withheld(_) => ShareState::NotShared, - }) - }); + } + ShareInfo::Withheld(_) => ShareState::NotShared, + }) + }); shared.unwrap_or(ShareState::NotShared) } @@ -591,7 +589,6 @@ impl OutboundGroupSession { pub(crate) fn is_withheld_to(&self, device: &DeviceData, code: &WithheldCode) -> bool { self.shared_with_set .read() - .unwrap() .get(device.user_id()) .and_then(|d| { let info = d.get(device.device_id())?; @@ -603,7 +600,7 @@ impl OutboundGroupSession { // Find the first request that contains the given user id and // device ID. - self.to_share_with_set.read().unwrap().values().any(|(_, share_info)| { + self.to_share_with_set.read().values().any(|(_, share_info)| { share_info .get(device.user_id()) .and_then(|d| d.get(device.device_id())) @@ -622,7 +619,7 @@ impl OutboundGroupSession { sender_key: Curve25519PublicKey, index: u32, ) { - self.shared_with_set.write().unwrap().entry(user_id.to_owned()).or_default().insert( + self.shared_with_set.write().entry(user_id.to_owned()).or_default().insert( device_id.to_owned(), ShareInfo::new_shared(sender_key, index, Default::default()), ); @@ -641,7 +638,6 @@ impl OutboundGroupSession { ShareInfo::new_shared(sender_key, self.message_index().await, Default::default()); self.shared_with_set .write() - .unwrap() .entry(user_id.to_owned()) .or_default() .insert(device_id.to_owned(), share_info); @@ -650,12 +646,12 @@ impl OutboundGroupSession { /// Get the list of requests that need to be sent out for this session to be /// marked as shared. pub(crate) fn pending_requests(&self) -> Vec> { - self.to_share_with_set.read().unwrap().values().map(|(req, _)| req.clone()).collect() + self.to_share_with_set.read().values().map(|(req, _)| req.clone()).collect() } /// Get the list of request ids this session is waiting for to be sent out. pub(crate) fn pending_request_ids(&self) -> Vec { - self.to_share_with_set.read().unwrap().keys().cloned().collect() + self.to_share_with_set.read().keys().cloned().collect() } /// Restore a Session from a previously pickled string. @@ -717,8 +713,8 @@ impl OutboundGroupSession { message_count: self.message_count.load(Ordering::SeqCst), shared: self.shared(), invalidated: self.invalidated(), - shared_with_set: self.shared_with_set.read().unwrap().clone(), - requests: self.to_share_with_set.read().unwrap().clone(), + shared_with_set: self.shared_with_set.read().clone(), + requests: self.to_share_with_set.read().clone(), } } } diff --git a/crates/matrix-sdk-crypto/src/session_manager/group_sessions/mod.rs b/crates/matrix-sdk-crypto/src/session_manager/group_sessions/mod.rs index b6ddeaaa290..2e148471a6d 100644 --- a/crates/matrix-sdk-crypto/src/session_manager/group_sessions/mod.rs +++ b/crates/matrix-sdk-crypto/src/session_manager/group_sessions/mod.rs @@ -17,12 +17,14 @@ mod share_strategy; use std::{ collections::{BTreeMap, BTreeSet}, fmt::Debug, - sync::{Arc, RwLock as StdRwLock}, + sync::Arc, }; use futures_util::future::join_all; use itertools::Itertools; -use matrix_sdk_common::{deserialized_responses::WithheldCode, executor::spawn}; +use matrix_sdk_common::{ + deserialized_responses::WithheldCode, executor::spawn, locks::RwLock as StdRwLock, +}; use ruma::{ events::{AnyMessageLikeEventContent, ToDeviceEventType}, serde::Raw, @@ -60,7 +62,7 @@ impl GroupSessionCache { } pub(crate) fn insert(&self, session: OutboundGroupSession) { - self.sessions.write().unwrap().insert(session.room_id().to_owned(), session); + self.sessions.write().insert(session.room_id().to_owned(), session); } /// Either get a session for the given room from the cache or load it from @@ -72,20 +74,20 @@ impl GroupSessionCache { pub async fn get_or_load(&self, room_id: &RoomId) -> Option { // Get the cached session, if there isn't one load one from the store // and put it in the cache. - if let Some(s) = self.sessions.read().unwrap().get(room_id) { + if let Some(s) = self.sessions.read().get(room_id) { return Some(s.clone()); } match self.store.get_outbound_group_session(room_id).await { Ok(Some(s)) => { { - let mut sessions_being_shared = self.sessions_being_shared.write().unwrap(); + let mut sessions_being_shared = self.sessions_being_shared.write(); for request_id in s.pending_request_ids() { sessions_being_shared.insert(request_id, s.clone()); } } - self.sessions.write().unwrap().insert(room_id.to_owned(), s.clone()); + self.sessions.write().insert(room_id.to_owned(), s.clone()); Some(s) } @@ -104,20 +106,20 @@ impl GroupSessionCache { /// * `room_id` - The id of the room for which we should get the outbound /// group session. fn get(&self, room_id: &RoomId) -> Option { - self.sessions.read().unwrap().get(room_id).cloned() + self.sessions.read().get(room_id).cloned() } /// Returns whether any session is withheld with the given device and code. fn has_session_withheld_to(&self, device: &DeviceData, code: &WithheldCode) -> bool { - self.sessions.read().unwrap().values().any(|s| s.is_withheld_to(device, code)) + self.sessions.read().values().any(|s| s.is_withheld_to(device, code)) } fn remove_from_being_shared(&self, id: &TransactionId) -> Option { - self.sessions_being_shared.write().unwrap().remove(id) + self.sessions_being_shared.write().remove(id) } fn mark_as_being_shared(&self, id: OwnedTransactionId, session: OutboundGroupSession) { - self.sessions_being_shared.write().unwrap().insert(id, session); + self.sessions_being_shared.write().insert(id, session); } } diff --git a/crates/matrix-sdk-crypto/src/session_manager/group_sessions/share_strategy.rs b/crates/matrix-sdk-crypto/src/session_manager/group_sessions/share_strategy.rs index f3e5b074b00..f3d6efeaed7 100644 --- a/crates/matrix-sdk-crypto/src/session_manager/group_sessions/share_strategy.rs +++ b/crates/matrix-sdk-crypto/src/session_manager/group_sessions/share_strategy.rs @@ -125,7 +125,7 @@ pub(crate) async fn collect_session_recipients( trace!(?users, ?settings, "Calculating group session recipients"); let users_shared_with: BTreeSet = - outbound.shared_with_set.read().unwrap().keys().cloned().collect(); + outbound.shared_with_set.read().keys().cloned().collect(); let users_shared_with: BTreeSet<&UserId> = users_shared_with.iter().map(Deref::deref).collect(); @@ -339,7 +339,7 @@ fn is_session_overshared_for_user( let recipient_device_ids: BTreeSet<&DeviceId> = recipient_devices.iter().map(|d| d.device_id()).collect(); - let guard = outbound_session.shared_with_set.read().unwrap(); + let guard = outbound_session.shared_with_set.read(); let Some(shared) = guard.get(user_id) else { return false; diff --git a/crates/matrix-sdk-crypto/src/session_manager/sessions.rs b/crates/matrix-sdk-crypto/src/session_manager/sessions.rs index e1e3aab1725..e7bdc5ba0ed 100644 --- a/crates/matrix-sdk-crypto/src/session_manager/sessions.rs +++ b/crates/matrix-sdk-crypto/src/session_manager/sessions.rs @@ -14,11 +14,11 @@ use std::{ collections::{BTreeMap, BTreeSet}, - sync::{Arc, RwLock as StdRwLock}, + sync::Arc, time::Duration, }; -use matrix_sdk_common::failures_cache::FailuresCache; +use matrix_sdk_common::{failures_cache::FailuresCache, locks::RwLock as StdRwLock}; use ruma::{ api::client::keys::claim_keys::v3::{ Request as KeysClaimRequest, Response as KeysClaimResponse, @@ -98,7 +98,7 @@ impl SessionManager { /// Mark the outgoing request as sent. pub fn mark_outgoing_request_as_sent(&self, id: &TransactionId) { - self.outgoing_to_device_requests.write().unwrap().remove(id); + self.outgoing_to_device_requests.write().remove(id); } pub async fn mark_device_as_wedged( @@ -121,13 +121,11 @@ impl SessionManager { if should_unwedge { self.users_for_key_claim .write() - .unwrap() .entry(device.user_id().to_owned()) .or_default() .insert(device.device_id().into()); self.wedged_devices .write() - .unwrap() .entry(device.user_id().to_owned()) .or_default() .insert(device.device_id().into()); @@ -142,7 +140,6 @@ impl SessionManager { pub fn is_device_wedged(&self, device: &DeviceData) -> bool { self.wedged_devices .read() - .unwrap() .get(device.user_id()) .is_some_and(|d| d.contains(device.device_id())) } @@ -151,13 +148,7 @@ impl SessionManager { /// /// If the device was wedged this will queue up a dummy to-device message. async fn check_if_unwedged(&self, user_id: &UserId, device_id: &DeviceId) -> OlmResult<()> { - if self - .wedged_devices - .write() - .unwrap() - .get_mut(user_id) - .is_some_and(|d| d.remove(device_id)) - { + if self.wedged_devices.write().get_mut(user_id).is_some_and(|d| d.remove(device_id)) { if let Some(device) = self.store.get_device(user_id, device_id).await? { let (_, content) = device.encrypt("m.dummy", ToDeviceDummyEventContent::new()).await?; @@ -176,7 +167,6 @@ impl SessionManager { self.outgoing_to_device_requests .write() - .unwrap() .insert(request.request_id.clone(), request); } } @@ -278,7 +268,7 @@ impl SessionManager { // Add the list of sessions that for some reason automatically need to // create an Olm session. - for (user, device_ids) in self.users_for_key_claim.read().unwrap().iter() { + for (user, device_ids) in self.users_for_key_claim.read().iter() { missing_session_devices_by_user.entry(user.to_owned()).or_default().extend( device_ids .iter() @@ -319,12 +309,12 @@ impl SessionManager { // stash the details of the request so that we can refer to it when handling the // response - *(self.current_key_claim_request.write().unwrap()) = result.clone(); + *(self.current_key_claim_request.write()) = result.clone(); Ok(result) } fn is_user_timed_out(&self, user_id: &UserId, device_id: &DeviceId) -> bool { - self.failed_devices.read().unwrap().get(user_id).is_some_and(|d| d.contains(device_id)) + self.failed_devices.read().get(user_id).is_some_and(|d| d.contains(device_id)) } /// This method will try to figure out for which devices a one-time key was @@ -354,7 +344,7 @@ impl SessionManager { ) { // First check that the response is for the request we were expecting. let request = { - let mut guard = self.current_key_claim_request.write().unwrap(); + let mut guard = self.current_key_claim_request.write(); let expected_request_id = guard.as_ref().map(|e| e.0.as_ref()); if Some(request_id) == expected_request_id { @@ -416,7 +406,7 @@ impl SessionManager { "Tried to create new Olm sessions, but the signed one-time key was missing for some devices", ); - let mut failed_devices_lock = self.failed_devices.write().unwrap(); + let mut failed_devices_lock = self.failed_devices.write(); for (user_id, device_set) in missing_devices_by_user { failed_devices_lock.entry(user_id.clone()).or_default().extend(device_set); @@ -542,7 +532,6 @@ impl SessionManager { self.failed_devices .write() - .unwrap() .entry(user_id.to_owned()) .or_default() .insert(device_id.to_owned()); @@ -573,7 +562,7 @@ impl SessionManager { info!(sessions = ?new_sessions, "Established new Olm sessions"); for (user, device_map) in new_sessions { - if let Some(user_cache) = self.failed_devices.read().unwrap().get(user) { + if let Some(user_cache) = self.failed_devices.read().get(user) { user_cache.remove(device_map.into_keys()); } } @@ -597,14 +586,9 @@ impl SessionManager { #[cfg(test)] mod tests { - use std::{ - collections::BTreeMap, - iter, - ops::Deref, - sync::{Arc, RwLock as StdRwLock}, - time::Duration, - }; + use std::{collections::BTreeMap, iter, ops::Deref, sync::Arc, time::Duration}; + use matrix_sdk_common::locks::RwLock as StdRwLock; use matrix_sdk_test::{async_test, ruma_response_from_json}; use ruma::{ api::client::keys::claim_keys::v3::Response as KeyClaimResponse, device_id, @@ -861,11 +845,11 @@ mod tests { let curve_key = bob_device.curve25519_key().unwrap(); - assert!(!manager.users_for_key_claim.read().unwrap().contains_key(bob.user_id())); + assert!(!manager.users_for_key_claim.read().contains_key(bob.user_id())); assert!(!manager.is_device_wedged(&bob_device)); manager.mark_device_as_wedged(bob_device.user_id(), curve_key).await.unwrap(); assert!(manager.is_device_wedged(&bob_device)); - assert!(manager.users_for_key_claim.read().unwrap().contains_key(bob.user_id())); + assert!(manager.users_for_key_claim.read().contains_key(bob.user_id())); let (txn_id, request) = manager.get_missing_sessions(iter::once(bob.user_id())).await.unwrap().unwrap(); @@ -885,13 +869,13 @@ mod tests { let response = KeyClaimResponse::new(one_time_keys); - assert!(manager.outgoing_to_device_requests.read().unwrap().is_empty()); + assert!(manager.outgoing_to_device_requests.read().is_empty()); manager.receive_keys_claim_response(&txn_id, &response).await.unwrap(); assert!(!manager.is_device_wedged(&bob_device)); assert!(manager.get_missing_sessions(iter::once(bob.user_id())).await.unwrap().is_none()); - assert!(!manager.outgoing_to_device_requests.read().unwrap().is_empty()) + assert!(!manager.outgoing_to_device_requests.read().is_empty()) } #[async_test] @@ -1012,7 +996,6 @@ mod tests { manager .failed_devices .write() - .unwrap() .get(alice) .unwrap() .expire(&alice_account.device_id().to_owned()); @@ -1027,7 +1010,6 @@ mod tests { assert!(manager .failed_devices .read() - .unwrap() .get(alice) .unwrap() .failure_count(alice_account.device_id()) diff --git a/crates/matrix-sdk-crypto/src/store/caches.rs b/crates/matrix-sdk-crypto/src/store/caches.rs index 3b45a62ad40..ca3954edf0c 100644 --- a/crates/matrix-sdk-crypto/src/store/caches.rs +++ b/crates/matrix-sdk-crypto/src/store/caches.rs @@ -22,10 +22,11 @@ use std::{ fmt::Display, sync::{ atomic::{AtomicBool, Ordering}, - Arc, RwLock as StdRwLock, Weak, + Arc, Weak, }, }; +use matrix_sdk_common::locks::RwLock as StdRwLock; use ruma::{DeviceId, OwnedDeviceId, OwnedRoomId, OwnedUserId, RoomId, UserId}; use serde::{Deserialize, Serialize}; use tokio::sync::{Mutex, RwLock}; @@ -104,7 +105,6 @@ impl GroupSessionStore { pub fn add(&self, session: InboundGroupSession) -> bool { self.entries .write() - .unwrap() .entry(session.room_id().to_owned()) .or_default() .insert(session.session_id().to_owned(), session) @@ -113,12 +113,12 @@ impl GroupSessionStore { /// Get all the group sessions the store knows about. pub fn get_all(&self) -> Vec { - self.entries.read().unwrap().values().flat_map(HashMap::values).cloned().collect() + self.entries.read().values().flat_map(HashMap::values).cloned().collect() } /// Get the number of `InboundGroupSession`s we have. pub fn count(&self) -> usize { - self.entries.read().unwrap().values().map(HashMap::len).sum() + self.entries.read().values().map(HashMap::len).sum() } /// Get a inbound group session from our store. @@ -128,7 +128,7 @@ impl GroupSessionStore { /// /// * `session_id` - The unique id of the session. pub fn get(&self, room_id: &RoomId, session_id: &str) -> Option { - self.entries.read().unwrap().get(room_id)?.get(session_id).cloned() + self.entries.read().get(room_id)?.get(session_id).cloned() } } @@ -151,7 +151,6 @@ impl DeviceStore { let user_id = device.user_id(); self.entries .write() - .unwrap() .entry(user_id.to_owned()) .or_default() .insert(device.device_id().into(), device) @@ -160,7 +159,7 @@ impl DeviceStore { /// Get the device with the given device_id and belonging to the given user. pub fn get(&self, user_id: &UserId, device_id: &DeviceId) -> Option { - Some(self.entries.read().unwrap().get(user_id)?.get(device_id)?.clone()) + Some(self.entries.read().get(user_id)?.get(device_id)?.clone()) } /// Remove the device with the given device_id and belonging to the given @@ -168,14 +167,13 @@ impl DeviceStore { /// /// Returns the device if it was removed, None if it wasn't in the store. pub fn remove(&self, user_id: &UserId, device_id: &DeviceId) -> Option { - self.entries.write().unwrap().get_mut(user_id)?.remove(device_id) + self.entries.write().get_mut(user_id)?.remove(device_id) } /// Get a read-only view over all devices of the given user. pub fn user_devices(&self, user_id: &UserId) -> HashMap { self.entries .write() - .unwrap() .entry(user_id.to_owned()) .or_default() .iter() diff --git a/crates/matrix-sdk-crypto/src/store/memorystore.rs b/crates/matrix-sdk-crypto/src/store/memorystore.rs index 4a4e16dc53c..1e6b4ae7fff 100644 --- a/crates/matrix-sdk-crypto/src/store/memorystore.rs +++ b/crates/matrix-sdk-crypto/src/store/memorystore.rs @@ -15,11 +15,12 @@ use std::{ collections::{BTreeMap, HashMap, HashSet}, convert::Infallible, - sync::RwLock as StdRwLock, }; use async_trait::async_trait; -use matrix_sdk_common::store_locks::memory_store_helper::try_take_leased_lock; +use matrix_sdk_common::{ + locks::RwLock as StdRwLock, store_locks::memory_store_helper::try_take_leased_lock, +}; use ruma::{ events::secret::request::SecretName, time::Instant, DeviceId, OwnedDeviceId, OwnedRoomId, OwnedTransactionId, OwnedUserId, RoomId, TransactionId, UserId, @@ -143,7 +144,7 @@ impl MemoryStore { } fn save_sessions(&self, sessions: Vec) { - let mut session_store = self.sessions.write().unwrap(); + let mut session_store = self.sessions.write(); for session in sessions { let entry = session_store.entry(session.sender_key().to_base64()).or_default(); @@ -159,12 +160,11 @@ impl MemoryStore { fn save_outbound_group_sessions(&self, sessions: Vec) { self.outbound_group_sessions .write() - .unwrap() .extend(sessions.into_iter().map(|s| (s.room_id().to_owned(), s))); } fn save_private_identity(&self, private_identity: Option) { - *self.private_identity.write().unwrap() = private_identity; + *self.private_identity.write() = private_identity; } /// Return all the [`InboundGroupSession`]s we have, paired with the @@ -176,7 +176,6 @@ impl MemoryStore { let lookup = |s: &InboundGroupSession| { self.inbound_group_sessions_backed_up_to .read() - .unwrap() .get(&s.room_id)? .get(s.session_id()) .cloned() @@ -202,11 +201,11 @@ impl CryptoStore for MemoryStore { type Error = Infallible; async fn load_account(&self) -> Result> { - Ok(self.account.read().unwrap().as_ref().map(|acc| acc.deep_clone())) + Ok(self.account.read().as_ref().map(|acc| acc.deep_clone())) } async fn load_identity(&self) -> Result> { - Ok(self.private_identity.read().unwrap().clone()) + Ok(self.private_identity.read().clone()) } async fn next_batch_token(&self) -> Result> { @@ -215,7 +214,7 @@ impl CryptoStore for MemoryStore { async fn save_pending_changes(&self, changes: PendingChanges) -> Result<()> { if let Some(account) = changes.account { - *self.account.write().unwrap() = Some(account); + *self.account.write() = Some(account); } Ok(()) @@ -232,7 +231,7 @@ impl CryptoStore for MemoryStore { self.delete_devices(changes.devices.deleted); { - let mut identities = self.identities.write().unwrap(); + let mut identities = self.identities.write(); for identity in changes.identities.new.into_iter().chain(changes.identities.changed) { identities.insert( identity.user_id().to_owned(), @@ -242,15 +241,15 @@ impl CryptoStore for MemoryStore { } { - let mut olm_hashes = self.olm_hashes.write().unwrap(); + let mut olm_hashes = self.olm_hashes.write(); for hash in changes.message_hashes { olm_hashes.entry(hash.sender_key.to_owned()).or_default().insert(hash.hash.clone()); } } { - let mut outgoing_key_requests = self.outgoing_key_requests.write().unwrap(); - let mut key_requests_by_info = self.key_requests_by_info.write().unwrap(); + let mut outgoing_key_requests = self.outgoing_key_requests.write(); + let mut key_requests_by_info = self.key_requests_by_info.write(); for key_request in changes.key_requests { let id = key_request.request_id.clone(); @@ -275,14 +274,14 @@ impl CryptoStore for MemoryStore { } { - let mut secret_inbox = self.secret_inbox.write().unwrap(); + let mut secret_inbox = self.secret_inbox.write(); for secret in changes.secrets { secret_inbox.entry(secret.secret_name.to_string()).or_default().push(secret); } } { - let mut direct_withheld_info = self.direct_withheld_info.write().unwrap(); + let mut direct_withheld_info = self.direct_withheld_info.write(); for (room_id, data) in changes.withheld_session_info { for (session_id, event) in data { direct_withheld_info @@ -298,7 +297,7 @@ impl CryptoStore for MemoryStore { } if !changes.room_settings.is_empty() { - let mut settings = self.room_settings.write().unwrap(); + let mut settings = self.room_settings.write(); settings.extend(changes.room_settings); } @@ -311,7 +310,7 @@ impl CryptoStore for MemoryStore { backed_up_to_version: Option<&str>, ) -> Result<()> { let mut inbound_group_sessions_backed_up_to = - self.inbound_group_sessions_backed_up_to.write().unwrap(); + self.inbound_group_sessions_backed_up_to.write(); for session in sessions { let room_id = session.room_id(); @@ -339,7 +338,7 @@ impl CryptoStore for MemoryStore { } async fn get_sessions(&self, sender_key: &str) -> Result>> { - Ok(self.sessions.read().unwrap().get(sender_key).cloned()) + Ok(self.sessions.read().get(sender_key).cloned()) } async fn get_inbound_group_session( @@ -358,7 +357,6 @@ impl CryptoStore for MemoryStore { Ok(self .direct_withheld_info .read() - .unwrap() .get(room_id) .and_then(|e| Some(e.get(session_id)?.to_owned()))) } @@ -458,7 +456,7 @@ impl CryptoStore for MemoryStore { room_and_session_ids: &[(&RoomId, &str)], ) -> Result<()> { let mut inbound_group_sessions_backed_up_to = - self.inbound_group_sessions_backed_up_to.write().unwrap(); + self.inbound_group_sessions_backed_up_to.write(); for &(room_id, session_id) in room_and_session_ids { let session = self.inbound_group_sessions.get(room_id, session_id); @@ -506,15 +504,15 @@ impl CryptoStore for MemoryStore { &self, room_id: &RoomId, ) -> Result> { - Ok(self.outbound_group_sessions.read().unwrap().get(room_id).cloned()) + Ok(self.outbound_group_sessions.read().get(room_id).cloned()) } async fn load_tracked_users(&self) -> Result> { - Ok(self.tracked_users.read().unwrap().values().cloned().collect()) + Ok(self.tracked_users.read().values().cloned().collect()) } async fn save_tracked_users(&self, tracked_users: &[(&UserId, bool)]) -> Result<()> { - self.tracked_users.write().unwrap().extend(tracked_users.iter().map(|(user_id, dirty)| { + self.tracked_users.write().extend(tracked_users.iter().map(|(user_id, dirty)| { let user_id: OwnedUserId = user_id.to_owned().into(); (user_id.clone(), TrackedUser { user_id, dirty: *dirty }) })); @@ -543,7 +541,7 @@ impl CryptoStore for MemoryStore { } async fn get_user_identity(&self, user_id: &UserId) -> Result> { - let serialized = self.identities.read().unwrap().get(user_id).cloned(); + let serialized = self.identities.read().get(user_id).cloned(); match serialized { None => Ok(None), Some(serialized) => { @@ -557,7 +555,6 @@ impl CryptoStore for MemoryStore { Ok(self .olm_hashes .write() - .unwrap() .entry(message_hash.sender_key.to_owned()) .or_default() .contains(&message_hash.hash)) @@ -567,7 +564,7 @@ impl CryptoStore for MemoryStore { &self, request_id: &TransactionId, ) -> Result> { - Ok(self.outgoing_key_requests.read().unwrap().get(request_id).cloned()) + Ok(self.outgoing_key_requests.read().get(request_id).cloned()) } async fn get_secret_request_by_info( @@ -579,16 +576,14 @@ impl CryptoStore for MemoryStore { Ok(self .key_requests_by_info .read() - .unwrap() .get(&key_info_string) - .and_then(|i| self.outgoing_key_requests.read().unwrap().get(i).cloned())) + .and_then(|i| self.outgoing_key_requests.read().get(i).cloned())) } async fn get_unsent_secret_requests(&self) -> Result> { Ok(self .outgoing_key_requests .read() - .unwrap() .values() .filter(|req| !req.sent_out) .cloned() @@ -596,10 +591,10 @@ impl CryptoStore for MemoryStore { } async fn delete_outgoing_secret_requests(&self, request_id: &TransactionId) -> Result<()> { - let req = self.outgoing_key_requests.write().unwrap().remove(request_id); + let req = self.outgoing_key_requests.write().remove(request_id); if let Some(i) = req { let key_info_string = encode_key_info(&i.info); - self.key_requests_by_info.write().unwrap().remove(&key_info_string); + self.key_requests_by_info.write().remove(&key_info_string); } Ok(()) @@ -609,36 +604,30 @@ impl CryptoStore for MemoryStore { &self, secret_name: &SecretName, ) -> Result> { - Ok(self - .secret_inbox - .write() - .unwrap() - .entry(secret_name.to_string()) - .or_default() - .to_owned()) + Ok(self.secret_inbox.write().entry(secret_name.to_string()).or_default().to_owned()) } async fn delete_secrets_from_inbox(&self, secret_name: &SecretName) -> Result<()> { - self.secret_inbox.write().unwrap().remove(secret_name.as_str()); + self.secret_inbox.write().remove(secret_name.as_str()); Ok(()) } async fn get_room_settings(&self, room_id: &RoomId) -> Result> { - Ok(self.room_settings.read().unwrap().get(room_id).cloned()) + Ok(self.room_settings.read().get(room_id).cloned()) } async fn get_custom_value(&self, key: &str) -> Result>> { - Ok(self.custom_values.read().unwrap().get(key).cloned()) + Ok(self.custom_values.read().get(key).cloned()) } async fn set_custom_value(&self, key: &str, value: Vec) -> Result<()> { - self.custom_values.write().unwrap().insert(key.to_owned(), value); + self.custom_values.write().insert(key.to_owned(), value); Ok(()) } async fn remove_custom_value(&self, key: &str) -> Result<()> { - self.custom_values.write().unwrap().remove(key); + self.custom_values.write().remove(key); Ok(()) } @@ -648,7 +637,7 @@ impl CryptoStore for MemoryStore { key: &str, holder: &str, ) -> Result { - Ok(try_take_leased_lock(&mut self.leases.write().unwrap(), lease_duration_ms, key, holder)) + Ok(try_take_leased_lock(&mut self.leases.write(), lease_duration_ms, key, holder)) } } @@ -1167,7 +1156,7 @@ mod integration_tests { impl MemoryStore { fn get_static_account(&self) -> Option { - self.account.read().unwrap().as_ref().map(|acc| acc.static_data().clone()) + self.account.read().as_ref().map(|acc| acc.static_data().clone()) } } diff --git a/crates/matrix-sdk-crypto/src/store/mod.rs b/crates/matrix-sdk-crypto/src/store/mod.rs index 662829f8c35..8f990670d4c 100644 --- a/crates/matrix-sdk-crypto/src/store/mod.rs +++ b/crates/matrix-sdk-crypto/src/store/mod.rs @@ -43,13 +43,14 @@ use std::{ fmt::Debug, ops::Deref, pin::pin, - sync::{atomic::Ordering, Arc, RwLock as StdRwLock}, + sync::{atomic::Ordering, Arc}, time::Duration, }; use as_variant::as_variant; use futures_core::Stream; use futures_util::StreamExt; +use matrix_sdk_common::locks::RwLock as StdRwLock; use ruma::{ encryption::KeyUsage, events::secret::request::SecretName, DeviceId, OwnedDeviceId, OwnedRoomId, OwnedUserId, UserId, @@ -155,7 +156,7 @@ impl KeyQueryManager { let tracked_users = cache.store.load_tracked_users().await?; let mut query_users_lock = self.users_for_key_query.lock().await; - let mut tracked_users_cache = cache.tracked_users.write().unwrap(); + let mut tracked_users_cache = cache.tracked_users.write(); for user in tracked_users { tracked_users_cache.insert(user.user_id.to_owned()); @@ -245,7 +246,7 @@ impl SyncedKeyQueryManager<'_> { let mut key_query_lock = self.manager.users_for_key_query.lock().await; { - let mut tracked_users = self.cache.tracked_users.write().unwrap(); + let mut tracked_users = self.cache.tracked_users.write(); for user_id in users { if tracked_users.insert(user_id.to_owned()) { key_query_lock.insert_user(user_id); @@ -271,7 +272,7 @@ impl SyncedKeyQueryManager<'_> { let mut key_query_lock = self.manager.users_for_key_query.lock().await; { - let tracked_users = &self.cache.tracked_users.read().unwrap(); + let tracked_users = &self.cache.tracked_users.read(); for user_id in users { if tracked_users.contains(user_id) { key_query_lock.insert_user(user_id); @@ -297,7 +298,7 @@ impl SyncedKeyQueryManager<'_> { let mut key_query_lock = self.manager.users_for_key_query.lock().await; { - let tracked_users = self.cache.tracked_users.read().unwrap(); + let tracked_users = self.cache.tracked_users.read(); for user_id in users { if tracked_users.contains(user_id) { let clean = key_query_lock.maybe_remove_user(user_id, sequence_number); @@ -330,7 +331,7 @@ impl SyncedKeyQueryManager<'_> { /// See the docs for [`crate::OlmMachine::tracked_users()`]. pub fn tracked_users(&self) -> HashSet { - self.cache.tracked_users.read().unwrap().iter().cloned().collect() + self.cache.tracked_users.read().iter().cloned().collect() } /// Mark the given user as being tracked for device lists, and mark that it @@ -340,7 +341,7 @@ impl SyncedKeyQueryManager<'_> { /// next time [`Store::users_for_key_query()`] is called. pub async fn mark_user_as_changed(&self, user: &UserId) -> Result<()> { self.manager.users_for_key_query.lock().await.insert_user(user); - self.cache.tracked_users.write().unwrap().insert(user.to_owned()); + self.cache.tracked_users.write().insert(user.to_owned()); self.cache.store.save_tracked_users(&[(user, true)]).await } diff --git a/crates/matrix-sdk-crypto/src/verification/cache.rs b/crates/matrix-sdk-crypto/src/verification/cache.rs index f5e23344a5d..6601cb3d4d0 100644 --- a/crates/matrix-sdk-crypto/src/verification/cache.rs +++ b/crates/matrix-sdk-crypto/src/verification/cache.rs @@ -12,12 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{ - collections::BTreeMap, - sync::{Arc, RwLock as StdRwLock}, -}; +use std::{collections::BTreeMap, sync::Arc}; use as_variant::as_variant; +use matrix_sdk_common::locks::RwLock as StdRwLock; use ruma::{DeviceId, OwnedTransactionId, OwnedUserId, TransactionId, UserId}; #[cfg(feature = "qrcode")] use tracing::debug; @@ -69,7 +67,7 @@ impl VerificationCache { #[cfg(test)] #[allow(dead_code)] pub fn is_empty(&self) -> bool { - self.inner.verification.read().unwrap().values().all(|m| m.is_empty()) + self.inner.verification.read().values().all(|m| m.is_empty()) } /// Add a new `Verification` object to the cache, this will cancel any @@ -78,7 +76,7 @@ impl VerificationCache { pub fn insert(&self, verification: impl Into) { let verification = verification.into(); - let mut verification_write_guard = self.inner.verification.write().unwrap(); + let mut verification_write_guard = self.inner.verification.write(); let user_verifications = verification_write_guard.entry(verification.other_user().to_owned()).or_default(); @@ -150,22 +148,21 @@ impl VerificationCache { self.inner .verification .write() - .unwrap() .entry(verification.other_user().to_owned()) .or_default() .insert(verification.flow_id().to_owned(), verification.clone()); } pub fn get(&self, sender: &UserId, flow_id: &str) -> Option { - self.inner.verification.read().unwrap().get(sender)?.get(flow_id).cloned() + self.inner.verification.read().get(sender)?.get(flow_id).cloned() } pub fn outgoing_requests(&self) -> Vec { - self.inner.outgoing_requests.read().unwrap().values().cloned().collect() + self.inner.outgoing_requests.read().values().cloned().collect() } pub fn garbage_collect(&self) -> Vec { - let verification = &mut self.inner.verification.write().unwrap(); + let verification = &mut self.inner.verification.write(); for user_verification in verification.values_mut() { user_verification.retain(|_, s| !(s.is_done() || s.is_cancelled())); @@ -186,7 +183,7 @@ impl VerificationCache { pub fn add_request(&self, request: OutgoingRequest) { trace!("Adding an outgoing request {:?}", request); - self.inner.outgoing_requests.write().unwrap().insert(request.request_id.clone(), request); + self.inner.outgoing_requests.write().insert(request.request_id.clone(), request); } pub fn add_verification_request(&self, request: OutgoingVerificationRequest) { @@ -211,7 +208,7 @@ impl VerificationCache { "Storing the request info, waiting for the request to be marked as sent" ); - self.inner.flow_ids_waiting_for_response.write().unwrap().insert( + self.inner.flow_ids_waiting_for_response.write().insert( request_info.request_id.to_owned(), (recipient.to_owned(), request_info.flow_id), ); @@ -235,7 +232,7 @@ impl VerificationCache { request: Arc::new(request.into()), }; - self.inner.outgoing_requests.write().unwrap().insert(request_id, request); + self.inner.outgoing_requests.write().insert(request_id, request); } OutgoingContent::Room(r, c) => { @@ -247,18 +244,18 @@ impl VerificationCache { request_id: request_id.clone(), }; - self.inner.outgoing_requests.write().unwrap().insert(request_id, request); + self.inner.outgoing_requests.write().insert(request_id, request); } } } pub fn mark_request_as_sent(&self, request_id: &TransactionId) { - if let Some(request_id) = self.inner.outgoing_requests.write().unwrap().remove(request_id) { + if let Some(request_id) = self.inner.outgoing_requests.write().remove(request_id) { trace!(?request_id, "Marking a verification HTTP request as sent"); } if let Some((user_id, flow_id)) = - self.inner.flow_ids_waiting_for_response.read().unwrap().get(request_id) + self.inner.flow_ids_waiting_for_response.read().get(request_id) { if let Some(verification) = self.get(user_id, flow_id.as_str()) { match verification { diff --git a/crates/matrix-sdk-crypto/src/verification/machine.rs b/crates/matrix-sdk-crypto/src/verification/machine.rs index 5e7c0c15cf0..49d7b80ae14 100644 --- a/crates/matrix-sdk-crypto/src/verification/machine.rs +++ b/crates/matrix-sdk-crypto/src/verification/machine.rs @@ -12,11 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{ - collections::HashMap, - sync::{Arc, RwLock as StdRwLock}, -}; +use std::{collections::HashMap, sync::Arc}; +use matrix_sdk_common::locks::RwLock as StdRwLock; use ruma::{ events::{ key::verification::VerificationMethod, AnyToDeviceEvent, AnyToDeviceEventContent, @@ -153,13 +151,12 @@ impl VerificationMachine { user_id: &UserId, flow_id: impl AsRef, ) -> Option { - self.requests.read().unwrap().get(user_id)?.get(flow_id.as_ref()).cloned() + self.requests.read().get(user_id)?.get(flow_id.as_ref()).cloned() } pub fn get_requests(&self, user_id: &UserId) -> Vec { self.requests .read() - .unwrap() .get(user_id) .map(|v| v.iter().map(|(_, value)| value.clone()).collect()) .unwrap_or_default() @@ -174,7 +171,7 @@ impl VerificationMachine { return; } - let mut requests = self.requests.write().unwrap(); + let mut requests = self.requests.write(); let user_requests = requests.entry(request.other_user().to_owned()).or_default(); // Cancel all the old verifications requests as well as the new one we @@ -247,7 +244,7 @@ impl VerificationMachine { let mut events = vec![]; let mut requests: Vec = { - let mut requests = self.requests.write().unwrap(); + let mut requests = self.requests.write(); for user_verification in requests.values_mut() { user_verification.retain(|_, v| !(v.is_done() || v.is_cancelled())); diff --git a/crates/matrix-sdk-crypto/src/verification/sas/sas_state.rs b/crates/matrix-sdk-crypto/src/verification/sas/sas_state.rs index 295335bb7f2..958cf6a004d 100644 --- a/crates/matrix-sdk-crypto/src/verification/sas/sas_state.rs +++ b/crates/matrix-sdk-crypto/src/verification/sas/sas_state.rs @@ -12,12 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{ - matches, - sync::{Arc, Mutex}, - time::Duration, -}; +use std::{matches, sync::Arc, time::Duration}; +use matrix_sdk_common::locks::Mutex; use ruma::{ events::{ key::verification::{ @@ -356,7 +353,7 @@ impl SasState { let their_public_key = Curve25519PublicKey::from_slice(content.public_key().as_bytes()) .map_err(|_| CancelCode::from("Invalid public key"))?; - if let Some(sas) = self.inner.lock().unwrap().take() { + if let Some(sas) = self.inner.lock().take() { sas.diffie_hellman(their_public_key).map_err(|_| "Invalid public key".into()) } else { Err(CancelCode::UnexpectedMessage) @@ -1127,7 +1124,7 @@ impl SasState { /// second element the English description of the emoji. pub fn get_emoji(&self) -> [Emoji; 7] { get_emoji( - &self.state.sas.lock().unwrap(), + &self.state.sas.lock(), &self.ids, self.verification_flow_id.as_str(), self.state.we_started, @@ -1140,7 +1137,7 @@ impl SasState { /// numbers can be converted to a unique emoji defined by the spec. pub fn get_emoji_index(&self) -> [u8; 7] { get_emoji_index( - &self.state.sas.lock().unwrap(), + &self.state.sas.lock(), &self.ids, self.verification_flow_id.as_str(), self.state.we_started, @@ -1153,7 +1150,7 @@ impl SasState { /// the short auth string. pub fn get_decimal(&self) -> (u16, u16, u16) { get_decimal( - &self.state.sas.lock().unwrap(), + &self.state.sas.lock(), &self.ids, self.verification_flow_id.as_str(), self.state.we_started, @@ -1175,7 +1172,7 @@ impl SasState { self.check_event(sender, content.flow_id()).map_err(|c| self.clone().cancel(true, c))?; let (devices, master_keys) = receive_mac_event( - &self.state.sas.lock().unwrap(), + &self.state.sas.lock(), &self.ids, self.verification_flow_id.as_str(), sender, @@ -1239,7 +1236,7 @@ impl SasState { self.check_event(sender, content.flow_id()).map_err(|c| self.clone().cancel(true, c))?; let (devices, master_keys) = receive_mac_event( - &self.state.sas.lock().unwrap(), + &self.state.sas.lock(), &self.ids, self.verification_flow_id.as_str(), sender, @@ -1283,7 +1280,7 @@ impl SasState { self.check_event(sender, content.flow_id()).map_err(|c| self.clone().cancel(true, c))?; let (devices, master_keys) = receive_mac_event( - &self.state.sas.lock().unwrap(), + &self.state.sas.lock(), &self.ids, self.verification_flow_id.as_str(), sender, @@ -1315,7 +1312,7 @@ impl SasState { /// The content needs to be automatically sent to the other side. pub fn as_content(&self) -> OutgoingContent { get_mac_content( - &self.state.sas.lock().unwrap(), + &self.state.sas.lock(), &self.ids, &self.verification_flow_id, self.state.accepted_protocols.message_auth_code, @@ -1376,7 +1373,7 @@ impl SasState { /// second element the English description of the emoji. pub fn get_emoji(&self) -> [Emoji; 7] { get_emoji( - &self.state.sas.lock().unwrap(), + &self.state.sas.lock(), &self.ids, self.verification_flow_id.as_str(), self.state.we_started, @@ -1389,7 +1386,7 @@ impl SasState { /// numbers can be converted to a unique emoji defined by the spec. pub fn get_emoji_index(&self) -> [u8; 7] { get_emoji_index( - &self.state.sas.lock().unwrap(), + &self.state.sas.lock(), &self.ids, self.verification_flow_id.as_str(), self.state.we_started, @@ -1402,7 +1399,7 @@ impl SasState { /// the short auth string. pub fn get_decimal(&self) -> (u16, u16, u16) { get_decimal( - &self.state.sas.lock().unwrap(), + &self.state.sas.lock(), &self.ids, self.verification_flow_id.as_str(), self.state.we_started, @@ -1417,7 +1414,7 @@ impl SasState { /// wasn't already sent. pub fn as_content(&self) -> OutgoingContent { get_mac_content( - &self.state.sas.lock().unwrap(), + &self.state.sas.lock(), &self.ids, &self.verification_flow_id, self.state.accepted_protocols.message_auth_code, @@ -1480,7 +1477,7 @@ impl SasState { /// wasn't already sent. pub fn as_content(&self) -> OutgoingContent { get_mac_content( - &self.state.sas.lock().unwrap(), + &self.state.sas.lock(), &self.ids, &self.verification_flow_id, self.state.accepted_protocols.message_auth_code,