From 4a7d20497ddbb2e62c9059e30b4fc6ac4082244d Mon Sep 17 00:00:00 2001 From: Manel Montilla Date: Sun, 22 Oct 2023 19:19:50 +0200 Subject: [PATCH] wruster: refactor the streams module --- wruster/src/lib.rs | 34 ++-- wruster/src/streams/cancellable_stream.rs | 113 ++--------- wruster/src/streams/mod.rs | 231 +++++++--------------- wruster/src/streams/observable.rs | 206 +++++++++++++++++++ wruster/src/streams/test.rs | 53 ++--- wruster/src/streams/timeout_stream.rs | 11 +- wruster/src/streams/tls/mod.rs | 2 +- 7 files changed, 342 insertions(+), 308 deletions(-) create mode 100644 wruster/src/streams/observable.rs diff --git a/wruster/src/lib.rs b/wruster/src/lib.rs index 8ba5fae..d9869b2 100644 --- a/wruster/src/lib.rs +++ b/wruster/src/lib.rs @@ -60,14 +60,14 @@ use http::*; use polling::{Event, Poller}; use router::{Normalize, Router}; -use streams::cancellable_stream::BaseStream; -use streams::timeout_stream::TimeoutStream; -use streams::TrackedStream; - -use crate::streams::cancellable_stream::CancellableStream; -pub use crate::streams::tls::{Certificate, PrivateKey}; -use crate::streams::{tls, TrackedStreamList}; - +use streams::{ + cancellable_stream::CancellableStream, + observable::{ObservedStream, ObservedStreamList}, + timeout_stream::TimeoutStream, + tls, Stream, +}; + +pub use streams::tls::{Certificate, PrivateKey}; /// Contains all the types necessary for dealing with Http messages. pub mod http; /// Contains the router to be used in a [`Server`]. @@ -273,7 +273,7 @@ impl Server { }) } - fn start( + fn start( &mut self, addr: &str, routes: Router, @@ -322,7 +322,7 @@ impl Server { let mut pool = thread_pool::Pool::new(execunits, 100); let stop = Arc::clone(&self.stop); let timeouts = self.timeouts.clone(); - let active_streams = TrackedStreamList::new(); + let active_streams = ObservedStreamList::new(); let handle = thread::spawn(move || { loop { debug!("tracked streams {}", active_streams.len()); @@ -350,7 +350,7 @@ impl Server { continue; } }; - let action_stream = TrackedStreamList::track(&active_streams, action_stream); + let action_stream = ObservedStreamList::track(&active_streams, action_stream); let local_action_stream = action_stream.clone(); let action = move || { handle_conversation( @@ -506,9 +506,9 @@ impl Default for Server { } } -fn handle_busy(stream: TrackedStream, timeouts: Timeouts, src_addr: SocketAddr) +fn handle_busy(stream: ObservedStream, timeouts: Timeouts, src_addr: SocketAddr) where - T: BaseStream, + T: Stream, { debug!("sending too busy to {}", src_addr); let write_timeout = Some(timeouts.write_response_timeout); @@ -527,12 +527,12 @@ where } fn handle_conversation( - mut stream: TrackedStream, + mut stream: ObservedStream, routes: Arc, timeouts: Timeouts, source_addr: SocketAddr, ) where - T: BaseStream + 'static, + T: Stream + 'static, { debug!("handling conversation with {}", source_addr); let mut connection_open = true; @@ -558,13 +558,13 @@ fn handle_conversation( } fn handle_connection( - stream: TrackedStream, + stream: ObservedStream, routes: Arc, source_addr: SocketAddr, timeouts: Timeouts, ) -> bool where - T: BaseStream + 'static, + T: Stream + 'static, { let connection_open: bool; let read_timeout = Some(timeouts.read_request_timeout); diff --git a/wruster/src/streams/cancellable_stream.rs b/wruster/src/streams/cancellable_stream.rs index c61fe00..da1b272 100644 --- a/wruster/src/streams/cancellable_stream.rs +++ b/wruster/src/streams/cancellable_stream.rs @@ -1,100 +1,15 @@ -use super::tls; +use super::BaseStream; use crate::log::debug; -use polling::{Event, Source}; -use std::io::Read; -use std::io::{self, Write}; -use std::net::{Shutdown, TcpStream}; -use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::{atomic, Arc, RwLock}; -use std::time::Duration; - -pub trait BaseStream { - fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()>; - fn shutdown(&self, how: Shutdown) -> io::Result<()>; - fn set_read_timeout(&self, dur: Option) -> io::Result<()>; - fn set_write_timeout(&self, dur: Option) -> io::Result<()>; - fn as_raw(&self) -> std::os::unix::prelude::RawFd; - fn write_buf(&self, buf: &[u8]) -> io::Result; - fn read_buf(&self, buf: &mut [u8]) -> io::Result; - fn flush_data(&self) -> io::Result<()>; -} - -pub trait Stream: Send + Sync + BaseStream {} - -impl BaseStream for TcpStream { - fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { - self.set_nonblocking(nonblocking) - } - - fn shutdown(&self, how: Shutdown) -> io::Result<()> { - self.shutdown(how) - } - - fn set_read_timeout(&self, dur: Option) -> io::Result<()> { - self.set_read_timeout(dur) - } - - fn set_write_timeout(&self, dur: Option) -> io::Result<()> { - self.set_write_timeout(dur) - } - - fn as_raw(&self) -> std::os::unix::prelude::RawFd { - self.raw() - } - - fn write_buf(&self, buf: &[u8]) -> io::Result { - let mut s = self; - <&Self as Write>::write(&mut s, buf) - } - - fn read_buf(&self, buf: &mut [u8]) -> io::Result { - let mut s = self; - <&Self as Read>::read(&mut s, buf) - } - - fn flush_data(&self) -> io::Result<()> { - let mut s = self; - <&Self as Write>::flush(&mut s) - } -} - -impl Stream for TcpStream {} - -impl BaseStream for tls::Stream { - fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { - self.set_nonblocking(nonblocking) - } - - fn shutdown(&self, how: Shutdown) -> io::Result<()> { - self.shutdown(how) - } - - fn set_read_timeout(&self, dur: Option) -> io::Result<()> { - self.set_read_timeout(dur) - } - - fn set_write_timeout(&self, dur: Option) -> io::Result<()> { - self.set_write_timeout(dur) - } - - fn as_raw(&self) -> std::os::unix::prelude::RawFd { - self.as_raw() - } - - fn write_buf(&self, buf: &[u8]) -> io::Result { - self.write_int(buf) - } - - fn read_buf(&self, buf: &mut [u8]) -> io::Result { - self.read_int(buf) - } - - fn flush_data(&self) -> io::Result<()> { - self.flush_data() - } -} - -impl Stream for tls::Stream {} +use polling::Event; +use std::{ + io, + net::Shutdown, + sync::{ + atomic::{self, AtomicBool, Ordering}, + Arc, RwLock, + }, + time::Duration, +}; pub struct CancellableStream { stream: T, @@ -136,10 +51,6 @@ where Ok(()) } - pub fn cancel(&self) -> io::Result<()> { - self.poller.notify() - } - pub fn shutdown(&self, how: Shutdown) -> io::Result<()> { self.done.store(true, Ordering::SeqCst); self.stream.shutdown(how) @@ -186,7 +97,7 @@ where // TODO: Actually this is not correct, we should read all the // events returned by wait, even if we end up reading more bytes - // than the len of the buffer provider by the caller. + // than the len of the buffer provide by the caller. if bytes_read == buf_len { break; } diff --git a/wruster/src/streams/mod.rs b/wruster/src/streams/mod.rs index 14f15ce..e1b8f12 100644 --- a/wruster/src/streams/mod.rs +++ b/wruster/src/streams/mod.rs @@ -1,202 +1,109 @@ -use std::{ - collections::HashMap, - io::{self, Read, Write}, - ops::{Deref, DerefMut}, - sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, RwLock, Weak, - }, -}; +/*! +Contains various types that augment a type that can act as a [Stream], e.g.: a [std::net::TcpStream]. +*/ +use polling::Source; +use std::io::Read; +use std::io::{self, Write}; +use std::net::{Shutdown, TcpStream}; +use std::time::Duration; pub mod cancellable_stream; +pub mod observable; pub mod timeout_stream; pub mod tls; -#[cfg(test)] -mod test; -mod test_utils; - -use timeout_stream::Timeout; - -use self::cancellable_stream::{BaseStream, CancellableStream}; - -pub struct ObservedStream -where - T: BaseStream, -{ - observed: CancellableStream, - parent: Option<(usize, Weak>)>, +pub trait BaseStream { + fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()>; + fn shutdown(&self, how: Shutdown) -> io::Result<()>; + fn set_read_timeout(&self, dur: Option) -> io::Result<()>; + fn set_write_timeout(&self, dur: Option) -> io::Result<()>; + fn as_raw(&self) -> std::os::unix::prelude::RawFd; + fn write_buf(&self, buf: &[u8]) -> io::Result; + fn read_buf(&self, buf: &mut [u8]) -> io::Result; + fn flush_data(&self) -> io::Result<()>; } -impl ObservedStream -where - T: BaseStream, -{ - pub fn new(observed: CancellableStream) -> ObservedStream { - ObservedStream { - observed, - parent: None, - } +impl BaseStream for TcpStream { + fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { + self.set_nonblocking(nonblocking) } -} -impl Drop for ObservedStream -where - T: BaseStream, -{ - fn drop(&mut self) { - let parent = match &self.parent { - Some(it) => it, - _ => return, - }; - let key = parent.0; - if let Some(parent) = parent.1.upgrade() { - parent.dropped(key); - } + fn shutdown(&self, how: Shutdown) -> io::Result<()> { + self.shutdown(how) } -} -impl Deref for ObservedStream -where - T: BaseStream, -{ - type Target = CancellableStream; - - fn deref(&self) -> &Self::Target { - &self.observed + fn set_read_timeout(&self, dur: Option) -> io::Result<()> { + self.set_read_timeout(dur) } -} -impl DerefMut for ObservedStream -where - T: BaseStream, -{ - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.observed + fn set_write_timeout(&self, dur: Option) -> io::Result<()> { + self.set_write_timeout(dur) } -} -impl From> for ObservedStream -where - T: BaseStream, -{ - fn from(it: CancellableStream) -> Self { - ObservedStream::new(it) + fn as_raw(&self) -> std::os::unix::prelude::RawFd { + self.raw() } -} -pub struct TrackedStream -where - T: BaseStream, -{ - stream: Arc>, -} - -impl Clone for TrackedStream -where - T: BaseStream, -{ - fn clone(&self) -> Self { - let stream = Arc::clone(&self.stream); - Self { stream } + fn write_buf(&self, buf: &[u8]) -> io::Result { + let mut s = self; + <&Self as Write>::write(&mut s, buf) } -} -impl Deref for TrackedStream -where - T: BaseStream, -{ - type Target = ObservedStream; - - fn deref(&self) -> &Self::Target { - &self.stream + fn read_buf(&self, buf: &mut [u8]) -> io::Result { + let mut s = self; + <&Self as Read>::read(&mut s, buf) } -} -impl Read for TrackedStream -where - T: BaseStream, -{ - fn read(&mut self, buf: &mut [u8]) -> io::Result { - let mut s = &self.stream.observed; - s.read(buf) + fn flush_data(&self) -> io::Result<()> { + let mut s = self; + <&Self as Write>::flush(&mut s) } } -impl Write for TrackedStream -where - T: BaseStream, -{ - fn write(&mut self, buf: &[u8]) -> io::Result { - let mut s = &self.stream.observed; - s.write(buf) - } - - fn flush(&mut self) -> io::Result<()> { - let mut s = &self.stream.observed; - s.flush() +impl BaseStream for tls::Stream { + fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { + self.set_nonblocking(nonblocking) } -} -impl Timeout for TrackedStream -where - T: BaseStream, -{ - fn set_read_timeout(&self, dur: Option) -> io::Result<()> { - self.stream.set_read_timeout(dur) + fn shutdown(&self, how: Shutdown) -> io::Result<()> { + self.shutdown(how) } - fn set_write_timeout(&self, dur: Option) -> io::Result<()> { - self.stream.set_write_timeout(dur) + fn set_read_timeout(&self, dur: Option) -> io::Result<()> { + self.set_read_timeout(dur) } -} - -pub struct TrackedStreamList -where - T: BaseStream, -{ - items: RwLock>>>, - next_key: AtomicUsize, -} -impl TrackedStreamList -where - T: BaseStream, -{ - pub fn new() -> Arc> { - let items = HashMap::>>::new(); - let list = TrackedStreamList { - items: RwLock::new(items), - next_key: AtomicUsize::new(0), - }; - Arc::new(list) + fn set_write_timeout(&self, dur: Option) -> io::Result<()> { + self.set_write_timeout(dur) } - pub fn track( - list: &Arc>, - stream: CancellableStream, - ) -> TrackedStream { - let mut stream = ObservedStream::new(stream); - let parent = Arc::downgrade(list); - let key = list.next_key.fetch_add(1, Ordering::SeqCst); - stream.parent = Some((key, parent)); - let stream = Arc::new(stream); - let mut items = list.items.write().unwrap(); - items.insert(key, Arc::downgrade(&stream)); - TrackedStream { stream } + fn as_raw(&self) -> std::os::unix::prelude::RawFd { + self.as_raw() } - pub fn len(&self) -> usize { - self.items.read().unwrap().len() + fn write_buf(&self, buf: &[u8]) -> io::Result { + self.write_int(buf) } - fn dropped(&self, key: usize) { - let mut items = self.items.write().unwrap(); - items.remove(&key); + fn read_buf(&self, buf: &mut [u8]) -> io::Result { + self.read_int(buf) } - pub fn drain(&self) -> Vec>> { - let mut items = self.items.write().unwrap(); - items.drain().map(|x| x.1).collect() + fn flush_data(&self) -> io::Result<()> { + self.flush_data() } } + +/** + Defines the shape type that can act as a Stream so its functionality can be extended + by the other types in the package, e.g.: [observable::ObservableStream]. +*/ +pub trait Stream: Send + Sync + BaseStream {} + +impl Stream for tls::Stream {} + +impl Stream for TcpStream {} + +#[cfg(test)] +mod test; +mod test_utils; diff --git a/wruster/src/streams/observable.rs b/wruster/src/streams/observable.rs new file mode 100644 index 0000000..a5485b7 --- /dev/null +++ b/wruster/src/streams/observable.rs @@ -0,0 +1,206 @@ +use super::{cancellable_stream::CancellableStream, timeout_stream::Timeout, Stream}; +use std::{ + collections::HashMap, + io::{self, Read, Write}, + ops::{Deref, DerefMut}, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, RwLock, Weak, + }, +}; + +/** +Wraps a [CancellableStream] so it can be included in a [ObservedStreamList]. +See the [ObservedStreamList] documentation for more info. +*/ +pub struct ObservableStream +where + T: Stream, +{ + observed: CancellableStream, + parent: Option<(usize, Weak>)>, +} + +impl ObservableStream +where + T: Stream, +{ + pub fn new(observed: CancellableStream) -> ObservableStream { + ObservableStream { + observed, + parent: None, + } + } +} + +impl Drop for ObservableStream +where + T: Stream, +{ + fn drop(&mut self) { + let parent = match &self.parent { + Some(it) => it, + _ => return, + }; + let key = parent.0; + if let Some(parent) = parent.1.upgrade() { + parent.dropped(key); + } + } +} + +impl Deref for ObservableStream +where + T: Stream, +{ + type Target = CancellableStream; + + fn deref(&self) -> &Self::Target { + &self.observed + } +} + +impl DerefMut for ObservableStream +where + T: Stream, +{ + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.observed + } +} + +impl From> for ObservableStream +where + T: Stream, +{ + fn from(it: CancellableStream) -> Self { + ObservableStream::new(it) + } +} + +/** + Represents an [ObservableStream] that was added to an [ObservedStreamList], + so it is returned by the method [ObservedStreamList.track]. +* A [ObservedStream] is dereferenced to the [ObservableStream] being Observed. +* It can be cloned and the new ObservedStream will be also observed. +*/ +pub struct ObservedStream +where + T: Stream, +{ + stream: Arc>, +} + +impl Clone for ObservedStream +where + T: Stream, +{ + fn clone(&self) -> Self { + let stream = Arc::clone(&self.stream); + Self { stream } + } +} + +impl Deref for ObservedStream +where + T: Stream, +{ + type Target = ObservableStream; + + fn deref(&self) -> &Self::Target { + &self.stream + } +} + +impl Read for ObservedStream +where + T: Stream, +{ + fn read(&mut self, buf: &mut [u8]) -> io::Result { + let mut s = &self.stream.observed; + s.read(buf) + } +} + +impl Write for ObservedStream +where + T: Stream, +{ + fn write(&mut self, buf: &[u8]) -> io::Result { + let mut s = &self.stream.observed; + s.write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + let mut s = &self.stream.observed; + s.flush() + } +} + +impl Timeout for ObservedStream +where + T: Stream, +{ + fn set_read_timeout(&self, dur: Option) -> io::Result<()> { + self.stream.set_read_timeout(dur) + } + + fn set_write_timeout(&self, dur: Option) -> io::Result<()> { + self.stream.set_write_timeout(dur) + } +} + +/** +Allows to track a list of [ObservableStream], so whenever one of the Streams in the list, and all of its clones, +is dropped, it's automatically removed from the list. An [ObservableStream] is included in an [ObservedStreamList] +by calling the method: [ObservedStreamList.track]. +*/ +pub struct ObservedStreamList +where + T: Stream, +{ + items: RwLock>>>, + next_key: AtomicUsize, +} + +impl ObservedStreamList +where + T: Stream, +{ + pub fn new() -> Arc> { + let items = HashMap::>>::new(); + let list = ObservedStreamList { + items: RwLock::new(items), + next_key: AtomicUsize::new(0), + }; + Arc::new(list) + } + + pub fn track( + list: &Arc>, + stream: CancellableStream, + ) -> ObservedStream { + let mut stream = ObservableStream::new(stream); + let parent = Arc::downgrade(list); + let key = list.next_key.fetch_add(1, Ordering::SeqCst); + stream.parent = Some((key, parent)); + let stream = Arc::new(stream); + let mut items = list.items.write().unwrap(); + items.insert(key, Arc::downgrade(&stream)); + ObservedStream { stream } + } + + pub fn len(&self) -> usize { + self.items.read().unwrap().len() + } + + fn dropped(&self, key: usize) { + let mut items = self.items.write().unwrap(); + items.remove(&key); + } + + pub fn drain(&self) -> Vec>> { + let mut items = self.items.write().unwrap(); + items.drain().map(|x| x.1).collect() + } +} diff --git a/wruster/src/streams/test.rs b/wruster/src/streams/test.rs index e74703e..ea4c776 100644 --- a/wruster/src/streams/test.rs +++ b/wruster/src/streams/test.rs @@ -1,19 +1,27 @@ -use super::*; -use crate::streams::timeout_stream::TimeoutStream; -use crate::streams::tls::test_utils::*; -use crate::test_utils::TestTLSClient; - -use std::io::{BufRead, BufReader, ErrorKind, Read, Write}; -use std::net::{Shutdown, TcpListener}; -use std::str::FromStr; -use std::sync::atomic::AtomicBool; -use std::thread; +use super::{ + cancellable_stream::CancellableStream, + observable::ObservedStreamList, + test_utils::{get_free_port, TcpClient}, + timeout_stream::TimeoutStream, + tls::test_utils::*, + *, +}; -use std::time::Duration; -use test_utils::{get_free_port, TcpClient}; +use crate::test_utils::TestTLSClient; +use std::{ + io::{BufRead, BufReader, ErrorKind, Read, Write}, + net::{Shutdown, TcpListener}, + str::FromStr, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, + thread, + time::Duration, +}; #[test] -fn shutdown_stops_reading() { +fn cancellable_stream_shutdown_stops_reading() { let port = get_free_port(); let addr = format!("127.0.0.1:{}", port); let listener = TcpListener::bind(addr.clone()).unwrap(); @@ -53,7 +61,7 @@ fn shutdown_stops_reading() { } #[test] -fn read_reads_data() { +fn cancellable_stream_read_reads_data() { let port = get_free_port(); let addr = format!("127.0.0.1:{}", port); let listener = TcpListener::bind(addr.clone()).unwrap(); @@ -74,8 +82,7 @@ fn read_reads_data() { } #[test] -fn read_honors_timeout() { - //env::set_var("RUST_LOG", "debug"); +fn cancellable_steeam_read_honors_timeout() { env_logger::init(); let port = get_free_port(); let addr = format!("127.0.0.1:{}", port); @@ -100,7 +107,7 @@ fn read_honors_timeout() { } #[test] -fn write_writes_data() { +fn cancellable_stream_write_writes_data() { let data = "test "; let port = get_free_port(); let addr = format!("127.0.0.1:{}", port); @@ -130,7 +137,7 @@ fn write_writes_data() { } #[test] -fn test_shutdown_list() { +fn observed_stream_list_removes_stream() { let port = get_free_port(); let addr = format!("127.0.0.1:{}", port); let listener = TcpListener::bind(addr.clone()).unwrap(); @@ -138,8 +145,8 @@ fn test_shutdown_list() { let handle = thread::spawn(move || { let (stream, _) = listener.accept().unwrap(); let cstream = CancellableStream::new(stream).unwrap(); - let track_list = TrackedStreamList::new(); - let stream_tracked = TrackedStreamList::track(&track_list, cstream); + let track_list = ObservedStreamList::new(); + let stream_tracked = ObservedStreamList::track(&track_list, cstream); let cstream2 = stream_tracked.clone(); assert_eq!(1, track_list.len()); let handle = thread::spawn(move || { @@ -182,7 +189,7 @@ fn tls_stream_read_reads_data() { } #[test] -fn test_shutdown_list_tls() { +fn observed_stream_list_tracks_tls_streams() { let port = get_free_port(); let addr = format!("127.0.0.1:{}", port); let listener = TcpListener::bind(addr.clone()).unwrap(); @@ -193,8 +200,8 @@ fn test_shutdown_list_tls() { let cert = load_test_certificate().unwrap(); let stream = tls::Stream::new(stream, key, cert).unwrap(); let cstream = CancellableStream::new(stream).unwrap(); - let track_list = TrackedStreamList::new(); - let stream_tracked = TrackedStreamList::track(&track_list, cstream); + let track_list = ObservedStreamList::new(); + let stream_tracked = ObservedStreamList::track(&track_list, cstream); let cstream2 = stream_tracked.clone(); assert_eq!(1, track_list.len()); let handle = thread::spawn(move || { diff --git a/wruster/src/streams/timeout_stream.rs b/wruster/src/streams/timeout_stream.rs index d81729e..6de3c77 100644 --- a/wruster/src/streams/timeout_stream.rs +++ b/wruster/src/streams/timeout_stream.rs @@ -1,8 +1,11 @@ -use std::io::{self, ErrorKind, Read, Write}; -use std::net::TcpStream; -use std::time::{Duration, Instant}; +use std::{ + io::{self, ErrorKind, Read, Write}, + net::TcpStream, + sync::Arc, + time::{Duration, Instant}, +}; -use super::cancellable_stream::{BaseStream, CancellableStream}; +use super::{cancellable_stream::CancellableStream, BaseStream}; pub trait Timeout { fn set_read_timeout(&self, dur: Option) -> io::Result<()>; diff --git a/wruster/src/streams/tls/mod.rs b/wruster/src/streams/tls/mod.rs index 824f3df..4077332 100644 --- a/wruster/src/streams/tls/mod.rs +++ b/wruster/src/streams/tls/mod.rs @@ -9,7 +9,7 @@ use std::{ use rustls::{self, ServerConfig, ServerConnection, StreamOwned}; -use super::cancellable_stream::BaseStream; +use super::BaseStream; pub mod test_utils;