Skip to content

Commit

Permalink
An attempt at #401 - removing TX busywait
Browse files Browse the repository at this point in the history
  • Loading branch information
Eugeny committed Dec 4, 2024
1 parent 785cfbf commit e2c6e3d
Show file tree
Hide file tree
Showing 7 changed files with 126 additions and 40 deletions.
11 changes: 5 additions & 6 deletions russh/src/channels/channel_ref.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,26 @@
use std::sync::Arc;

use tokio::sync::mpsc::UnboundedSender;
use tokio::sync::Mutex;

use crate::ChannelMsg;

use super::WindowSizeRef;

/// A handle to the [`super::Channel`]'s to be able to transmit messages
/// to it and update it's `window_size`.
#[derive(Debug)]
pub struct ChannelRef {
pub(super) sender: UnboundedSender<ChannelMsg>,
pub(super) window_size: Arc<Mutex<u32>>,
pub(super) window_size: WindowSizeRef,
}

impl ChannelRef {
pub fn new(sender: UnboundedSender<ChannelMsg>) -> Self {
Self {
sender,
window_size: Default::default(),
window_size: WindowSizeRef::new(0),
}
}

pub fn window_size(&self) -> &Arc<Mutex<u32>> {
pub fn window_size(&self) -> &WindowSizeRef {
&self.window_size
}
}
Expand Down
105 changes: 86 additions & 19 deletions russh/src/channels/io/tx.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
use std::convert::TryFrom;
use std::future::Future;
use std::io;
use std::num::NonZero;
use std::ops::DerefMut;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{ready, Context, Poll};
Expand All @@ -7,7 +11,7 @@ use futures::FutureExt;
use tokio::io::AsyncWrite;
use tokio::sync::mpsc::error::SendError;
use tokio::sync::mpsc::{self, OwnedPermit};
use tokio::sync::{Mutex, OwnedMutexGuard};
use tokio::sync::{watch, Mutex, OwnedMutexGuard};

use super::ChannelMsg;
use crate::{ChannelId, CryptoVec};
Expand All @@ -16,13 +20,50 @@ type BoxedThreadsafeFuture<T> = Pin<Box<dyn Sync + Send + std::future::Future<Ou
type OwnedPermitFuture<S> =
BoxedThreadsafeFuture<Result<(OwnedPermit<S>, ChannelMsg, usize), SendError<()>>>;

async fn _watch_changed<T>(
mut w: watch::Receiver<T>,
) -> Result<watch::Receiver<T>, watch::error::RecvError> {
w.changed().await?;
w.borrow_and_update();
Ok(w)
}

struct WatchNotification<T>(
Pin<
Box<dyn Sync + Send + Future<Output = Result<watch::Receiver<T>, watch::error::RecvError>>>,
>,
);

/// A single future that becomes ready every time there's a change
/// in the window size
impl<T: Sync + Send + 'static> WatchNotification<T> {
fn new(w: watch::Receiver<T>) -> Self {
Self(Box::pin(_watch_changed(w)))
}
}

impl<T: Sync + Send + 'static> Future for WatchNotification<T> {
type Output = Result<(), watch::error::RecvError>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let inner = self.deref_mut().0.as_mut();
match ready!(inner.poll(cx)) {
Ok(receiver) => {
*self.get_mut() = WatchNotification::new(receiver);
Poll::Ready(Ok(()))
}
Err(e) => Poll::Ready(Err(e)),
}
}
}

pub struct ChannelTx<S> {
sender: mpsc::Sender<S>,
send_fut: Option<OwnedPermitFuture<S>>,
id: ChannelId,

window_size_fut: Option<BoxedThreadsafeFuture<OwnedMutexGuard<u32>>>,
window_size: Arc<Mutex<u32>>,
window_size_notication: WatchNotification<u32>,
max_packet_size: u32,
ext: Option<u32>,
}
Expand All @@ -35,50 +76,72 @@ where
sender: mpsc::Sender<S>,
id: ChannelId,
window_size: Arc<Mutex<u32>>,
window_size_notification: watch::Receiver<u32>,
max_packet_size: u32,
ext: Option<u32>,
) -> Self {
Self {
sender,
send_fut: None,
id,
window_size_notication: WatchNotification::new(window_size_notification),
window_size,
window_size_fut: None,
max_packet_size,
ext,
}
}

fn poll_mk_msg(&mut self, cx: &mut Context<'_>, buf: &[u8]) -> Poll<(ChannelMsg, usize)> {
fn poll_writable(
&mut self,
cx: &mut Context<'_>,
buf_len: usize,
) -> Poll<Result<NonZero<usize>, watch::error::RecvError>> {
let window_size = self.window_size.clone();
let window_size_fut = self
.window_size_fut
.get_or_insert_with(|| Box::pin(window_size.lock_owned()));
let mut window_size = ready!(window_size_fut.poll_unpin(cx));
self.window_size_fut.take();

let writable = (self.max_packet_size)
.min(*window_size)
.min(buf.len() as u32) as usize;
if writable == 0 {
// TODO fix this busywait
cx.waker().wake_by_ref();
return Poll::Pending;
let writable = (self.max_packet_size).min(*window_size).min(buf_len as u32) as usize;

match NonZero::try_from(writable) {
Ok(w) => {
*window_size -= writable as u32;
Poll::Ready(Ok(w))
}
Err(_) => match ready!(self.window_size_notication.poll_unpin(cx)) {
Ok(_) => {
cx.waker().wake_by_ref();
Poll::Pending
}
Err(e) => Poll::Ready(Err(e)),
},
}
let mut data = CryptoVec::new_zeroed(writable);
#[allow(clippy::indexing_slicing)] // Clamped to maximum `buf.len()` with `.min`
data.copy_from_slice(&buf[..writable]);
data.resize(writable);
}

*window_size -= writable as u32;
drop(window_size);
fn poll_mk_msg(
&mut self,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<(ChannelMsg, NonZero<usize>), watch::error::RecvError>> {
let writable = match ready!(self.poll_writable(cx, buf.len())) {
Ok(w) => w,
Err(e) => return Poll::Ready(Err(e)),
};

let mut data = CryptoVec::new_zeroed(writable.into());
#[allow(clippy::indexing_slicing)] // Clamped to maximum `buf.len()` with `.min`
data.copy_from_slice(&buf[..writable.into()]);
data.resize(writable.into());

let msg = match self.ext {
None => ChannelMsg::Data { data },
Some(ext) => ChannelMsg::ExtendedData { data, ext },
};

Poll::Ready((msg, writable))
Poll::Ready(Ok((msg, writable)))
}

fn activate(&mut self, msg: ChannelMsg, writable: usize) -> &mut OwnedPermitFuture<S> {
Expand Down Expand Up @@ -119,8 +182,12 @@ where
let send_fut = if let Some(x) = self.send_fut.as_mut() {
x
} else {
let (msg, writable) = ready!(self.poll_mk_msg(cx, buf));
self.activate(msg, writable)
let (msg, writable) = match ready!(self.poll_mk_msg(cx, buf)) {
Ok(x) => x,
// Cannot write anymore
Err(_) => return Poll::Ready(Ok(0)),
};
self.activate(msg, writable.into())
};
let r = ready!(send_fut.as_mut().poll_unpin(cx));
Poll::Ready(self.handle_write_result(r))
Expand Down
31 changes: 25 additions & 6 deletions russh/src/channels/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::sync::Arc;

use tokio::io::{AsyncRead, AsyncWrite};
use tokio::sync::mpsc::{Sender, UnboundedReceiver};
use tokio::sync::Mutex;
use tokio::sync::{watch, Mutex};

use crate::{ChannelId, ChannelOpenFailure, CryptoVec, Error, Pty, Sig};

Expand Down Expand Up @@ -112,6 +112,22 @@ pub enum ChannelMsg {
OpenFailure(ChannelOpenFailure),
}

#[derive(Clone, Debug)]
pub struct WindowSizeRef {
pub(crate) value: Arc<Mutex<u32>>,
notifier: watch::Sender<u32>,
}

impl WindowSizeRef {
pub fn new(initial: u32) -> Self {
let (notifier, _) = watch::channel(initial);
Self {
value: Arc::new(Mutex::new(initial)),
notifier,
}
}
}

/// A handle to a session channel.
///
/// Allows you to read and write from a channel without borrowing the session
Expand All @@ -120,7 +136,7 @@ pub struct Channel<Send: From<(ChannelId, ChannelMsg)>> {
pub(crate) sender: Sender<Send>,
pub(crate) receiver: UnboundedReceiver<ChannelMsg>,
pub(crate) max_packet_size: u32,
pub(crate) window_size: Arc<Mutex<u32>>,
pub(crate) window_size: WindowSizeRef,
}

impl<T: From<(ChannelId, ChannelMsg)>> std::fmt::Debug for Channel<T> {
Expand All @@ -137,7 +153,7 @@ impl<S: From<(ChannelId, ChannelMsg)> + Send + Sync + 'static> Channel<S> {
window_size: u32,
) -> (Self, ChannelRef) {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
let window_size = Arc::new(Mutex::new(window_size));
let window_size = WindowSizeRef::new(window_size);

(
Self {
Expand All @@ -157,7 +173,8 @@ impl<S: From<(ChannelId, ChannelMsg)> + Send + Sync + 'static> Channel<S> {
/// Returns the min between the maximum packet size and the
/// remaining window size in the channel.
pub async fn writable_packet_size(&self) -> usize {
self.max_packet_size.min(*self.window_size.lock().await) as usize
self.max_packet_size
.min(*self.window_size.value.lock().await) as usize
}

pub fn id(&self) -> ChannelId {
Expand Down Expand Up @@ -337,7 +354,8 @@ impl<S: From<(ChannelId, ChannelMsg)> + Send + Sync + 'static> Channel<S> {
io::ChannelTx::new(
self.sender.clone(),
self.id,
self.window_size.clone(),
self.window_size.value.clone(),
self.window_size.notifier.subscribe(),
self.max_packet_size,
None,
),
Expand Down Expand Up @@ -369,7 +387,8 @@ impl<S: From<(ChannelId, ChannelMsg)> + Send + Sync + 'static> Channel<S> {
io::ChannelTx::new(
self.sender.clone(),
self.id,
self.window_size.clone(),
self.window_size.value.clone(),
self.window_size.notifier.subscribe(),
self.max_packet_size,
ext,
)
Expand Down
2 changes: 1 addition & 1 deletion russh/src/client/encrypted.rs
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,7 @@ impl Session {
new_size -= enc.flush_pending(channel_num)? as u32;
}
if let Some(chan) = self.channels.get(&channel_num) {
*chan.window_size().lock().await = new_size;
*chan.window_size().value.lock().await = new_size;

let _ = chan.send(ChannelMsg::WindowAdjusted { new_size });
}
Expand Down
8 changes: 4 additions & 4 deletions russh/src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ use tokio::pin;
use tokio::sync::mpsc::{
channel, unbounded_channel, Receiver, Sender, UnboundedReceiver, UnboundedSender,
};
use tokio::sync::{oneshot, Mutex};
use tokio::sync::oneshot;

use crate::channels::{Channel, ChannelMsg, ChannelRef};
use crate::channels::{Channel, ChannelMsg, ChannelRef, WindowSizeRef};
use crate::cipher::{self, clear, CipherPair, OpeningKey};
use crate::keys::key::parse_public_key;
use crate::session::{
Expand Down Expand Up @@ -428,7 +428,7 @@ impl<H: Handler> Handle<H> {
async fn wait_channel_confirmation(
&self,
mut receiver: UnboundedReceiver<ChannelMsg>,
window_size_ref: Arc<Mutex<u32>>,
window_size_ref: WindowSizeRef,
) -> Result<Channel<Msg>, crate::Error> {
loop {
match receiver.recv().await {
Expand All @@ -437,7 +437,7 @@ impl<H: Handler> Handle<H> {
max_packet_size,
window_size,
}) => {
*window_size_ref.lock().await = window_size;
*window_size_ref.value.lock().await = window_size;

return Ok(Channel {
id,
Expand Down
2 changes: 1 addition & 1 deletion russh/src/server/encrypted.rs
Original file line number Diff line number Diff line change
Expand Up @@ -763,7 +763,7 @@ impl Session {
enc.flush_pending(channel_num)?;
}
if let Some(chan) = self.channels.get(&channel_num) {
*chan.window_size().lock().await = new_size;
*chan.window_size().value.lock().await = new_size;

chan.send(ChannelMsg::WindowAdjusted { new_size })
.unwrap_or(())
Expand Down
7 changes: 4 additions & 3 deletions russh/src/server/session.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
use std::collections::{HashMap, VecDeque};
use std::sync::Arc;

use channels::WindowSizeRef;
use log::debug;
use negotiation::parse_kex_algo_list;
use russh_keys::helpers::NameList;
use russh_keys::map_err;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio::sync::mpsc::{unbounded_channel, Receiver, Sender, UnboundedReceiver};
use tokio::sync::{oneshot, Mutex};
use tokio::sync::oneshot;

use super::*;
use crate::channels::{Channel, ChannelMsg, ChannelRef};
Expand Down Expand Up @@ -346,7 +347,7 @@ impl Handle {
async fn wait_channel_confirmation(
&self,
mut receiver: UnboundedReceiver<ChannelMsg>,
window_size_ref: Arc<Mutex<u32>>,
window_size_ref: WindowSizeRef,
) -> Result<Channel<Msg>, Error> {
loop {
match receiver.recv().await {
Expand All @@ -355,7 +356,7 @@ impl Handle {
max_packet_size,
window_size,
}) => {
*window_size_ref.lock().await = window_size;
*window_size_ref.value.lock().await = window_size;

return Ok(Channel {
id,
Expand Down

0 comments on commit e2c6e3d

Please sign in to comment.