Skip to content

Commit

Permalink
Merge pull request #23 from manelmontilla/fix-big-files-bug
Browse files Browse the repository at this point in the history
Fixes problems serving big files.
  • Loading branch information
manelmontilla authored Nov 18, 2023
2 parents 88c3844 + 4fb4f8a commit cbdc1fc
Show file tree
Hide file tree
Showing 8 changed files with 69 additions and 35 deletions.
4 changes: 2 additions & 2 deletions wrustatic/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ fn main() {
log_middleware(Box::new(move |request| serve_static(&dir, request)));
routes.add("/", http::HttpMethod::GET, serve_dir);
let timeouts = Timeouts {
write_response_timeout: Duration::from_secs(5),
read_request_timeout: Duration::from_secs(5),
write_response_timeout: Duration::from_secs(60),
read_request_timeout: Duration::from_secs(60),
};
let mut server = Server::from_timeouts(timeouts);
let running = match cli.tls_cert {
Expand Down
2 changes: 1 addition & 1 deletion wruster/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ mime = "0.3.16"
atomic_refcell = "0.1.8"
log = "0.4.14"
env_logger = "0.8.4"
polling = "2.0.0"
polling = "2.8.0"
url = "2.2.2"
crossbeam = "0.8"
cfg-if = "0.1"
Expand Down
21 changes: 17 additions & 4 deletions wruster/src/http/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::io::{self, BufReader};
use std::io::{self, BufReader, BufWriter};
use std::io::{prelude::*, Cursor};

use std::convert::Infallible;
Expand Down Expand Up @@ -370,10 +370,23 @@ impl Body {
*/
pub fn write<T: io::Write>(&mut self, to: &mut T) -> HttpResult<()> {
let src = &mut self.content;
if let Err(err) = io::copy(src, to) {
return Err(HttpError::Unknown(err.to_string()));

// When the content to write is large (>2MB's) the io::copy function
// uses buffers that are too short, to avoid that we use this technique:
// https://github.com/rust-lang/rust/pull/78641 to set the buffer to
// 1MB.
let res = match self.content_length {
x if x > u64::pow(2, 20) * 2 => {
let buff_size = usize::pow(2, 20);
let mut dest = BufWriter::with_capacity(buff_size, to);
io::copy(src, &mut dest)
}
_ => io::copy(src, to),
};
Ok(())
match res {
Err(err) => Err(HttpError::Unknown(err.to_string())),
Ok(_) => Ok(()),
}
}

/**
Expand Down
23 changes: 11 additions & 12 deletions wruster/src/streams/cancellable_stream.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use super::BaseStream;
use crate::log::debug;
use polling::Event;
use std::{
io,
Expand Down Expand Up @@ -29,7 +28,6 @@ where
let read_timeout = RwLock::new(None);
let write_timeout = RwLock::new(None);
let done = atomic::AtomicBool::new(false);
poller.add(stream.as_raw(), Event::all(1))?;
Ok(CancellableStream {
stream,
done,
Expand Down Expand Up @@ -57,7 +55,6 @@ where
}

fn read_int(&self, buf: &mut [u8]) -> io::Result<usize> {
debug!("read int");
self.poller
.modify(self.stream.as_raw(), Event::readable(1))?;
let mut events = Vec::new();
Expand Down Expand Up @@ -112,13 +109,14 @@ where
}

fn write_int(&self, buf: &[u8]) -> io::Result<usize> {
self.poller
.modify(self.stream.as_raw(), Event::writable(1))?;
let mut events = Vec::new();
let timeout = &self.write_timeout.write().unwrap().clone();
let mut bytes_written = 0;
let buf_len = buf.len();
while bytes_written < buf_len {
events.clear();
self.poller
.modify(self.stream.as_raw(), Event::writable(1))?;
if self.poller.wait(&mut events, *timeout)? == 0 {
let stop = self.done.load(atomic::Ordering::SeqCst);
if stop {
Expand All @@ -132,24 +130,25 @@ where
return Err(io::Error::from(io::ErrorKind::TimedOut));
}
for evt in &events {
if evt.key != 1 {
if evt.key != 1 || !evt.writable {
continue;
}
let write_buf = &buf[bytes_written..];
let s = &self.stream;
match s.write_buf(write_buf) {
Ok(n) => bytes_written += n,
Err(err) if err.kind() == io::ErrorKind::WouldBlock => continue,
Ok(n) => {
bytes_written += n;
}
Err(err) if err.kind() == io::ErrorKind::WouldBlock => {
break;
}
Err(err) => {
self.stream.set_nonblocking(false)?;
return Err(err);
}
};
if bytes_written == buf_len {
break;
}
break;
}
events.clear();
}
Ok(bytes_written)
}
Expand Down
30 changes: 16 additions & 14 deletions wruster/src/streams/test.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::{
cancellable_stream::CancellableStream,
observable::ObservedStreamList,
test_utils::{get_free_port, TcpClient},
test_utils::{get_free_port, load_test_file, test_file_size, TcpClient},
timeout_stream::TimeoutStream,
tls::test_utils::*,
*,
Expand Down Expand Up @@ -82,7 +82,7 @@ fn cancellable_stream_read_reads_data() {
}

#[test]
fn cancellable_steeam_read_honors_timeout() {
fn cancellable_stream_read_honors_timeout() {
env_logger::init();
let port = get_free_port();
let addr = format!("127.0.0.1:{}", port);
Expand All @@ -108,32 +108,34 @@ fn cancellable_steeam_read_honors_timeout() {

#[test]
fn cancellable_stream_write_writes_data() {
let data = "test ";
let port = get_free_port();
let addr = format!("127.0.0.1:{}", port);
let listener = TcpListener::bind(addr.clone()).unwrap();
let server_data = data.clone();
let mut server_data = load_test_file("big.png").unwrap();
let handle = thread::spawn(move || {
let (stream, _) = listener.accept().unwrap();
let mut cstream = CancellableStream::new(stream).unwrap();
let data = server_data.as_bytes();
let mut data = Vec::new();
server_data.read_to_end(&mut data).unwrap();
cstream.write(&data)
});

let mut client = TcpClient::connect(addr.to_string()).unwrap();
let bytes_sent = handle
.join()
.unwrap()
.expect("expected data to be written correctly");
assert_eq!(bytes_sent, data.len());

let mut reader = BufReader::new(&mut client);
let mut content = Vec::new();
let mut expected_file = load_test_file("big.png").unwrap();
let mut expected_data = Vec::new();
expected_file.read_to_end(&mut expected_data).unwrap();
let len = test_file_size("big.png").unwrap();
reader
.read_until(b' ', &mut content)
.read_to_end(&mut content)
.expect("expect data to available");
let content = String::from_utf8(content).expect("expect data to be valid");
assert_eq!(content, "test ".to_string());
assert_eq!(content, expected_data);
let bytes_sent = handle
.join()
.unwrap()
.expect("expected data to be written correctly");
assert_eq!(bytes_sent, len.try_into().unwrap());
}

#[test]
Expand Down
20 changes: 20 additions & 0 deletions wruster/src/streams/test_utils.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use std::{
error::Error,
fs::{self, File},
io::{self, Read, Write},
net::{Ipv4Addr, Shutdown, SocketAddrV4, TcpListener, TcpStream},
path::PathBuf,
};

pub struct TcpClient {
Expand Down Expand Up @@ -57,3 +59,21 @@ pub fn get_free_port() -> u16 {
.unwrap()
.port()
}

#[allow(dead_code)]
pub fn load_test_file(name: &str) -> Result<File, io::Error> {
let mut file_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
file_path.push("tests/assets");
file_path.push(name);
let file = fs::File::open(&file_path).unwrap();
return Ok(file);
}

#[allow(dead_code)]
pub fn test_file_size(name: &str) -> Result<u64, io::Error> {
let mut file_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
file_path.push("tests/assets");
file_path.push(name);
let metadata = fs::metadata(&file_path).unwrap();
Ok(metadata.len())
}
Binary file added wruster/tests/assets/big.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions wruster_handlers/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ Contains a set of helpful handlers, middlewares and utilities to create
new handlers in a wruster web server.
*/
use std::fs;
use std::io::BufReader;
use std::io::Read;
use std::{io, path::PathBuf};

#[macro_use]
Expand Down Expand Up @@ -68,7 +68,7 @@ pub fn serve_static(dir: &str, request: &Request) -> Response {
};
let mime_type = mime_guess::from_path(path).first_or_octet_stream();
let mut headers = Headers::new();
let content = Box::new(BufReader::new(content));
let content: Box<dyn Read> = Box::new(content);
headers.add(Header {
name: String::from("Content-Length"),
value: metadata.len().to_string(),
Expand Down

0 comments on commit cbdc1fc

Please sign in to comment.