diff --git a/accounts-db/src/account_locks.rs b/accounts-db/src/account_locks.rs index e72a3107026731..ecd784033c68c1 100644 --- a/accounts-db/src/account_locks.rs +++ b/accounts-db/src/account_locks.rs @@ -7,15 +7,15 @@ use { pubkey::Pubkey, transaction::{TransactionError, MAX_TX_ACCOUNT_LOCKS}, }, - std::{cell::RefCell, thread::ThreadId}, + std::{cell::RefCell, collections::hash_map, thread::ThreadId}, }; const HANA_FEATURE_PLACEHOLDER: bool = false; #[derive(Debug, Default)] pub struct AccountLocks { - write_locks: AHashMap, - readonly_locks: AHashMap>, + write_locks: AHashMap, + readonly_locks: AHashMap>, } impl AccountLocks { @@ -65,12 +65,12 @@ impl AccountLocks { #[cfg_attr(feature = "dev-context-only-utils", qualifiers(pub))] fn is_locked_readonly(&self, key: &Pubkey) -> bool { let thread_id = std::thread::current().id(); - for (locking_thread_id, readonly_locks) in &self.readonly_locks { + for (locking_thread_id, thread_readonly_locks) in &self.readonly_locks { if HANA_FEATURE_PLACEHOLDER && locking_thread_id == &thread_id { continue; } - if readonly_locks.contains(key) { + if thread_readonly_locks.contains_key(key) { return true; } } @@ -82,10 +82,7 @@ impl AccountLocks { fn is_locked_write(&self, key: &Pubkey) -> bool { if HANA_FEATURE_PLACEHOLDER { let thread_id = std::thread::current().id(); - match self.write_locks.get(key) { - Some(locking_thread_id) if locking_thread_id != &thread_id => true, - _ => false, - } + matches!(self.write_locks.get(key), Some((locking_thread_id, _)) if locking_thread_id != &thread_id) } else { self.write_locks.contains_key(key) } @@ -103,33 +100,65 @@ impl AccountLocks { fn lock_readonly(&mut self, key: &Pubkey) { let thread_id = std::thread::current().id(); - self.readonly_locks + *self + .readonly_locks .entry(thread_id) .or_default() - .insert(*key); + .entry(*key) + .or_default() += 1; } fn lock_write(&mut self, key: &Pubkey) { let thread_id = std::thread::current().id(); - self.write_locks.insert(*key, thread_id); + if let hash_map::Entry::Occupied(mut occupied_entry) = self.write_locks.entry(*key) { + let (locking_thread_id, count) = occupied_entry.get_mut(); + debug_assert!( + *locking_thread_id == thread_id, + "Attempted to steal a write lock from another thread." + ); + *count += 1; + } else { + self.write_locks.insert(*key, (thread_id, 1)); + } } - // HANA TODO i removed counting read locks so this may be None, consider whether to add it back fn unlock_readonly(&mut self, key: &Pubkey) { let thread_id = std::thread::current().id(); - if let Some(readonly_locks) = self.readonly_locks.get_mut(&thread_id) { - readonly_locks.remove(key); + if let Some(hash_map::Entry::Occupied(mut occupied_entry)) = self + .readonly_locks + .get_mut(&thread_id) + .map(|thread_readonly_locks| thread_readonly_locks.entry(*key)) + { + let count = occupied_entry.get_mut(); + *count -= 1; + if *count == 0 { + occupied_entry.remove_entry(); + } } else { debug_assert!( false, - "Attempted to remove a readonly lock for a thread that never created any." + "Attempted to remove a read-lock for a key that wasn't read-locked" ); } } - // HANA TODO i did not impl counting (same thread) write locks so this may be None, consider whether to add it back fn unlock_write(&mut self, key: &Pubkey) { - self.write_locks.remove(key); + if let hash_map::Entry::Occupied(mut occupied_entry) = self.write_locks.entry(*key) { + let (locking_thread_id, count) = occupied_entry.get_mut(); + debug_assert!( + *locking_thread_id == std::thread::current().id(), + "Attempted to steal a write lock from another thread." + ); + *count -= 1; + if *count == 0 { + occupied_entry.remove_entry(); + } + } else { + debug_assert!( + false, + "Attempted to remove a write-lock for a key that wasn't write-locked" + ); + } } }