diff --git a/Cargo.toml b/Cargo.toml index 76b80cf..0ea35c2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,9 @@ repository = "https://github.com/sbstp/attohttpc" [dependencies] base64 = {version = "0.13.0", optional = true} +bytes = {version = "1.2.1", optional = true} +cookie = {version = "0.16.0", optional = true} +cookie_store = {version = "0.16.1", optional = true} encoding_rs = {version = "0.8.31", optional = true} encoding_rs_io = {version = "0.1.7", optional = true} flate2 = {version = "1.0.24", default-features = false, optional = true} @@ -34,6 +37,7 @@ webpki-roots = {version = "0.22.4", optional = true} [dev-dependencies] anyhow = "1.0.61" +axum = "0.6.4" env_logger = "0.9.0" futures = "0.3.23" futures-util = "0.3.23" @@ -51,7 +55,8 @@ charsets = ["encoding_rs", "encoding_rs_io"] compress = ["flate2/default"] compress-zlib = ["flate2/zlib"] compress-zlib-ng = ["flate2/zlib-ng"] -default = ["compress", "tls-native"] +cookies = ["bytes", "cookie", "cookie_store"] +default = ["compress", "tls-native", "cookies"] form = ["serde", "serde_urlencoded"] json = ["serde", "serde_json"] multipart-form = ["multipart", "mime"] diff --git a/src/cookies.rs b/src/cookies.rs new file mode 100644 index 0000000..16092c4 --- /dev/null +++ b/src/cookies.rs @@ -0,0 +1,161 @@ +use std::{cell::RefCell, fmt::Write, rc::Rc}; + +use bytes::Bytes; +pub use cookie::Cookie; +use cookie_store::CookieStore; +use url::Url; + +use crate::header::HeaderValue; + +/// Values that can be converted into a [`Cookie`]. +pub trait IntoCookie { + /// Convert the value into a [`Cookie`]. + fn into_cookie(self) -> Cookie<'static>; +} + +impl IntoCookie for (T1, T2) +where + T1: Into, + T2: Into, +{ + fn into_cookie(self) -> Cookie<'static> { + Cookie::build(self.0.into(), self.1.into()).finish() + } +} + +impl<'a> IntoCookie for Cookie<'a> { + fn into_cookie(self) -> Cookie<'static> { + self.into_owned() + } +} + +impl<'a> IntoCookie for cookie::CookieBuilder<'a> { + fn into_cookie(self) -> Cookie<'static> { + self.finish().into_owned() + } +} + +/// Persists cookies between requests. +/// +/// All the typical cookie properties, such as expiry, domain, path and secure are respected. +/// Cookies should always be accessed through a [`Url`] for security reasons. +#[derive(Clone, Debug)] +pub struct CookieJar(Rc>); + +impl CookieJar { + pub(crate) fn new() -> Self { + CookieJar(Rc::new(RefCell::new(CookieStore::default()))) + } + + /// Get available cookies for the given [`Url`]. Only cookies that match + /// the domain, path and secure setting will be returned. Expired cookies + /// are not returned either. + pub fn cookies_for_url(&self, url: &Url) -> Vec<(String, String)> { + self.0 + .borrow() + .get_request_values(url) + .map(|(name, value)| (name.into(), value.into())) + .collect() + } + + /// Store the given [`Cookie`] in the [`CookieJar`] for the given [`Url`]. + /// If the [`Cookie`] has additional properties such as a specific path, domain or secure, + /// the [`Cookie`] will be stored with those properties. + pub fn store_cookie_for_url(&self, cookie: impl IntoCookie, url: &Url) { + self.0 + .borrow_mut() + .store_response_cookies(Some(cookie.into_cookie()).into_iter(), url) + } + + /// Remove all the cookies stored in the [CookieJar]. + pub fn clear(&mut self) { + self.0.borrow_mut().clear(); + } + + /// Get the cookies formatted as required by the `Cookie` header. + pub(crate) fn header_for_url(&self, url: &Url) -> Option { + let mut hvalue = String::new(); + for (idx, (name, value)) in self.0.borrow().get_request_values(url).enumerate() { + if idx > 0 { + hvalue.push_str("; "); + } + write!(hvalue, "{name}={value}").unwrap(); + } + + if hvalue.is_empty() { + return None; + } + + HeaderValue::from_maybe_shared(Bytes::from(hvalue)).ok() + } + + /// Store cookies into the jar using unparsed `Set-Cookie` headers. + pub(crate) fn store_header_for_url<'a>( + &self, + url: &Url, + set_cookie_headers: impl Iterator, + ) { + fn parse_cookie(buf: &[u8]) -> Result> { + let s = std::str::from_utf8(buf)?; + let c = Cookie::parse(s)?; + Ok(c) + } + + let iter = set_cookie_headers.filter_map(|v| match parse_cookie(v.as_bytes()) { + Ok(c) => Some(c.into_owned()), + Err(err) => { + warn!("Invalid cookie could not be stored to jar: {}", err); + None + } + }); + + self.0.borrow_mut().store_response_cookies(iter, url) + } +} + +impl Default for CookieJar { + fn default() -> Self { + Self::new() + } +} + +#[test] +fn test_cookies_for_url() { + let url = Url::parse("http://example.com").expect("invalid url"); + let jar = CookieJar::new(); + + jar.store_cookie_for_url(("foo", "baz"), &url); + assert!(!jar.cookies_for_url(&url).is_empty()); +} + +#[test] +fn test_header_for_url() { + let url = Url::parse("http://example.com").expect("invalid url"); + let jar = CookieJar::new(); + jar.store_cookie_for_url(("foo", "bar"), &url); + jar.store_cookie_for_url(("qux", "baz"), &url); + + let val = jar.header_for_url(&url).unwrap(); + + // unfortunately the cookies are stored in a HashMap and the iteration order is not guaranteed. + let val = std::str::from_utf8(val.as_bytes()).unwrap(); + let mut cookies = val.split("; ").collect::>(); + cookies.sort(); + + assert_eq!(cookies, vec!["foo=bar", "qux=baz"]); +} + +#[test] +fn test_security_secure() { + let url = Url::parse("https://example.com").expect("invalid url"); + let insecure_url = Url::parse("http://example.com").expect("invalid url"); + + let jar = CookieJar::new(); + jar.store_cookie_for_url(Cookie::build("foo", "baz").secure(true), &url); + + // same URL which is secure, have cookie + assert!(!jar.cookies_for_url(&url).is_empty()); + + // insecure URL, no cookie + assert!(jar.cookies_for_url(&insecure_url).is_empty()); +} diff --git a/src/lib.rs b/src/lib.rs index 6dc0d47..6a34f48 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -73,6 +73,8 @@ macro_rules! warn { #[cfg(feature = "charsets")] pub mod charsets; +#[cfg(feature = "cookies")] +mod cookies; mod error; mod happy; #[cfg(feature = "multipart")] @@ -82,6 +84,8 @@ mod request; mod streams; mod tls; +#[cfg(feature = "cookies")] +pub use crate::cookies::{Cookie, CookieJar, IntoCookie}; pub use crate::error::{Error, ErrorKind, InvalidResponseKind, Result}; #[cfg(feature = "multipart")] pub use crate::multipart::{Multipart, MultipartBuilder, MultipartFile}; diff --git a/src/parsing/compressed_reader.rs b/src/parsing/compressed_reader.rs index 1b9cd6f..61d44d2 100644 --- a/src/parsing/compressed_reader.rs +++ b/src/parsing/compressed_reader.rs @@ -155,7 +155,12 @@ mod tests { let _ = write!(buf, "HTTP/1.1 200 OK\r\nContent-Length: {}\r\n\r\n", payload.len()); buf.extend(payload); - let req = PreparedRequest::new(Method::GET, "http://google.ca"); + let req = PreparedRequest::new( + Method::GET, + "http://google.ca", + #[cfg(feature = "cookies")] + None, + ); let sock = BaseStream::mock(buf); let response = parse_response(sock, &req).unwrap(); @@ -178,7 +183,12 @@ mod tests { ); buf.extend(payload); - let req = PreparedRequest::new(Method::GET, "http://google.ca"); + let req = PreparedRequest::new( + Method::GET, + "http://google.ca", + #[cfg(feature = "cookies")] + None, + ); let sock = BaseStream::mock(buf); let response = parse_response(sock, &req).unwrap(); @@ -201,7 +211,12 @@ mod tests { ); buf.extend(payload); - let req = PreparedRequest::new(Method::GET, "http://google.ca"); + let req = PreparedRequest::new( + Method::GET, + "http://google.ca", + #[cfg(feature = "cookies")] + None, + ); let sock = BaseStream::mock(buf); let response = parse_response(sock, &req).unwrap(); @@ -214,7 +229,12 @@ mod tests { fn test_no_body_with_gzip() { let buf = b"HTTP/1.1 200 OK\r\ncontent-encoding: gzip\r\n\r\n"; - let req = PreparedRequest::new(Method::GET, "http://google.ca"); + let req = PreparedRequest::new( + Method::GET, + "http://google.ca", + #[cfg(feature = "cookies")] + None, + ); let sock = BaseStream::mock(buf.to_vec()); // Fixed by the move from libflate to flate2 assert!(parse_response(sock, &req).is_ok()); @@ -225,7 +245,12 @@ mod tests { fn test_no_body_with_gzip_head() { let buf = b"HTTP/1.1 200 OK\r\ncontent-encoding: gzip\r\n\r\n"; - let req = PreparedRequest::new(Method::HEAD, "http://google.ca"); + let req = PreparedRequest::new( + Method::HEAD, + "http://google.ca", + #[cfg(feature = "cookies")] + None, + ); let sock = BaseStream::mock(buf.to_vec()); assert!(parse_response(sock, &req).is_ok()); } diff --git a/src/request/builder.rs b/src/request/builder.rs index 74ee4e3..09f4cbd 100644 --- a/src/request/builder.rs +++ b/src/request/builder.rs @@ -15,6 +15,8 @@ use url::Url; #[cfg(feature = "charsets")] use crate::charsets::Charset; +#[cfg(feature = "cookies")] +use crate::cookies::CookieJar; use crate::error::{Error, ErrorKind, Result}; use crate::parsing::Response; use crate::request::{ @@ -38,6 +40,8 @@ pub struct RequestBuilder { method: Method, body: B, base_settings: BaseSettings, + #[cfg(feature = "cookies")] + cookie_jar: Option, } impl RequestBuilder { @@ -60,17 +64,40 @@ impl RequestBuilder { where U: AsRef, { - Self::try_with_settings(method, base_url, BaseSettings::default()) + Self::try_with_settings( + method, + base_url, + BaseSettings::default(), + #[cfg(feature = "cookies")] + None, + ) } - pub(crate) fn with_settings(method: Method, base_url: U, base_settings: BaseSettings) -> Self + pub(crate) fn with_settings( + method: Method, + base_url: U, + base_settings: BaseSettings, + #[cfg(feature = "cookies")] cookie_jar: Option, + ) -> Self where U: AsRef, { - Self::try_with_settings(method, base_url, base_settings).expect("invalid url or method") + Self::try_with_settings( + method, + base_url, + base_settings, + #[cfg(feature = "cookies")] + cookie_jar, + ) + .expect("invalid url or method") } - pub(crate) fn try_with_settings(method: Method, base_url: U, base_settings: BaseSettings) -> Result + pub(crate) fn try_with_settings( + method: Method, + base_url: U, + base_settings: BaseSettings, + #[cfg(feature = "cookies")] cookie_jar: Option, + ) -> Result where U: AsRef, { @@ -85,6 +112,8 @@ impl RequestBuilder { method, body: body::Empty, base_settings, + #[cfg(feature = "cookies")] + cookie_jar, }) } } @@ -152,6 +181,8 @@ impl RequestBuilder { method: self.method, body, base_settings: self.base_settings, + #[cfg(feature = "cookies")] + cookie_jar: self.cookie_jar, } } @@ -415,10 +446,21 @@ impl RequestBuilder { method: self.method, body: self.body, base_settings: self.base_settings, + #[cfg(feature = "cookies")] + cookie_jar: self.cookie_jar.clone(), }; header_insert(&mut prepped.base_settings.headers, CONNECTION, "close")?; - prepped.set_compression()?; + + #[cfg(feature = "flate2")] + if prepped.base_settings.allow_compression { + header_insert( + &mut prepped.base_settings.headers, + crate::header::ACCEPT_ENCODING, + "gzip, deflate", + )?; + } + match prepped.body.kind()? { BodyKind::Empty => (), BodyKind::KnownLength(len) => { @@ -436,6 +478,13 @@ impl RequestBuilder { header_insert_if_missing(&mut prepped.base_settings.headers, ACCEPT, "*/*")?; header_insert_if_missing(&mut prepped.base_settings.headers, USER_AGENT, DEFAULT_USER_AGENT)?; + #[cfg(feature = "cookies")] + if let Some(cookie_jar) = self.cookie_jar { + if let Some(header_val) = cookie_jar.header_for_url(&prepped.url) { + header_insert(&mut prepped.base_settings.headers, crate::header::COOKIE, header_val)?; + } + } + Ok(prepped) } diff --git a/src/request/mod.rs b/src/request/mod.rs index 2741406..db209ac 100644 --- a/src/request/mod.rs +++ b/src/request/mod.rs @@ -3,14 +3,14 @@ use std::io::{prelude::*, BufWriter}; use std::str; use std::time::Instant; -#[cfg(feature = "flate2")] -use http::header::ACCEPT_ENCODING; use http::{ header::{HeaderValue, IntoHeaderName, HOST}, HeaderMap, Method, StatusCode, Version, }; use url::Url; +#[cfg(feature = "cookies")] +use crate::cookies::CookieJar; use crate::error::{Error, ErrorKind, InvalidResponseKind, Result}; use crate::parsing::{parse_response, Response}; use crate::streams::{BaseStream, ConnectInfo}; @@ -67,11 +67,13 @@ pub struct PreparedRequest { method: Method, body: B, pub(crate) base_settings: BaseSettings, + #[cfg(feature = "cookies")] + cookie_jar: Option, } #[cfg(test)] impl PreparedRequest { - pub(crate) fn new(method: Method, base_url: U) -> Self + pub(crate) fn new(method: Method, base_url: U, #[cfg(feature = "cookies")] cookie_jar: Option) -> Self where U: AsRef, { @@ -80,24 +82,13 @@ impl PreparedRequest { method, body: body::Empty, base_settings: BaseSettings::default(), + #[cfg(feature = "cookies")] + cookie_jar, } } } impl PreparedRequest { - #[cfg(not(feature = "flate2"))] - fn set_compression(&mut self) -> Result { - Ok(()) - } - - #[cfg(feature = "flate2")] - fn set_compression(&mut self) -> Result { - if self.base_settings.allow_compression { - header_insert(&mut self.base_settings.headers, ACCEPT_ENCODING, "gzip, deflate")?; - } - Ok(()) - } - fn base_redirect_url(&self, location: &str, previous_url: &Url) -> Result { match Url::parse(location) { Ok(url) => Ok(url), @@ -232,6 +223,11 @@ impl PreparedRequest { debug!("status code {}", resp.status().as_u16()); + #[cfg(feature = "cookies")] + if let Some(cookie_jar) = &self.cookie_jar { + cookie_jar.store_header_for_url(&url, resp.headers().get_all(crate::header::SET_COOKIE).iter()); + } + let is_redirect = matches!( resp.status(), StatusCode::MOVED_PERMANENTLY @@ -334,6 +330,8 @@ mod test { url: Url::parse("http://reddit.com/r/rust").unwrap(), body: Empty, base_settings: BaseSettings::default(), + #[cfg(feature = "cookies")] + cookie_jar: None, }; let proxy = Url::parse("http://proxy:3128").unwrap(); @@ -353,6 +351,8 @@ mod test { url: Url::parse("http://reddit.com/r/rust").unwrap(), body: Empty, base_settings: BaseSettings::default(), + #[cfg(feature = "cookies")] + cookie_jar: None, }; let proxy = Url::parse("http://proxy:3128").unwrap(); diff --git a/src/request/session.rs b/src/request/session.rs index 7b15345..ff4109a 100644 --- a/src/request/session.rs +++ b/src/request/session.rs @@ -6,6 +6,8 @@ use http::Method; #[cfg(feature = "charsets")] use crate::charsets::Charset; +#[cfg(feature = "cookies")] +use crate::cookies::CookieJar; use crate::error::{Error, Result}; use crate::request::proxy::ProxySettings; use crate::request::{header_append, header_insert, BaseSettings, RequestBuilder}; @@ -16,6 +18,8 @@ use crate::tls::Certificate; #[derive(Debug, Default)] pub struct Session { base_settings: BaseSettings, + #[cfg(feature = "cookies")] + cookie_jar: CookieJar, } impl Session { @@ -23,15 +27,30 @@ impl Session { pub fn new() -> Session { Session { base_settings: BaseSettings::default(), + #[cfg(feature = "cookies")] + cookie_jar: CookieJar::new(), } } + fn make_request_builder(&self, method: Method, base_url: U) -> RequestBuilder + where + U: AsRef, + { + RequestBuilder::with_settings( + method, + base_url, + self.base_settings.clone(), + #[cfg(feature = "cookies")] + Some(self.cookie_jar.clone()), + ) + } + /// Create a new `RequestBuilder` with the GET method and this Session's settings applied on it. pub fn get(&self, base_url: U) -> RequestBuilder where U: AsRef, { - RequestBuilder::with_settings(Method::GET, base_url, self.base_settings.clone()) + self.make_request_builder(Method::GET, base_url) } /// Create a new `RequestBuilder` with the POST method and this Session's settings applied on it. @@ -39,7 +58,7 @@ impl Session { where U: AsRef, { - RequestBuilder::with_settings(Method::POST, base_url, self.base_settings.clone()) + self.make_request_builder(Method::POST, base_url) } /// Create a new `RequestBuilder` with the PUT method and this Session's settings applied on it. @@ -47,7 +66,7 @@ impl Session { where U: AsRef, { - RequestBuilder::with_settings(Method::PUT, base_url, self.base_settings.clone()) + self.make_request_builder(Method::PUT, base_url) } /// Create a new `RequestBuilder` with the DELETE method and this Session's settings applied on it. @@ -55,7 +74,7 @@ impl Session { where U: AsRef, { - RequestBuilder::with_settings(Method::DELETE, base_url, self.base_settings.clone()) + self.make_request_builder(Method::DELETE, base_url) } /// Create a new `RequestBuilder` with the HEAD method and this Session's settings applied on it. @@ -63,7 +82,7 @@ impl Session { where U: AsRef, { - RequestBuilder::with_settings(Method::HEAD, base_url, self.base_settings.clone()) + self.make_request_builder(Method::HEAD, base_url) } /// Create a new `RequestBuilder` with the OPTIONS method and this Session's settings applied on it. @@ -71,7 +90,7 @@ impl Session { where U: AsRef, { - RequestBuilder::with_settings(Method::OPTIONS, base_url, self.base_settings.clone()) + self.make_request_builder(Method::OPTIONS, base_url) } /// Create a new `RequestBuilder` with the PATCH method and this Session's settings applied on it. @@ -79,7 +98,7 @@ impl Session { where U: AsRef, { - RequestBuilder::with_settings(Method::PATCH, base_url, self.base_settings.clone()) + self.make_request_builder(Method::PATCH, base_url) } /// Create a new `RequestBuilder` with the TRACE method and this Session's settings applied on it. @@ -87,7 +106,7 @@ impl Session { where U: AsRef, { - RequestBuilder::with_settings(Method::TRACE, base_url, self.base_settings.clone()) + self.make_request_builder(Method::TRACE, base_url) } // @@ -249,4 +268,28 @@ impl Session { pub fn add_root_certificate(&mut self, cert: Certificate) { self.base_settings.root_certificates.0.push(cert); } + + /// Get a reference to the [`CookieJar`] within the session. + /// + /// The [`CookieJar`] can be used to retrieve and/or modify the cookies in the [`Session`]. + /// The cookies are automatically persisted across requests in a secure manner. + /// + /// # Examples + /// + /// ``` + /// # use std::error::Error; + /// # use attohttpc::Session; + /// # use url::Url; + /// # fn main() -> Result<(), Box> { + /// let url = Url::parse("http://example.com")?; + /// let sess = Session::new(); + /// sess.cookie_jar().store_cookie_for_url(("token", "ABCDEF123"), &url); + /// sess.get("http://example.com").send()?; + /// # Ok(()) + /// # } + /// ``` + #[cfg(feature = "cookies")] + pub fn cookie_jar(&self) -> &CookieJar { + &self.cookie_jar + } } diff --git a/tests/test_cookies.rs b/tests/test_cookies.rs new file mode 100644 index 0000000..adfe017 --- /dev/null +++ b/tests/test_cookies.rs @@ -0,0 +1,31 @@ +use std::net::SocketAddr; + +use attohttpc::Session; +use axum::http::{header, StatusCode}; +use axum::response::IntoResponse; +use axum::routing::get; +use axum::Router; +use url::Url; + +#[tokio::test(flavor = "multi_thread")] +#[cfg(feature = "cookies")] +async fn test_redirection_default() -> Result<(), anyhow::Error> { + async fn root() -> impl IntoResponse { + (StatusCode::OK, [(header::SET_COOKIE, "foo=bar")], "Hello, World!") + } + + let app = Router::new().route("/", get(root)); + + let addr = SocketAddr::from(([127, 0, 0, 1], 3939)); + tokio::spawn(axum::Server::bind(&addr).serve(app.into_make_service())); + + let sess = Session::new(); + sess.get("http://localhost:3939").send()?; + let cookies = sess + .cookie_jar() + .cookies_for_url(&Url::parse("http://localhost:3939").unwrap()); + + assert!(!cookies.is_empty()); + + Ok(()) +} diff --git a/tools/tests.bash b/tools/tests.bash index ec7356e..641bcc6 100755 --- a/tools/tests.bash +++ b/tools/tests.bash @@ -18,6 +18,7 @@ cargo test --no-default-features --features charsets cargo test --no-default-features --features compress cargo test --no-default-features --features compress-zlib cargo test --no-default-features --features compress-zlib-ng +cargo test --no-default-features --features cookies cargo test --no-default-features --features form cargo test --no-default-features --features multipart-form cargo test --no-default-features --features json