Skip to content

Commit

Permalink
wruster: refactor the streams module
Browse files Browse the repository at this point in the history
  • Loading branch information
manelmontilla committed Oct 28, 2023
1 parent 5318adb commit bf09256
Show file tree
Hide file tree
Showing 7 changed files with 341 additions and 308 deletions.
34 changes: 17 additions & 17 deletions wruster/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,14 @@ use http::*;
use polling::{Event, Poller};
use router::{Normalize, Router};

use streams::cancellable_stream::BaseStream;
use streams::timeout_stream::TimeoutStream;
use streams::TrackedStream;

use crate::streams::cancellable_stream::CancellableStream;
pub use crate::streams::tls::{Certificate, PrivateKey};
use crate::streams::{tls, TrackedStreamList};

use streams::{
cancellable_stream::CancellableStream,
observable::{ObservedStream, ObservedStreamList},
timeout_stream::TimeoutStream,
tls, Stream,
};

pub use streams::tls::{Certificate, PrivateKey};
/// Contains all the types necessary for dealing with Http messages.
pub mod http;
/// Contains the router to be used in a [`Server`].
Expand Down Expand Up @@ -273,7 +273,7 @@ impl Server {
})
}

fn start<T: BaseStream + Send + Sync + 'static, F>(
fn start<T: Stream + Send + Sync + 'static, F>(
&mut self,
addr: &str,
routes: Router,
Expand Down Expand Up @@ -322,7 +322,7 @@ impl Server {
let mut pool = thread_pool::Pool::new(execunits, 100);
let stop = Arc::clone(&self.stop);
let timeouts = self.timeouts.clone();
let active_streams = TrackedStreamList::new();
let active_streams = ObservedStreamList::new();
let handle = thread::spawn(move || {
loop {
debug!("tracked streams {}", active_streams.len());
Expand Down Expand Up @@ -350,7 +350,7 @@ impl Server {
continue;
}
};
let action_stream = TrackedStreamList::track(&active_streams, action_stream);
let action_stream = ObservedStreamList::track(&active_streams, action_stream);
let local_action_stream = action_stream.clone();
let action = move || {
handle_conversation(
Expand Down Expand Up @@ -506,9 +506,9 @@ impl Default for Server {
}
}

fn handle_busy<T>(stream: TrackedStream<T>, timeouts: Timeouts, src_addr: SocketAddr)
fn handle_busy<T>(stream: ObservedStream<T>, timeouts: Timeouts, src_addr: SocketAddr)
where
T: BaseStream,
T: Stream,
{
debug!("sending too busy to {}", src_addr);
let write_timeout = Some(timeouts.write_response_timeout);
Expand All @@ -527,12 +527,12 @@ where
}

fn handle_conversation<T>(
mut stream: TrackedStream<T>,
mut stream: ObservedStream<T>,
routes: Arc<Router>,
timeouts: Timeouts,
source_addr: SocketAddr,
) where
T: BaseStream + 'static,
T: Stream + 'static,
{
debug!("handling conversation with {}", source_addr);
let mut connection_open = true;
Expand All @@ -558,13 +558,13 @@ fn handle_conversation<T>(
}

fn handle_connection<T>(
stream: TrackedStream<T>,
stream: ObservedStream<T>,
routes: Arc<Router>,
source_addr: SocketAddr,
timeouts: Timeouts,
) -> bool
where
T: BaseStream + 'static,
T: Stream + 'static,
{
let connection_open: bool;
let read_timeout = Some(timeouts.read_request_timeout);
Expand Down
113 changes: 12 additions & 101 deletions wruster/src/streams/cancellable_stream.rs
Original file line number Diff line number Diff line change
@@ -1,100 +1,15 @@
use super::tls;
use super::BaseStream;
use crate::log::debug;
use polling::{Event, Source};
use std::io::Read;
use std::io::{self, Write};
use std::net::{Shutdown, TcpStream};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{atomic, Arc, RwLock};
use std::time::Duration;

pub trait BaseStream {
fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()>;
fn shutdown(&self, how: Shutdown) -> io::Result<()>;
fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()>;
fn set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()>;
fn as_raw(&self) -> std::os::unix::prelude::RawFd;
fn write_buf(&self, buf: &[u8]) -> io::Result<usize>;
fn read_buf(&self, buf: &mut [u8]) -> io::Result<usize>;
fn flush_data(&self) -> io::Result<()>;
}

pub trait Stream: Send + Sync + BaseStream {}

impl BaseStream for TcpStream {
fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
self.set_nonblocking(nonblocking)
}

fn shutdown(&self, how: Shutdown) -> io::Result<()> {
self.shutdown(how)
}

fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
self.set_read_timeout(dur)
}

fn set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
self.set_write_timeout(dur)
}

fn as_raw(&self) -> std::os::unix::prelude::RawFd {
self.raw()
}

fn write_buf(&self, buf: &[u8]) -> io::Result<usize> {
let mut s = self;
<&Self as Write>::write(&mut s, buf)
}

fn read_buf(&self, buf: &mut [u8]) -> io::Result<usize> {
let mut s = self;
<&Self as Read>::read(&mut s, buf)
}

fn flush_data(&self) -> io::Result<()> {
let mut s = self;
<&Self as Write>::flush(&mut s)
}
}

impl Stream for TcpStream {}

impl BaseStream for tls::Stream {
fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
self.set_nonblocking(nonblocking)
}

fn shutdown(&self, how: Shutdown) -> io::Result<()> {
self.shutdown(how)
}

fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
self.set_read_timeout(dur)
}

fn set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
self.set_write_timeout(dur)
}

fn as_raw(&self) -> std::os::unix::prelude::RawFd {
self.as_raw()
}

fn write_buf(&self, buf: &[u8]) -> io::Result<usize> {
self.write_int(buf)
}

fn read_buf(&self, buf: &mut [u8]) -> io::Result<usize> {
self.read_int(buf)
}

fn flush_data(&self) -> io::Result<()> {
self.flush_data()
}
}

impl Stream for tls::Stream {}
use polling::Event;
use std::{
io,
net::Shutdown,
sync::{
atomic::{self, AtomicBool, Ordering},
Arc, RwLock,
},
time::Duration,
};

pub struct CancellableStream<T: BaseStream> {
stream: T,
Expand Down Expand Up @@ -136,10 +51,6 @@ where
Ok(())
}

pub fn cancel(&self) -> io::Result<()> {
self.poller.notify()
}

pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
self.done.store(true, Ordering::SeqCst);
self.stream.shutdown(how)
Expand Down Expand Up @@ -186,7 +97,7 @@ where

// TODO: Actually this is not correct, we should read all the
// events returned by wait, even if we end up reading more bytes
// than the len of the buffer provider by the caller.
// than the len of the buffer provide by the caller.
if bytes_read == buf_len {
break;
}
Expand Down
Loading

0 comments on commit bf09256

Please sign in to comment.