Skip to content

Commit

Permalink
fallback to blocking in mpsc channels
Browse files Browse the repository at this point in the history
  • Loading branch information
ibraheemdev committed May 18, 2024
1 parent 8af67ba commit 5c11519
Show file tree
Hide file tree
Showing 4 changed files with 540 additions and 139 deletions.
162 changes: 127 additions & 35 deletions library/std/src/sync/mpmc/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,18 @@ pub(crate) struct Channel<T> {
receivers: SyncWaker,
}

/// The state of the channel after calling `start_recv` or `start_send`.
#[derive(PartialEq, Eq)]
enum Status {
/// The channel is ready to read or write to.
Ready,
/// There is currently a send or receive in progress holding up the queue.
/// All operations must block to preserve linearizability.
InProgress,
/// The channel is empty.
Empty,
}

impl<T> Channel<T> {
/// Creates a bounded channel of capacity `cap`.
pub(crate) fn with_capacity(cap: usize) -> Self {
Expand Down Expand Up @@ -122,7 +134,7 @@ impl<T> Channel<T> {
}

/// Attempts to reserve a slot for sending a message.
fn start_send(&self, token: &mut Token) -> bool {
fn start_send(&self, token: &mut Token) -> Status {
let backoff = Backoff::new();
let mut tail = self.tail.load(Ordering::Relaxed);

Expand All @@ -131,7 +143,7 @@ impl<T> Channel<T> {
if tail & self.mark_bit != 0 {
token.array.slot = ptr::null();
token.array.stamp = 0;
return true;
return Status::Ready;
}

// Deconstruct the tail.
Expand Down Expand Up @@ -166,7 +178,7 @@ impl<T> Channel<T> {
// Prepare the token for the follow-up call to `write`.
token.array.slot = slot as *const Slot<T> as *const u8;
token.array.stamp = tail + 1;
return true;
return Status::Ready;
}
Err(_) => {
backoff.spin_light();
Expand All @@ -180,10 +192,16 @@ impl<T> Channel<T> {
// If the head lags one lap behind the tail as well...
if head.wrapping_add(self.one_lap) == tail {
// ...then the channel is full.
return false;
return Status::Empty;
}

// The head was advanced but the stamp hasn't been updated yet,
// meaning a receive is in-progress. Spin for a bit waiting for
// the receive to complete before falling back to blocking.
if !backoff.try_spin_light() {
return Status::InProgress;
}

backoff.spin_light();
tail = self.tail.load(Ordering::Relaxed);
} else {
// Snooze because we need to wait for the stamp to get updated.
Expand All @@ -200,10 +218,10 @@ impl<T> Channel<T> {
return Err(msg);
}

let slot: &Slot<T> = &*(token.array.slot as *const Slot<T>);
let slot: &Slot<T> = unsafe { &*token.array.slot.cast::<Slot<T>>() };

// Write the message into the slot and update the stamp.
slot.msg.get().write(MaybeUninit::new(msg));
unsafe { slot.msg.get().write(MaybeUninit::new(msg)) }
slot.stamp.store(token.array.stamp, Ordering::Release);

// Wake a sleeping receiver.
Expand All @@ -212,7 +230,7 @@ impl<T> Channel<T> {
}

/// Attempts to reserve a slot for receiving a message.
fn start_recv(&self, token: &mut Token) -> bool {
fn start_recv(&self, token: &mut Token) -> Status {
let backoff = Backoff::new();
let mut head = self.head.load(Ordering::Relaxed);

Expand Down Expand Up @@ -249,7 +267,7 @@ impl<T> Channel<T> {
// Prepare the token for the follow-up call to `read`.
token.array.slot = slot as *const Slot<T> as *const u8;
token.array.stamp = head.wrapping_add(self.one_lap);
return true;
return Status::Ready;
}
Err(_) => {
backoff.spin_light();
Expand All @@ -267,14 +285,20 @@ impl<T> Channel<T> {
// ...then receive an error.
token.array.slot = ptr::null();
token.array.stamp = 0;
return true;
return Status::Ready;
} else {
// Otherwise, the receive operation is not ready.
return false;
return Status::Empty;
}
}

backoff.spin_light();
// The tail was advanced but the stamp hasn't been updated yet,
// meaning a send is in-progress. Spin for a bit waiting for
// the send to complete before falling back to blocking.
if !backoff.try_spin_light() {
return Status::InProgress;
}

head = self.head.load(Ordering::Relaxed);
} else {
// Snooze because we need to wait for the stamp to get updated.
Expand All @@ -291,10 +315,10 @@ impl<T> Channel<T> {
return Err(());
}

let slot: &Slot<T> = &*(token.array.slot as *const Slot<T>);
let slot: &Slot<T> = unsafe { &*token.array.slot.cast::<Slot<T>>() };

// Read the message from the slot and update the stamp.
let msg = slot.msg.get().read().assume_init();
let msg = unsafe { slot.msg.get().read().assume_init() };
slot.stamp.store(token.array.stamp, Ordering::Release);

// Wake a sleeping sender.
Expand All @@ -304,11 +328,13 @@ impl<T> Channel<T> {

/// Attempts to send a message into the channel.
pub(crate) fn try_send(&self, msg: T) -> Result<(), TrySendError<T>> {
let token = &mut Token::default();
if self.start_send(token) {
unsafe { self.write(token, msg).map_err(TrySendError::Disconnected) }
} else {
Err(TrySendError::Full(msg))
match self.send_blocking(msg, None, false) {
Ok(None) => Ok(()),
Ok(Some(msg)) => Err(TrySendError::Full(msg)),
Err(SendTimeoutError::Disconnected(msg)) => Err(TrySendError::Disconnected(msg)),
Err(SendTimeoutError::Timeout(_)) => {
unreachable!("called recv_blocking with deadline: None")
}
}
}

Expand All @@ -318,12 +344,43 @@ impl<T> Channel<T> {
msg: T,
deadline: Option<Instant>,
) -> Result<(), SendTimeoutError<T>> {
self.send_blocking(msg, deadline, true)
.map(|value| assert!(value.is_none(), "called send_blocking with block: true"))
}

/// Sends a message into the channel.
///
/// Blocks until a message is sent if `should_block` is `true`. Otherwise, returns `Ok(Some(msg))` if
/// the channel is full.
///
/// Note this method may still block when `should_block` is `false` if the channel is in an inconsistent state.
pub(crate) fn send_blocking(
&self,
msg: T,
deadline: Option<Instant>,
should_block: bool,
) -> Result<Option<T>, SendTimeoutError<T>> {
let token = &mut Token::default();
let mut state = self.senders.start();
loop {
// Try sending a message.
if self.start_send(token) {
let res = unsafe { self.write(token, msg) };
return res.map_err(SendTimeoutError::Disconnected);
// Try sending a message several times.
let backoff = Backoff::new();
loop {
match self.start_send(token) {
Status::Ready => {
let res = unsafe { self.write(token, msg) };
return res.map(|_| None).map_err(SendTimeoutError::Disconnected);
}
// If the channel is full, return or block immediately.
Status::Empty if !should_block => return Ok(Some(msg)),
Status::Empty => break,
// Otherwise spin for a bit before blocking.
Status::InProgress => {}
}

if !backoff.try_spin_light() {
break;
}
}

if let Some(d) = deadline {
Expand All @@ -335,7 +392,7 @@ impl<T> Channel<T> {
Context::with(|cx| {
// Prepare for blocking until a receiver wakes us up.
let oper = Operation::hook(token);
self.senders.register(oper, cx);
self.senders.register(oper, cx, &state);

// Has the channel become ready just now?
if !self.is_full() || self.is_disconnected() {
Expand All @@ -353,28 +410,61 @@ impl<T> Channel<T> {
Selected::Operation(_) => {}
}
});

state.unpark();
}
}

/// Attempts to receive a message without blocking.
pub(crate) fn try_recv(&self) -> Result<T, TryRecvError> {
let token = &mut Token::default();

if self.start_recv(token) {
unsafe { self.read(token).map_err(|_| TryRecvError::Disconnected) }
} else {
Err(TryRecvError::Empty)
match self.recv_blocking(None, false) {
Ok(Some(value)) => Ok(value),
Ok(None) => Err(TryRecvError::Empty),
Err(RecvTimeoutError::Disconnected) => Err(TryRecvError::Disconnected),
Err(RecvTimeoutError::Timeout) => {
unreachable!("called recv_blocking with deadline: None")
}
}
}

/// Receives a message from the channel.
pub(crate) fn recv(&self, deadline: Option<Instant>) -> Result<T, RecvTimeoutError> {
self.recv_blocking(deadline, true)
.map(|value| value.expect("called recv_blocking with block: true"))
}

/// Receives a message from the channel.
///
/// Blocks until a message is received if `should_block` is `true`. Otherwise, returns `Ok(None)` if
/// the channel is full.
///
/// Note this may still block when `should_block` is `false` if the channel is in an inconsistent state.
pub(crate) fn recv_blocking(
&self,
deadline: Option<Instant>,
should_block: bool,
) -> Result<Option<T>, RecvTimeoutError> {
let token = &mut Token::default();
let mut state = self.receivers.start();
loop {
// Try receiving a message.
if self.start_recv(token) {
let res = unsafe { self.read(token) };
return res.map_err(|_| RecvTimeoutError::Disconnected);
// Try receiving a message several times.
let backoff = Backoff::new();
loop {
match self.start_recv(token) {
Status::Ready => {
let res = unsafe { self.read(token) };
return res.map(Some).map_err(|_| RecvTimeoutError::Disconnected);
}
// If the channel is empty, return or block immediately.
Status::Empty if !should_block => return Ok(None),
Status::Empty => break,
// Otherwise spin for a bit before blocking.
Status::InProgress => {}
}

if !backoff.try_spin_light() {
break;
}
}

if let Some(d) = deadline {
Expand All @@ -386,7 +476,7 @@ impl<T> Channel<T> {
Context::with(|cx| {
// Prepare for blocking until a sender wakes us up.
let oper = Operation::hook(token);
self.receivers.register(oper, cx);
self.receivers.register(oper, cx, &state);

// Has the channel become ready just now?
if !self.is_empty() || self.is_disconnected() {
Expand All @@ -406,6 +496,8 @@ impl<T> Channel<T> {
Selected::Operation(_) => {}
}
});

state.unpark();
}
}

Expand Down
Loading

0 comments on commit 5c11519

Please sign in to comment.