From a35a17e120fa70d6643bfee3e9ab9a761242c13e Mon Sep 17 00:00:00 2001 From: Manel Montilla Date: Tue, 5 Dec 2023 21:15:09 +0100 Subject: [PATCH] wruster/src/streams/cancellable_stream: fix bug making streams not to detect a connection close --- wruster/src/streams/cancellable_stream.rs | 18 ++++++++++------- wruster/src/streams/test.rs | 24 +++++++++++++++++++++++ 2 files changed, 35 insertions(+), 7 deletions(-) diff --git a/wruster/src/streams/cancellable_stream.rs b/wruster/src/streams/cancellable_stream.rs index f1f0bee..43e1498 100644 --- a/wruster/src/streams/cancellable_stream.rs +++ b/wruster/src/streams/cancellable_stream.rs @@ -82,19 +82,22 @@ where let read_buf = &mut buf[bytes_read..]; let s = &self.stream; - match s.read_buf(read_buf) { + let result = match s.read_buf(read_buf) { Ok(0) if self.done.load(Ordering::SeqCst) => { - return Err(io::Error::from(io::ErrorKind::NotConnected)); + Err(io::Error::from(io::ErrorKind::NotConnected)) } + Ok(n) => Ok(n), + Err(err) => Err(err), + }; + match result { Ok(n) => { bytes_read += n; } - Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} - Err(err) => { - return Err(err); + Err(err) if err.kind() == io::ErrorKind::WouldBlock => { + return Err(io::Error::from(io::ErrorKind::Interrupted)); } + Err(err) => return Err(err), }; - // TODO: Actually this is not correct, we should read all the // events returned by wait, even if we end up reading more bytes // than the len of the buffer provide by the caller. @@ -106,7 +109,8 @@ where // the operation by returning and error of kind :Interrupted. // Reference: https://doc.rust-lang.org/std/io/trait.Read.html#errors. if bytes_read == 0 { - return Err(io::Error::from(io::ErrorKind::Interrupted)); + return Err(io::Error::from(io::ErrorKind::ConnectionReset)); + //return Err(io::Error::from(io::ErrorKind::Interrupted)); } Ok(bytes_read) } diff --git a/wruster/src/streams/test.rs b/wruster/src/streams/test.rs index 1391c9c..00cf7dc 100644 --- a/wruster/src/streams/test.rs +++ b/wruster/src/streams/test.rs @@ -60,6 +60,30 @@ fn cancellable_stream_shutdown_stops_reading() { handle.join().unwrap(); } +#[test] +fn cancellable_stream_read_stops_connection_close() { + let port = get_free_port(); + let addr = format!("127.0.0.1:{}", port); + let listener = TcpListener::bind(addr.clone()).unwrap(); + let handle = thread::spawn(move || { + let (stream, _) = listener.accept().unwrap(); + let mut cstream = CancellableStream::new(stream).unwrap(); + cstream + .set_read_timeout(Some(Duration::from_secs(2))) + .unwrap(); + let mut reader = BufReader::new(&mut cstream); + let mut content = Vec::new(); + reader + .read_until(b' ', &mut content) + .expect_err("connetion close"); + }); + + let mut client = TcpClient::connect(addr.to_string()).unwrap(); + thread::sleep(Duration::from_secs(1)); + client.close().unwrap(); + handle.join().unwrap(); +} + #[test] fn cancellable_stream_read_reads_data() { let port = get_free_port();