Skip to content

Commit

Permalink
wruster/src/streams/cancellable_stream: fix bug making streams not to…
Browse files Browse the repository at this point in the history
… detect a connection close
  • Loading branch information
manelmontilla committed Dec 7, 2023
1 parent efa9129 commit a35a17e
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 7 deletions.
18 changes: 11 additions & 7 deletions wruster/src/streams/cancellable_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,19 +82,22 @@ where
let read_buf = &mut buf[bytes_read..];
let s = &self.stream;

match s.read_buf(read_buf) {
let result = match s.read_buf(read_buf) {
Ok(0) if self.done.load(Ordering::SeqCst) => {
return Err(io::Error::from(io::ErrorKind::NotConnected));
Err(io::Error::from(io::ErrorKind::NotConnected))
}
Ok(n) => Ok(n),
Err(err) => Err(err),
};
match result {
Ok(n) => {
bytes_read += n;
}
Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
Err(err) => {
return Err(err);
Err(err) if err.kind() == io::ErrorKind::WouldBlock => {
return Err(io::Error::from(io::ErrorKind::Interrupted));
}
Err(err) => return Err(err),
};

// TODO: Actually this is not correct, we should read all the
// events returned by wait, even if we end up reading more bytes
// than the len of the buffer provide by the caller.
Expand All @@ -106,7 +109,8 @@ where
// the operation by returning and error of kind :Interrupted.
// Reference: https://doc.rust-lang.org/std/io/trait.Read.html#errors.
if bytes_read == 0 {
return Err(io::Error::from(io::ErrorKind::Interrupted));
return Err(io::Error::from(io::ErrorKind::ConnectionReset));
//return Err(io::Error::from(io::ErrorKind::Interrupted));
}
Ok(bytes_read)
}
Expand Down
24 changes: 24 additions & 0 deletions wruster/src/streams/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,30 @@ fn cancellable_stream_shutdown_stops_reading() {
handle.join().unwrap();
}

#[test]
fn cancellable_stream_read_stops_connection_close() {
let port = get_free_port();
let addr = format!("127.0.0.1:{}", port);
let listener = TcpListener::bind(addr.clone()).unwrap();
let handle = thread::spawn(move || {
let (stream, _) = listener.accept().unwrap();
let mut cstream = CancellableStream::new(stream).unwrap();
cstream
.set_read_timeout(Some(Duration::from_secs(2)))
.unwrap();
let mut reader = BufReader::new(&mut cstream);
let mut content = Vec::new();
reader
.read_until(b' ', &mut content)
.expect_err("connetion close");
});

let mut client = TcpClient::connect(addr.to_string()).unwrap();
thread::sleep(Duration::from_secs(1));
client.close().unwrap();
handle.join().unwrap();
}

#[test]
fn cancellable_stream_read_reads_data() {
let port = get_free_port();
Expand Down

0 comments on commit a35a17e

Please sign in to comment.