From 604d4611723fd114ad9a77b15192045e26afb2ae Mon Sep 17 00:00:00 2001 From: John Nunley Date: Fri, 22 Sep 2023 20:12:23 -0700 Subject: [PATCH] feat: Support blocking and non-blocking operations on the same mutex Signed-off-by: John Nunley --- Cargo.toml | 2 +- src/barrier.rs | 87 +++++++++++++-- src/mutex.rs | 63 ++++++++++- src/rwlock.rs | 173 +++++++++++++++++++++++------ src/rwlock/futures.rs | 251 +++++++++++++++++++++++++++++++++--------- src/rwlock/raw.rs | 49 ++++++--- src/semaphore.rs | 54 +++++++-- tests/barrier.rs | 40 +++++++ tests/mutex.rs | 8 ++ tests/rwlock.rs | 10 ++ tests/semaphore.rs | 11 ++ 11 files changed, 627 insertions(+), 121 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 74fe924..776d152 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,7 +16,7 @@ exclude = ["/.*"] [dependencies] event-listener = { version = "3.0.0", default-features = false } -event-listener-strategy = { version = "0.2.0", default-features = false } +event-listener-strategy = { version = "0.3.0", default-features = false } pin-project-lite = "0.2.11" [features] diff --git a/src/barrier.rs b/src/barrier.rs index 9259603..84b0742 100644 --- a/src/barrier.rs +++ b/src/barrier.rs @@ -1,9 +1,9 @@ use event_listener::{Event, EventListener}; +use event_listener_strategy::{easy_wrapper, EventListenerFuture, Strategy}; use core::fmt; -use core::future::Future; use core::pin::Pin; -use core::task::{Context, Poll}; +use core::task::Poll; use crate::futures::Lock; use crate::Mutex; @@ -79,18 +79,67 @@ impl Barrier { /// } /// ``` pub fn wait(&self) -> BarrierWait<'_> { - BarrierWait { + BarrierWait::_new(BarrierWaitInner { barrier: self, lock: Some(self.state.lock()), evl: EventListener::new(&self.event), state: WaitState::Initial, - } + }) + } + + /// Blocks the current thread until all tasks reach this point. + /// + /// Barriers are reusable after all tasks have synchronized, and can be used continuously. + /// + /// Returns a [`BarrierWaitResult`] indicating whether this task is the "leader", meaning the + /// last task to call this method. + /// + /// # Blocking + /// + /// Rather than using asynchronous waiting, like the [`wait`] method, this method will + /// block the current thread until the wait is complete. + /// + /// This method should not be used in an asynchronous context. It is intended to be + /// used in a way that a barrier can be used in both asynchronous and synchronous contexts. + /// Calling this method in an asynchronous context may result in a deadlock. + /// + /// # Examples + /// + /// ``` + /// use async_lock::Barrier; + /// use futures_lite::future; + /// use std::sync::Arc; + /// use std::thread; + /// + /// let barrier = Arc::new(Barrier::new(5)); + /// + /// for _ in 0..5 { + /// let b = barrier.clone(); + /// thread::spawn(move || { + /// // The same messages will be printed together. + /// // There will NOT be interleaving of "before" and "after". + /// println!("before wait"); + /// b.wait_blocking(); + /// println!("after wait"); + /// }); + /// } + /// ``` + #[cfg(all(feature = "std", not(target_family = "wasm")))] + pub fn wait_blocking(&self) -> BarrierWaitResult { + self.wait().wait() } } +easy_wrapper! { + /// The future returned by [`Barrier::wait()`]. + pub struct BarrierWait<'a>(BarrierWaitInner<'a> => BarrierWaitResult); + #[cfg(all(feature = "std", not(target_family = "wasm")))] + pub(crate) wait(); +} + pin_project_lite::pin_project! { /// The future returned by [`Barrier::wait()`]. - pub struct BarrierWait<'a> { + struct BarrierWaitInner<'a> { // The barrier to wait on. barrier: &'a Barrier, @@ -124,18 +173,27 @@ enum WaitState { Reacquiring { local_gen: u64 }, } -impl Future for BarrierWait<'_> { +impl EventListenerFuture for BarrierWaitInner<'_> { type Output = BarrierWaitResult; - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + fn poll_with_strategy<'a, S: Strategy<'a>>( + self: Pin<&mut Self>, + strategy: &mut S, + cx: &mut S::Context, + ) -> Poll { let mut this = self.project(); loop { match this.state { WaitState::Initial => { // See if the lock is ready yet. - let mut state = ready!(this.lock.as_mut().as_pin_mut().unwrap().poll(cx)); - this.lock.set(None); + let mut state = ready!(this + .lock + .as_mut() + .as_pin_mut() + .unwrap() + .poll_with_strategy(strategy, cx)); + this.lock.as_mut().set(None); let local_gen = state.generation_id; state.count += 1; @@ -154,10 +212,10 @@ impl Future for BarrierWait<'_> { } WaitState::Waiting { local_gen } => { - ready!(this.evl.as_mut().poll(cx)); + ready!(strategy.poll(this.evl.as_mut(), cx)); // We are now re-acquiring the mutex. - this.lock.set(Some(this.barrier.state.lock())); + this.lock.as_mut().set(Some(this.barrier.state.lock())); *this.state = WaitState::Reacquiring { local_gen: *local_gen, }; @@ -165,7 +223,12 @@ impl Future for BarrierWait<'_> { WaitState::Reacquiring { local_gen } => { // Acquire the local state again. - let state = ready!(this.lock.as_mut().as_pin_mut().unwrap().poll(cx)); + let state = ready!(this + .lock + .as_mut() + .as_pin_mut() + .unwrap() + .poll_with_strategy(strategy, cx)); this.lock.set(None); if *local_gen == state.generation_id && state.count < this.barrier.n { diff --git a/src/mutex.rs b/src/mutex.rs index a463dae..384e1c2 100644 --- a/src/mutex.rs +++ b/src/mutex.rs @@ -112,6 +112,34 @@ impl Mutex { }) } + /// Acquires the mutex using the blocking strategy. + /// + /// Returns a guard that releases the mutex when dropped. + /// + /// # Blocking + /// + /// Rather than using asynchronous waiting, like the [`lock`] method, this method will + /// block the current thread until the lock is acquired. + /// + /// This method should not be used in an asynchronous context. It is intended to be + /// used in a way that a mutex can be used in both asynchronous and synchronous contexts. + /// Calling this method in an asynchronous context may result in a deadlock. + /// + /// # Examples + /// + /// ``` + /// use async_lock::Mutex; + /// + /// let mutex = Mutex::new(10); + /// let guard = mutex.lock_blocking(); + /// assert_eq!(*guard, 10); + /// ``` + #[cfg(all(feature = "std", not(target_family = "wasm")))] + #[inline] + pub fn lock_blocking(&self) -> MutexGuard<'_, T> { + self.lock().wait() + } + /// Attempts to acquire the mutex. /// /// If the mutex could not be acquired at this time, then [`None`] is returned. Otherwise, a @@ -199,6 +227,35 @@ impl Mutex { }) } + /// Acquires the mutex and clones a reference to it using the blocking strategy. + /// + /// Returns an owned guard that releases the mutex when dropped. + /// + /// # Blocking + /// + /// Rather than using asynchronous waiting, like the [`lock_arc`] method, this method will + /// block the current thread until the lock is acquired. + /// + /// This method should not be used in an asynchronous context. It is intended to be + /// used in a way that a mutex can be used in both asynchronous and synchronous contexts. + /// Calling this method in an asynchronous context may result in a deadlock. + /// + /// # Examples + /// + /// ``` + /// use async_lock::Mutex; + /// use std::sync::Arc; + /// + /// let mutex = Arc::new(Mutex::new(10)); + /// let guard = mutex.lock_arc_blocking(); + /// assert_eq!(*guard, 10); + /// ``` + #[cfg(all(feature = "std", not(target_family = "wasm")))] + #[inline] + pub fn lock_arc_blocking(self: &Arc) -> MutexGuardArc { + self.lock_arc().wait() + } + /// Attempts to acquire the mutex and clone a reference to it. /// /// If the mutex could not be acquired at this time, then [`None`] is returned. Otherwise, an @@ -291,7 +348,7 @@ impl<'a, T: ?Sized> EventListenerFuture for LockInner<'a, T> { #[inline] fn poll_with_strategy<'x, S: event_listener_strategy::Strategy<'x>>( - self: Pin<&'x mut Self>, + self: Pin<&mut Self>, strategy: &mut S, context: &mut S::Context, ) -> Poll { @@ -350,7 +407,7 @@ impl EventListenerFuture for LockArcInnards { type Output = MutexGuardArc; fn poll_with_strategy<'a, S: event_listener_strategy::Strategy<'a>>( - mut self: Pin<&'a mut Self>, + mut self: Pin<&mut Self>, strategy: &mut S, context: &mut S::Context, ) -> Poll { @@ -459,7 +516,7 @@ impl>> EventListenerFuture for AcquireSlow #[cold] fn poll_with_strategy<'a, S: event_listener_strategy::Strategy<'a>>( - mut self: Pin<&'a mut Self>, + mut self: Pin<&mut Self>, strategy: &mut S, context: &mut S::Context, ) -> Poll { diff --git a/src/rwlock.rs b/src/rwlock.rs index ea6e000..0dca9f2 100644 --- a/src/rwlock.rs +++ b/src/rwlock.rs @@ -145,10 +145,7 @@ impl RwLock { /// ``` #[inline] pub fn read_arc<'a>(self: &'a Arc) -> ReadArc<'a, T> { - ReadArc { - raw: self.raw.read(), - lock: self, - } + ReadArc::new(self.raw.read(), self) } } @@ -207,10 +204,41 @@ impl RwLock { /// ``` #[inline] pub fn read(&self) -> Read<'_, T> { - Read { - raw: self.raw.read(), - value: self.value.get(), - } + Read::new(self.raw.read(), self.value.get()) + } + + /// Acquires a read lock. + /// + /// Returns a guard that releases the lock when dropped. + /// + /// Note that attempts to acquire a read lock will block if there are also concurrent attempts + /// to acquire a write lock. + /// + /// # Blocking + /// + /// Rather than using asynchronous waiting, like the [`read`] method, this method will + /// block the current thread until the read lock is acquired. + /// + /// This method should not be used in an asynchronous context. It is intended to be + /// used in a way that a lock can be used in both asynchronous and synchronous contexts. + /// Calling this method in an asynchronous context may result in a deadlock. + /// + /// # Examples + /// + /// ``` + /// use async_lock::RwLock; + /// + /// let lock = RwLock::new(1); + /// + /// let reader = lock.read_blocking(); + /// assert_eq!(*reader, 1); + /// + /// assert!(lock.try_read().is_some()); + /// ``` + #[cfg(all(feature = "std", not(target_family = "wasm")))] + #[inline] + pub fn read_blocking(&self) -> RwLockReadGuard<'_, T> { + self.read().wait() } /// Attempts to acquire a read lock with the possiblity to upgrade to a write lock. @@ -277,10 +305,46 @@ impl RwLock { /// ``` #[inline] pub fn upgradable_read(&self) -> UpgradableRead<'_, T> { - UpgradableRead { - raw: self.raw.upgradable_read(), - value: self.value.get(), - } + UpgradableRead::new(self.raw.upgradable_read(), self.value.get()) + } + + /// Attempts to acquire a read lock with the possiblity to upgrade to a write lock. + /// + /// Returns a guard that releases the lock when dropped. + /// + /// Upgradable read lock reserves the right to be upgraded to a write lock, which means there + /// can be at most one upgradable read lock at a time. + /// + /// Note that attempts to acquire an upgradable read lock will block if there are concurrent + /// attempts to acquire another upgradable read lock or a write lock. + /// + /// # Blocking + /// + /// Rather than using asynchronous waiting, like the [`upgradable_read`] method, this method will + /// block the current thread until the read lock is acquired. + /// + /// This method should not be used in an asynchronous context. It is intended to be + /// used in a way that a lock can be used in both asynchronous and synchronous contexts. + /// Calling this method in an asynchronous context may result in a deadlock. + /// + /// # Examples + /// + /// ``` + /// use async_lock::{RwLock, RwLockUpgradableReadGuard}; + /// + /// let lock = RwLock::new(1); + /// + /// let reader = lock.upgradable_read_blocking(); + /// assert_eq!(*reader, 1); + /// assert_eq!(*lock.try_read().unwrap(), 1); + /// + /// let mut writer = RwLockUpgradableReadGuard::upgrade_blocking(reader); + /// *writer = 2; + /// ``` + #[cfg(all(feature = "std", not(target_family = "wasm")))] + #[inline] + pub fn upgradable_read_blocking(&self) -> RwLockUpgradableReadGuard<'_, T> { + self.upgradable_read().wait() } /// Attempts to acquire an owned, reference-counted read lock with the possiblity to @@ -348,10 +412,7 @@ impl RwLock { /// ``` #[inline] pub fn upgradable_read_arc<'a>(self: &'a Arc) -> UpgradableReadArc<'a, T> { - UpgradableReadArc { - raw: self.raw.upgradable_read(), - lock: self, - } + UpgradableReadArc::new(self.raw.upgradable_read(), self) } /// Attempts to acquire a write lock. @@ -402,10 +463,36 @@ impl RwLock { /// ``` #[inline] pub fn write(&self) -> Write<'_, T> { - Write { - raw: self.raw.write(), - value: self.value.get(), - } + Write::new(self.raw.write(), self.value.get()) + } + + /// Acquires a write lock. + /// + /// Returns a guard that releases the lock when dropped. + /// + /// # Blocking + /// + /// Rather than using asynchronous waiting, like the [`write`] method, this method will + /// block the current thread until the write lock is acquired. + /// + /// This method should not be used in an asynchronous context. It is intended to be + /// used in a way that a lock can be used in both asynchronous and synchronous contexts. + /// Calling this method in an asynchronous context may result in a deadlock. + /// + /// # Examples + /// + /// ``` + /// use async_lock::RwLock; + /// + /// let lock = RwLock::new(1); + /// + /// let writer = lock.write_blocking(); + /// assert!(lock.try_read().is_none()); + /// ``` + #[cfg(all(feature = "std", not(target_family = "wasm")))] + #[inline] + pub fn write_blocking(&self) -> RwLockWriteGuard<'_, T> { + self.write().wait() } /// Attempts to acquire an owned, reference-counted write lock. @@ -455,10 +542,7 @@ impl RwLock { /// ``` #[inline] pub fn write_arc<'a>(self: &'a Arc) -> WriteArc<'a, T> { - WriteArc { - raw: self.raw.write(), - lock: self, - } + WriteArc::new(self.raw.write(), self) } /// Returns a mutable reference to the inner value. @@ -766,11 +850,36 @@ impl<'a, T: ?Sized> RwLockUpgradableReadGuard<'a, T> { pub fn upgrade(guard: Self) -> Upgrade<'a, T> { let reader = ManuallyDrop::new(guard); - Upgrade { + Upgrade::new( // SAFETY: `reader` is an upgradable read guard - raw: unsafe { reader.lock.upgrade() }, - value: reader.value, - } + unsafe { reader.lock.upgrade() }, + reader.value, + ) + } + + /// Upgrades into a write lock. + /// + /// # Blocking + /// + /// This function will block the current thread until it is able to acquire the write lock. + /// + /// # Examples + /// + /// ``` + /// use async_lock::{RwLock, RwLockUpgradableReadGuard}; + /// + /// let lock = RwLock::new(1); + /// + /// let reader = lock.upgradable_read_blocking(); + /// assert_eq!(*reader, 1); + /// + /// let mut writer = RwLockUpgradableReadGuard::upgrade_blocking(reader); + /// *writer = 2; + /// ``` + #[cfg(all(feature = "std", not(target_family = "wasm")))] + #[inline] + pub fn upgrade_blocking(guard: Self) -> RwLockWriteGuard<'a, T> { + RwLockUpgradableReadGuard::upgrade(guard).wait() } } @@ -951,9 +1060,11 @@ impl RwLockUpgradableReadGuardArc { // SAFETY: see above explanation. let raw: RawUpgrade<'static> = unsafe { mem::transmute(raw) }; - UpgradeArc { - raw: ManuallyDrop::new(raw), - lock: ManuallyDrop::new(Self::into_arc(guard)), + unsafe { + UpgradeArc::new( + ManuallyDrop::new(raw), + ManuallyDrop::new(Self::into_arc(guard)), + ) } } } diff --git a/src/rwlock/futures.rs b/src/rwlock/futures.rs index 613d30a..c759157 100644 --- a/src/rwlock/futures.rs +++ b/src/rwlock/futures.rs @@ -1,8 +1,7 @@ use core::fmt; -use core::future::Future; use core::mem::ManuallyDrop; use core::pin::Pin; -use core::task::{Context, Poll}; +use core::task::Poll; use alloc::sync::Arc; @@ -12,9 +11,18 @@ use super::{ RwLockUpgradableReadGuardArc, RwLockWriteGuard, RwLockWriteGuardArc, }; +use event_listener_strategy::{easy_wrapper, EventListenerFuture, Strategy}; + +easy_wrapper! { + /// The future returned by [`RwLock::read`]. + pub struct Read<'a, T: ?Sized>(ReadInner<'a, T> => RwLockReadGuard<'a, T>); + #[cfg(all(feature = "std", not(target_family = "wasm")))] + pub(crate) wait(); +} + pin_project_lite::pin_project! { /// The future returned by [`RwLock::read`]. - pub struct Read<'a, T: ?Sized> { + struct ReadInner<'a, T: ?Sized> { // Raw read lock acquisition future, doesn't depend on `T`. #[pin] pub(super) raw: RawRead<'a>, @@ -24,8 +32,15 @@ pin_project_lite::pin_project! { } } -unsafe impl Send for Read<'_, T> {} -unsafe impl Sync for Read<'_, T> {} +unsafe impl Send for ReadInner<'_, T> {} +unsafe impl Sync for ReadInner<'_, T> {} + +impl<'x, T: ?Sized> Read<'x, T> { + #[inline] + pub(super) fn new(raw: RawRead<'x>, value: *const T) -> Self { + Self::_new(ReadInner { raw, value }) + } +} impl fmt::Debug for Read<'_, T> { #[inline] @@ -34,13 +49,17 @@ impl fmt::Debug for Read<'_, T> { } } -impl<'a, T: ?Sized> Future for Read<'a, T> { +impl<'a, T: ?Sized> EventListenerFuture for ReadInner<'a, T> { type Output = RwLockReadGuard<'a, T>; #[inline] - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + fn poll_with_strategy<'x, S: Strategy<'x>>( + self: Pin<&mut Self>, + strategy: &mut S, + cx: &mut S::Context, + ) -> Poll { let mut this = self.project(); - ready!(this.raw.as_mut().poll(cx)); + ready!(this.raw.as_mut().poll_with_strategy(strategy, cx)); Poll::Ready(RwLockReadGuard { lock: this.raw.lock, @@ -49,9 +68,16 @@ impl<'a, T: ?Sized> Future for Read<'a, T> { } } +easy_wrapper! { + /// The future returned by [`RwLock::read_arc`]. + pub struct ReadArc<'a, T>(ReadArcInner<'a, T> => RwLockReadGuardArc); + #[cfg(all(feature = "std", not(target_family = "wasm")))] + pub(crate) wait(); +} + pin_project_lite::pin_project! { /// The future returned by [`RwLock::read_arc`]. - pub struct ReadArc<'a, T> { + struct ReadArcInner<'a, T> { // Raw read lock acquisition future, doesn't depend on `T`. #[pin] pub(super) raw: RawRead<'a>, @@ -61,8 +87,15 @@ pin_project_lite::pin_project! { } } -unsafe impl Send for ReadArc<'_, T> {} -unsafe impl Sync for ReadArc<'_, T> {} +unsafe impl Send for ReadArcInner<'_, T> {} +unsafe impl Sync for ReadArcInner<'_, T> {} + +impl<'x, T> ReadArc<'x, T> { + #[inline] + pub(super) fn new(raw: RawRead<'x>, lock: &'x Arc>) -> Self { + Self::_new(ReadArcInner { raw, lock }) + } +} impl fmt::Debug for ReadArc<'_, T> { #[inline] @@ -71,22 +104,35 @@ impl fmt::Debug for ReadArc<'_, T> { } } -impl<'a, T> Future for ReadArc<'a, T> { +impl<'a, T> EventListenerFuture for ReadArcInner<'a, T> { type Output = RwLockReadGuardArc; #[inline] - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + fn poll_with_strategy<'x, S: Strategy<'x>>( + self: Pin<&mut Self>, + strategy: &mut S, + cx: &mut S::Context, + ) -> Poll { let mut this = self.project(); - ready!(this.raw.as_mut().poll(cx)); + ready!(this.raw.as_mut().poll_with_strategy(strategy, cx)); // SAFETY: we just acquired a read lock Poll::Ready(unsafe { RwLockReadGuardArc::from_arc(this.lock.clone()) }) } } +easy_wrapper! { + /// The future returned by [`RwLock::upgradable_read`]. + pub struct UpgradableRead<'a, T: ?Sized>( + UpgradableReadInner<'a, T> => RwLockUpgradableReadGuard<'a, T> + ); + #[cfg(all(feature = "std", not(target_family = "wasm")))] + pub(crate) wait(); +} + pin_project_lite::pin_project! { /// The future returned by [`RwLock::upgradable_read`]. - pub struct UpgradableRead<'a, T: ?Sized> { + struct UpgradableReadInner<'a, T: ?Sized> { // Raw upgradable read lock acquisition future, doesn't depend on `T`. #[pin] pub(super) raw: RawUpgradableRead<'a>, @@ -97,8 +143,15 @@ pin_project_lite::pin_project! { } } -unsafe impl Send for UpgradableRead<'_, T> {} -unsafe impl Sync for UpgradableRead<'_, T> {} +unsafe impl Send for UpgradableReadInner<'_, T> {} +unsafe impl Sync for UpgradableReadInner<'_, T> {} + +impl<'x, T: ?Sized> UpgradableRead<'x, T> { + #[inline] + pub(super) fn new(raw: RawUpgradableRead<'x>, value: *mut T) -> Self { + Self::_new(UpgradableReadInner { raw, value }) + } +} impl fmt::Debug for UpgradableRead<'_, T> { #[inline] @@ -107,13 +160,17 @@ impl fmt::Debug for UpgradableRead<'_, T> { } } -impl<'a, T: ?Sized> Future for UpgradableRead<'a, T> { +impl<'a, T: ?Sized> EventListenerFuture for UpgradableReadInner<'a, T> { type Output = RwLockUpgradableReadGuard<'a, T>; #[inline] - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + fn poll_with_strategy<'x, S: Strategy<'x>>( + self: Pin<&mut Self>, + strategy: &mut S, + cx: &mut S::Context, + ) -> Poll { let mut this = self.project(); - ready!(this.raw.as_mut().poll(cx)); + ready!(this.raw.as_mut().poll_with_strategy(strategy, cx)); Poll::Ready(RwLockUpgradableReadGuard { lock: this.raw.lock, @@ -122,9 +179,18 @@ impl<'a, T: ?Sized> Future for UpgradableRead<'a, T> { } } +easy_wrapper! { + /// The future returned by [`RwLock::upgradable_read_arc`]. + pub struct UpgradableReadArc<'a, T: ?Sized>( + UpgradableReadArcInner<'a, T> => RwLockUpgradableReadGuardArc + ); + #[cfg(all(feature = "std", not(target_family = "wasm")))] + pub(crate) wait(); +} + pin_project_lite::pin_project! { /// The future returned by [`RwLock::upgradable_read_arc`]. - pub struct UpgradableReadArc<'a, T: ?Sized> { + struct UpgradableReadArcInner<'a, T: ?Sized> { // Raw upgradable read lock acquisition future, doesn't depend on `T`. #[pin] pub(super) raw: RawUpgradableRead<'a>, @@ -133,8 +199,15 @@ pin_project_lite::pin_project! { } } -unsafe impl Send for UpgradableReadArc<'_, T> {} -unsafe impl Sync for UpgradableReadArc<'_, T> {} +unsafe impl Send for UpgradableReadArcInner<'_, T> {} +unsafe impl Sync for UpgradableReadArcInner<'_, T> {} + +impl<'x, T: ?Sized> UpgradableReadArc<'x, T> { + #[inline] + pub(super) fn new(raw: RawUpgradableRead<'x>, lock: &'x Arc>) -> Self { + Self::_new(UpgradableReadArcInner { raw, lock }) + } +} impl fmt::Debug for UpgradableReadArc<'_, T> { #[inline] @@ -143,22 +216,33 @@ impl fmt::Debug for UpgradableReadArc<'_, T> { } } -impl<'a, T: ?Sized> Future for UpgradableReadArc<'a, T> { +impl<'a, T: ?Sized> EventListenerFuture for UpgradableReadArcInner<'a, T> { type Output = RwLockUpgradableReadGuardArc; #[inline] - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + fn poll_with_strategy<'x, S: Strategy<'x>>( + self: Pin<&mut Self>, + strategy: &mut S, + cx: &mut S::Context, + ) -> Poll { let mut this = self.project(); - ready!(this.raw.as_mut().poll(cx)); + ready!(this.raw.as_mut().poll_with_strategy(strategy, cx)); Poll::Ready(RwLockUpgradableReadGuardArc { lock: this.lock.clone(), }) } } +easy_wrapper! { + /// The future returned by [`RwLock::write`]. + pub struct Write<'a, T: ?Sized>(WriteInner<'a, T> => RwLockWriteGuard<'a, T>); + #[cfg(all(feature = "std", not(target_family = "wasm")))] + pub(crate) wait(); +} + pin_project_lite::pin_project! { /// The future returned by [`RwLock::write`]. - pub struct Write<'a, T: ?Sized> { + struct WriteInner<'a, T: ?Sized> { // Raw write lock acquisition future, doesn't depend on `T`. #[pin] pub(super) raw: RawWrite<'a>, @@ -168,8 +252,15 @@ pin_project_lite::pin_project! { } } -unsafe impl Send for Write<'_, T> {} -unsafe impl Sync for Write<'_, T> {} +unsafe impl Send for WriteInner<'_, T> {} +unsafe impl Sync for WriteInner<'_, T> {} + +impl<'x, T: ?Sized> Write<'x, T> { + #[inline] + pub(super) fn new(raw: RawWrite<'x>, value: *mut T) -> Self { + Self::_new(WriteInner { raw, value }) + } +} impl fmt::Debug for Write<'_, T> { #[inline] @@ -178,13 +269,17 @@ impl fmt::Debug for Write<'_, T> { } } -impl<'a, T: ?Sized> Future for Write<'a, T> { +impl<'a, T: ?Sized> EventListenerFuture for WriteInner<'a, T> { type Output = RwLockWriteGuard<'a, T>; #[inline] - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + fn poll_with_strategy<'x, S: Strategy<'x>>( + self: Pin<&mut Self>, + strategy: &mut S, + cx: &mut S::Context, + ) -> Poll { let mut this = self.project(); - ready!(this.raw.as_mut().poll(cx)); + ready!(this.raw.as_mut().poll_with_strategy(strategy, cx)); Poll::Ready(RwLockWriteGuard { lock: this.raw.lock, @@ -193,9 +288,16 @@ impl<'a, T: ?Sized> Future for Write<'a, T> { } } +easy_wrapper! { + /// The future returned by [`RwLock::write_arc`]. + pub struct WriteArc<'a, T: ?Sized>(WriteArcInner<'a, T> => RwLockWriteGuardArc); + #[cfg(all(feature = "std", not(target_family = "wasm")))] + pub(crate) wait(); +} + pin_project_lite::pin_project! { /// The future returned by [`RwLock::write_arc`]. - pub struct WriteArc<'a, T: ?Sized> { + struct WriteArcInner<'a, T: ?Sized> { // Raw write lock acquisition future, doesn't depend on `T`. #[pin] pub(super) raw: RawWrite<'a>, @@ -204,8 +306,15 @@ pin_project_lite::pin_project! { } } -unsafe impl Send for WriteArc<'_, T> {} -unsafe impl Sync for WriteArc<'_, T> {} +unsafe impl Send for WriteArcInner<'_, T> {} +unsafe impl Sync for WriteArcInner<'_, T> {} + +impl<'x, T: ?Sized> WriteArc<'x, T> { + #[inline] + pub(super) fn new(raw: RawWrite<'x>, lock: &'x Arc>) -> Self { + Self::_new(WriteArcInner { raw, lock }) + } +} impl fmt::Debug for WriteArc<'_, T> { #[inline] @@ -214,13 +323,17 @@ impl fmt::Debug for WriteArc<'_, T> { } } -impl<'a, T: ?Sized> Future for WriteArc<'a, T> { +impl<'a, T: ?Sized> EventListenerFuture for WriteArcInner<'a, T> { type Output = RwLockWriteGuardArc; #[inline] - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + fn poll_with_strategy<'x, S: Strategy<'x>>( + self: Pin<&mut Self>, + strategy: &mut S, + cx: &mut S::Context, + ) -> Poll { let mut this = self.project(); - ready!(this.raw.as_mut().poll(cx)); + ready!(this.raw.as_mut().poll_with_strategy(strategy, cx)); Poll::Ready(RwLockWriteGuardArc { lock: this.lock.clone(), @@ -228,9 +341,16 @@ impl<'a, T: ?Sized> Future for WriteArc<'a, T> { } } +easy_wrapper! { + /// The future returned by [`RwLockUpgradableReadGuard::upgrade`]. + pub struct Upgrade<'a, T: ?Sized>(UpgradeInner<'a, T> => RwLockWriteGuard<'a, T>); + #[cfg(all(feature = "std", not(target_family = "wasm")))] + pub(crate) wait(); +} + pin_project_lite::pin_project! { /// The future returned by [`RwLockUpgradableReadGuard::upgrade`]. - pub struct Upgrade<'a, T: ?Sized> { + struct UpgradeInner<'a, T: ?Sized> { // Raw read lock upgrade future, doesn't depend on `T`. #[pin] pub(super) raw: RawUpgrade<'a>, @@ -240,8 +360,15 @@ pin_project_lite::pin_project! { } } -unsafe impl Send for Upgrade<'_, T> {} -unsafe impl Sync for Upgrade<'_, T> {} +unsafe impl Send for UpgradeInner<'_, T> {} +unsafe impl Sync for UpgradeInner<'_, T> {} + +impl<'x, T: ?Sized> Upgrade<'x, T> { + #[inline] + pub(super) fn new(raw: RawUpgrade<'x>, value: *mut T) -> Self { + Self::_new(UpgradeInner { raw, value }) + } +} impl fmt::Debug for Upgrade<'_, T> { #[inline] @@ -250,13 +377,17 @@ impl fmt::Debug for Upgrade<'_, T> { } } -impl<'a, T: ?Sized> Future for Upgrade<'a, T> { +impl<'a, T: ?Sized> EventListenerFuture for UpgradeInner<'a, T> { type Output = RwLockWriteGuard<'a, T>; #[inline] - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + fn poll_with_strategy<'x, S: Strategy<'x>>( + self: Pin<&mut Self>, + strategy: &mut S, + cx: &mut S::Context, + ) -> Poll { let mut this = self.project(); - let lock = ready!(this.raw.as_mut().poll(cx)); + let lock = ready!(this.raw.as_mut().poll_with_strategy(strategy, cx)); Poll::Ready(RwLockWriteGuard { lock, @@ -265,9 +396,16 @@ impl<'a, T: ?Sized> Future for Upgrade<'a, T> { } } +easy_wrapper! { + /// The future returned by [`RwLockUpgradableReadGuardArc::upgrade`]. + pub struct UpgradeArc(UpgradeArcInner => RwLockWriteGuardArc); + #[cfg(all(feature = "std", not(target_family = "wasm")))] + pub(crate) wait(); +} + pin_project_lite::pin_project! { /// The future returned by [`RwLockUpgradableReadGuardArc::upgrade`]. - pub struct UpgradeArc { + struct UpgradeArcInner { // Raw read lock upgrade future, doesn't depend on `T`. // `'static` is a lie, this field is actually referencing the // `Arc` data. But since this struct also stores said `Arc`, we know @@ -285,7 +423,7 @@ pin_project_lite::pin_project! { pub(super) lock: ManuallyDrop>>, } - impl PinnedDrop for UpgradeArc { + impl PinnedDrop for UpgradeArcInner { fn drop(this: Pin<&mut Self>) { let this = this.project(); if !this.raw.is_ready() { @@ -302,6 +440,16 @@ pin_project_lite::pin_project! { } } +impl UpgradeArc { + #[inline] + pub(super) unsafe fn new( + raw: ManuallyDrop>, + lock: ManuallyDrop>>, + ) -> Self { + Self::_new(UpgradeArcInner { raw, lock }) + } +} + impl fmt::Debug for UpgradeArc { #[inline] fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -309,15 +457,20 @@ impl fmt::Debug for UpgradeArc { } } -impl Future for UpgradeArc { +impl EventListenerFuture for UpgradeArcInner { type Output = RwLockWriteGuardArc; #[inline] - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + fn poll_with_strategy<'x, S: Strategy<'x>>( + self: Pin<&mut Self>, + strategy: &mut S, + cx: &mut S::Context, + ) -> Poll { let this = self.project(); unsafe { // SAFETY: Practically, this is a pin projection. - ready!(Pin::new_unchecked(&mut **this.raw.get_unchecked_mut()).poll(cx)); + ready!(Pin::new_unchecked(&mut **this.raw.get_unchecked_mut()) + .poll_with_strategy(strategy, cx)); } Poll::Ready(RwLockWriteGuardArc { diff --git a/src/rwlock/raw.rs b/src/rwlock/raw.rs index 816e491..df08edb 100644 --- a/src/rwlock/raw.rs +++ b/src/rwlock/raw.rs @@ -6,13 +6,13 @@ //! the locking code only once, and also lets us make //! [`RwLockReadGuard`](super::RwLockReadGuard) covariant in `T`. -use core::future::Future; use core::mem::forget; use core::pin::Pin; use core::sync::atomic::{AtomicUsize, Ordering}; -use core::task::{Context, Poll}; +use core::task::Poll; use event_listener::{Event, EventListener}; +use event_listener_strategy::{EventListenerFuture, Strategy}; use crate::futures::Lock; use crate::Mutex; @@ -53,7 +53,6 @@ impl RawRwLock { } /// Returns `true` iff a read lock was successfully acquired. - pub(super) fn try_read(&self) -> bool { let mut state = self.state.load(Ordering::Acquire); @@ -298,10 +297,14 @@ pin_project_lite::pin_project! { } } -impl<'a> Future for RawRead<'a> { +impl<'a> EventListenerFuture for RawRead<'a> { type Output = (); - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + fn poll_with_strategy<'x, S: Strategy<'x>>( + self: Pin<&mut Self>, + strategy: &mut S, + cx: &mut S::Context, + ) -> Poll<()> { let mut this = self.project(); loop { @@ -331,7 +334,7 @@ impl<'a> Future for RawRead<'a> { Ordering::SeqCst } else { // Wait for the writer to finish. - ready!(this.listener.as_mut().poll(cx)); + ready!(strategy.poll(this.listener.as_mut(), cx)); // Notify the next reader waiting in list. this.lock.no_writer.notify(1); @@ -349,7 +352,6 @@ impl<'a> Future for RawRead<'a> { pin_project_lite::pin_project! { /// The future returned by [`RawRwLock::upgradable_read`]. - pub(super) struct RawUpgradableRead<'a> { // The lock that is being acquired. pub(super) lock: &'a RawRwLock, @@ -360,14 +362,18 @@ pin_project_lite::pin_project! { } } -impl<'a> Future for RawUpgradableRead<'a> { +impl<'a> EventListenerFuture for RawUpgradableRead<'a> { type Output = (); - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + fn poll_with_strategy<'x, S: Strategy<'x>>( + self: Pin<&mut Self>, + strategy: &mut S, + cx: &mut S::Context, + ) -> Poll<()> { let this = self.project(); // Acquire the mutex. - let mutex_guard = ready!(this.acquire.poll(cx)); + let mutex_guard = ready!(this.acquire.poll_with_strategy(strategy, cx)); forget(mutex_guard); // Load the current state. @@ -427,6 +433,7 @@ pin_project_lite::pin_project! { pin_project_lite::pin_project! { #[project = WriteStateProj] + #[project_replace = WriteStateProjReplace] enum WriteState<'a> { // We are currently acquiring the inner mutex. Acquiring { #[pin] lock: Lock<'a, ()> }, @@ -439,17 +446,21 @@ pin_project_lite::pin_project! { } } -impl<'a> Future for RawWrite<'a> { +impl<'a> EventListenerFuture for RawWrite<'a> { type Output = (); - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + fn poll_with_strategy<'x, S: Strategy<'x>>( + self: Pin<&mut Self>, + strategy: &mut S, + cx: &mut S::Context, + ) -> Poll<()> { let mut this = self.project(); loop { match this.state.as_mut().project() { WriteStateProj::Acquiring { lock } => { // First grab the mutex. - let mutex_guard = ready!(lock.poll(cx)); + let mutex_guard = ready!(lock.poll_with_strategy(strategy, cx)); forget(mutex_guard); // Set `WRITER_BIT` and create a guard that unsets it in case this future is canceled. @@ -486,7 +497,7 @@ impl<'a> Future for RawWrite<'a> { this.no_readers.as_mut().listen(); } else { // Wait for the readers to finish. - ready!(this.no_readers.as_mut().poll(cx)); + ready!(strategy.poll(this.no_readers.as_mut(), cx)); }; } WriteStateProj::Acquired => panic!("Write lock already acquired"), @@ -520,10 +531,14 @@ pin_project_lite::pin_project! { } } -impl<'a> Future for RawUpgrade<'a> { +impl<'a> EventListenerFuture for RawUpgrade<'a> { type Output = &'a RawRwLock; - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<&'a RawRwLock> { + fn poll_with_strategy<'x, S: Strategy<'x>>( + self: Pin<&mut Self>, + strategy: &mut S, + cx: &mut S::Context, + ) -> Poll<&'a RawRwLock> { let mut this = self.project(); let lock = this.lock.expect("cannot poll future after completion"); @@ -547,7 +562,7 @@ impl<'a> Future for RawUpgrade<'a> { this.listener.as_mut().listen(); } else { // Wait for the readers to finish. - ready!(this.listener.as_mut().poll(cx)); + ready!(strategy.poll(this.listener.as_mut(), cx)); }; } diff --git a/src/semaphore.rs b/src/semaphore.rs index dedf531..6d5169d 100644 --- a/src/semaphore.rs +++ b/src/semaphore.rs @@ -7,6 +7,7 @@ use core::task::{Context, Poll}; use alloc::sync::Arc; use event_listener::{Event, EventListener}; +use event_listener_strategy::{easy_wrapper, EventListenerFuture, Strategy}; /// A counter for limiting the number of concurrent operations. #[derive(Debug)] @@ -85,10 +86,37 @@ impl Semaphore { /// # }); /// ``` pub fn acquire(&self) -> Acquire<'_> { - Acquire { + Acquire::_new(AcquireInner { semaphore: self, listener: EventListener::new(&self.event), - } + }) + } + + /// Waits for a permit for a concurrent operation. + /// + /// Returns a guard that releases the permit when dropped. + /// + /// # Blocking + /// + /// Rather than using asynchronous waiting, like the [`acquire`] method, this method will + /// block the current thread until the permit is acquired. + /// + /// This method should not be used in an asynchronous context. It is intended to be + /// used in a way that a semaphore can be used in both asynchronous and synchronous contexts. + /// Calling this method in an asynchronous context may result in a deadlock. + /// + /// # Examples + /// + /// ``` + /// use async_lock::Semaphore; + /// + /// let s = Semaphore::new(2); + /// let guard = s.acquire_blocking(); + /// ``` + #[cfg(all(feature = "std", not(target_family = "wasm")))] + #[inline] + pub fn acquire_blocking(&self) -> SemaphoreGuard<'_> { + self.acquire().wait() } /// Attempts to get an owned permit for a concurrent operation. @@ -177,9 +205,15 @@ impl Semaphore { } } -pin_project_lite::pin_project! { +easy_wrapper! { /// The future returned by [`Semaphore::acquire`]. - pub struct Acquire<'a> { + pub struct Acquire<'a>(AcquireInner<'a> => SemaphoreGuard<'a>); + #[cfg(all(feature = "std", not(target_family = "wasm")))] + pub(crate) wait(); +} + +pin_project_lite::pin_project! { + struct AcquireInner<'a> { // The semaphore being acquired. semaphore: &'a Semaphore, @@ -189,16 +223,20 @@ pin_project_lite::pin_project! { } } -impl fmt::Debug for Acquire<'_> { +impl fmt::Debug for AcquireInner<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.write_str("Acquire { .. }") } } -impl<'a> Future for Acquire<'a> { +impl<'a> EventListenerFuture for AcquireInner<'a> { type Output = SemaphoreGuard<'a>; - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + fn poll_with_strategy<'x, S: Strategy<'x>>( + self: Pin<&mut Self>, + strategy: &mut S, + cx: &mut S::Context, + ) -> Poll { let mut this = self.project(); loop { @@ -209,7 +247,7 @@ impl<'a> Future for Acquire<'a> { if !this.listener.is_listening() { this.listener.as_mut().listen(); } else { - ready!(this.listener.as_mut().poll(cx)); + ready!(strategy.poll(this.listener.as_mut(), cx)); } } } diff --git a/tests/barrier.rs b/tests/barrier.rs index 3efd355..657478c 100644 --- a/tests/barrier.rs +++ b/tests/barrier.rs @@ -44,3 +44,43 @@ fn smoke() { } }); } + +#[cfg(all(feature = "std", not(target_family = "wasm")))] +#[test] +fn smoke_blocking() { + future::block_on(async move { + const N: usize = 10; + + let barrier = Arc::new(Barrier::new(N)); + + for _ in 0..10 { + let (tx, rx) = async_channel::unbounded(); + + for _ in 0..N - 1 { + let c = barrier.clone(); + let tx = tx.clone(); + + thread::spawn(move || { + let res = c.wait_blocking(); + tx.send_blocking(res.is_leader()).unwrap(); + }); + } + + // At this point, all spawned threads should be blocked, + // so we shouldn't get anything from the cahnnel. + let res = rx.try_recv(); + assert!(res.is_err()); + + let mut leader_found = barrier.wait_blocking().is_leader(); + + // Now, the barrier is cleared and we should get data. + for _ in 0..N - 1 { + if rx.recv().await.unwrap() { + assert!(!leader_found); + leader_found = true; + } + } + assert!(leader_found); + } + }); +} diff --git a/tests/mutex.rs b/tests/mutex.rs index 267358a..dce687a 100644 --- a/tests/mutex.rs +++ b/tests/mutex.rs @@ -24,6 +24,14 @@ fn smoke() { }) } +#[cfg(all(feature = "std", not(target_family = "wasm")))] +#[test] +fn smoke_blocking() { + let m = Mutex::new(()); + drop(m.lock_blocking()); + drop(m.lock_blocking()); +} + #[test] fn try_lock() { let m = Mutex::new(()); diff --git a/tests/rwlock.rs b/tests/rwlock.rs index 6a5ea86..78c8ee5 100644 --- a/tests/rwlock.rs +++ b/tests/rwlock.rs @@ -47,6 +47,16 @@ fn smoke() { }); } +#[cfg(all(feature = "std", not(target_family = "wasm")))] +#[test] +fn smoke_blocking() { + let lock = RwLock::new(()); + drop(lock.read_blocking()); + drop(lock.write_blocking()); + drop((lock.read_blocking(), lock.read_blocking())); + drop(lock.write_blocking()); +} + #[test] fn try_write() { future::block_on(async { diff --git a/tests/semaphore.rs b/tests/semaphore.rs index 0779b68..0aab0b5 100644 --- a/tests/semaphore.rs +++ b/tests/semaphore.rs @@ -114,6 +114,17 @@ fn yields_when_contended() { check_yields_when_contended(s.try_acquire_arc().unwrap(), s.acquire_arc()); } +#[cfg(all(feature = "std", not(target_family = "wasm")))] +#[test] +fn smoke_blocking() { + let s = Semaphore::new(2); + let g1 = s.acquire_blocking(); + let _g2 = s.acquire_blocking(); + assert!(s.try_acquire().is_none()); + drop(g1); + assert!(s.try_acquire().is_some()); +} + #[test] fn add_permits() { static COUNTER: AtomicUsize = AtomicUsize::new(0);