diff --git a/leaf/src/app/dispatcher.rs b/leaf/src/app/dispatcher.rs index b749d6399..408c911e5 100644 --- a/leaf/src/app/dispatcher.rs +++ b/leaf/src/app/dispatcher.rs @@ -83,18 +83,19 @@ impl Dispatcher { T: 'static + AsyncRead + AsyncWrite + Unpin + Send + Sync, { debug!("dispatching {}:{}", &sess.network, &sess.destination); - let mut lhs: Box = if *option::DOMAIN_SNIFFING - && !sess.destination.is_domain() - && sess.destination.port() == 443 - { + let mut lhs: Box = if sniff::should_sniff(&sess) { let mut lhs = sniff::SniffingStream::new(lhs); - match lhs.sniff().await { + match lhs.sniff(&sess).await { Ok(res) => { if let Some(domain) = res { debug!( "sniffed domain {} for tcp link {} <-> {}", &domain, &sess.source, &sess.destination, ); + // TODO Add an option to use the sniffed domain for routing only + // + // TODO Add DNS sniff, sniff domain name from DNS response, keep + // an IP -> domain mapping, use this info for routing only. sess.destination = match SocksAddr::try_from((&domain, sess.destination.port())) { Ok(a) => a, diff --git a/leaf/src/common/sniff.rs b/leaf/src/common/sniff.rs index 737bd2ab6..78b3fface 100644 --- a/leaf/src/common/sniff.rs +++ b/leaf/src/common/sniff.rs @@ -8,11 +8,46 @@ use bytes::BytesMut; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf}; use tokio::time::timeout; +use crate::option; +use crate::session::Session; + +fn should_sniff_tls(sess: &Session) -> bool { + if *option::TLS_DOMAIN_SNIFFING { + if !*option::TLS_DOMAIN_SNIFFING_ALL && sess.destination.port() != 443 { + return false; + } + true + } else { + false + } +} + +fn should_sniff_http(sess: &Session) -> bool { + if *option::HTTP_DOMAIN_SNIFFING { + if !*option::HTTP_DOMAIN_SNIFFING_ALL && sess.destination.port() != 80 { + return false; + } + true + } else { + false + } +} + +pub fn should_sniff(sess: &Session) -> bool { + !sess.destination.is_domain() && (should_sniff_tls(sess) || should_sniff_http(sess)) +} + pub struct SniffingStream { inner: T, buf: BytesMut, } +enum SniffResult { + NotMatch, + NotEnoughData, + Domain(String), +} + impl SniffingStream where T: AsyncRead + AsyncWrite + Unpin, @@ -24,128 +59,186 @@ where } } - pub async fn sniff(&mut self) -> io::Result> { + fn sniff_http_host(&self, buf: &[u8]) -> SniffResult { + // Credits https://github.com/eycorsican/leaf/pull/288 + + let bytes_str = String::from_utf8_lossy(buf); + let parts: Vec<&str> = bytes_str.split("\r\n").collect(); + + if parts.len() == 0 { + return SniffResult::NotMatch; + } + + let http_methods = [ + "get", "post", "head", "put", "delete", "options", "connect", "patch", "trace", + ]; + let method_str = parts[0]; + + let matched_method = http_methods + .into_iter() + .filter(|item| method_str.to_lowercase().contains(item)) + .count(); + + if matched_method == 0 { + return SniffResult::NotMatch; + } + + for (idx, &el) in parts.iter().enumerate() { + if idx == 0 || el == "" { + continue; + } + let inner_parts: Vec<&str> = el.split(":").collect(); + if inner_parts.len() != 2 { + continue; + } + if inner_parts[0].to_lowercase() == "host" { + return SniffResult::Domain(inner_parts[1].trim().to_string()); + } + } + + SniffResult::NotMatch + } + + fn sniff_tls_sni(&self, buf: &[u8]) -> SniffResult { + // https://tls.ulfheim.net/ + + let sbuf = &buf[..]; + if sbuf.len() < 5 { + return SniffResult::NotEnoughData; + } + // handshake record type + if sbuf[0] != 0x16 { + return SniffResult::NotMatch; + } + // protocol version + if sbuf[1] != 0x3 { + return SniffResult::NotMatch; + } + let header_len = u16::from_be_bytes(sbuf[3..5].try_into().unwrap()) as usize; + if sbuf.len() < 5 + header_len { + return SniffResult::NotEnoughData; + } + let sbuf = &sbuf[5..5 + header_len]; + // ? + if sbuf.len() < 42 { + return SniffResult::NotEnoughData; + } + let session_id_len = sbuf[38] as usize; + if session_id_len > 32 || sbuf.len() < 39 + session_id_len { + return SniffResult::NotEnoughData; + } + let sbuf = &sbuf[39 + session_id_len..]; + if sbuf.len() < 2 { + return SniffResult::NotEnoughData; + } + let cipher_suite_bytes = u16::from_be_bytes(sbuf[..2].try_into().unwrap()) as usize; + if sbuf.len() < 2 + cipher_suite_bytes { + return SniffResult::NotEnoughData; + } + let sbuf = &sbuf[2 + cipher_suite_bytes..]; + if sbuf.is_empty() { + return SniffResult::NotEnoughData; + } + let compression_method_bytes = sbuf[0] as usize; + if sbuf.len() < 1 + compression_method_bytes { + return SniffResult::NotEnoughData; + } + let sbuf = &sbuf[1 + compression_method_bytes..]; + if sbuf.len() < 2 { + return SniffResult::NotEnoughData; + } + let extensions_bytes = u16::from_be_bytes(sbuf[..2].try_into().unwrap()) as usize; + if sbuf.len() < 2 + extensions_bytes { + return SniffResult::NotEnoughData; + } + let mut sbuf = &sbuf[2..2 + extensions_bytes]; + while !sbuf.is_empty() { + // extension + extension-specific-len + if sbuf.len() < 4 { + return SniffResult::NotEnoughData; + } + let extension = u16::from_be_bytes(sbuf[..2].try_into().unwrap()); + let extension_len = u16::from_be_bytes(sbuf[2..4].try_into().unwrap()) as usize; + sbuf = &sbuf[4..]; + if sbuf.len() < extension_len { + return SniffResult::NotEnoughData; + } + // extension "server name" + if extension == 0x0 { + let mut ebuf = &sbuf[..extension_len]; + if ebuf.len() < 2 { + return SniffResult::NotEnoughData; + } + let entry_len = u16::from_be_bytes(ebuf[..2].try_into().unwrap()) as usize; + ebuf = &ebuf[2..]; + if ebuf.len() < entry_len { + return SniffResult::NotEnoughData; + } + // just make sure no oob + if ebuf.is_empty() { + return SniffResult::NotEnoughData; + } + let entry_type = ebuf[0]; + // type "DNS hostname" + if entry_type == 0x0 { + ebuf = &ebuf[1..]; + // just make sure no oob + if ebuf.len() < 2 { + return SniffResult::NotEnoughData; + } + let hostname_len = u16::from_be_bytes(ebuf[..2].try_into().unwrap()) as usize; + ebuf = &ebuf[2..]; + if ebuf.len() < hostname_len { + return SniffResult::NotEnoughData; + } + return SniffResult::Domain( + String::from_utf8_lossy(&ebuf[..hostname_len]).into(), + ); + } else { + // TODO + // I assume there's only "DNS hostname" type + // in the the "server name" extension, should + // check if this is true later. + // + // I also assume there's only one entry in the + // "server name" extension list. + return SniffResult::NotMatch; + } + } else { + sbuf = &sbuf[extension_len..]; + } + } + SniffResult::NotEnoughData + } + + pub async fn sniff(&mut self, sess: &Session) -> io::Result> { let mut buf = vec![0u8; 2 * 1024]; - 'outer: for _ in 0..2 { + for _ in 0..2 { match timeout(Duration::from_millis(100), self.inner.read(&mut buf)).await { Ok(res) => match res { Ok(n) => { self.buf.extend_from_slice(&buf[..n]); - - // https://tls.ulfheim.net/ - - let sbuf = &self.buf[..]; - if sbuf.len() < 5 { - continue; + let mut tls_not_match = true; + let mut http_not_match = true; + if should_sniff_tls(sess) { + tls_not_match = false; + match self.sniff_tls_sni(&buf[..n]) { + SniffResult::NotEnoughData => (), + SniffResult::NotMatch => tls_not_match = true, + SniffResult::Domain(domain) => return Ok(Some(domain)), + } } - // handshake record type - if sbuf[0] != 0x16 { - return Ok(None); + if should_sniff_http(sess) { + http_not_match = false; + match self.sniff_http_host(&buf[..n]) { + SniffResult::NotEnoughData => (), + SniffResult::NotMatch => http_not_match = true, + SniffResult::Domain(domain) => return Ok(Some(domain)), + } } - // protocol version - if sbuf[1] != 0x3 { + if tls_not_match && http_not_match { return Ok(None); } - let header_len = - u16::from_be_bytes(sbuf[3..5].try_into().unwrap()) as usize; - if sbuf.len() < 5 + header_len { - continue; - } - let sbuf = &sbuf[5..5 + header_len]; - // ? - if sbuf.len() < 42 { - continue; - } - let session_id_len = sbuf[38] as usize; - if session_id_len > 32 || sbuf.len() < 39 + session_id_len { - continue; - } - let sbuf = &sbuf[39 + session_id_len..]; - if sbuf.len() < 2 { - continue; - } - let cipher_suite_bytes = - u16::from_be_bytes(sbuf[..2].try_into().unwrap()) as usize; - if sbuf.len() < 2 + cipher_suite_bytes { - continue; - } - let sbuf = &sbuf[2 + cipher_suite_bytes..]; - if sbuf.is_empty() { - continue; - } - let compression_method_bytes = sbuf[0] as usize; - if sbuf.len() < 1 + compression_method_bytes { - continue; - } - let sbuf = &sbuf[1 + compression_method_bytes..]; - if sbuf.len() < 2 { - continue; - } - let extensions_bytes = - u16::from_be_bytes(sbuf[..2].try_into().unwrap()) as usize; - if sbuf.len() < 2 + extensions_bytes { - continue; - } - let mut sbuf = &sbuf[2..2 + extensions_bytes]; - while !sbuf.is_empty() { - // extension + extension-specific-len - if sbuf.len() < 4 { - continue 'outer; - } - let extension = u16::from_be_bytes(sbuf[..2].try_into().unwrap()); - let extension_len = - u16::from_be_bytes(sbuf[2..4].try_into().unwrap()) as usize; - sbuf = &sbuf[4..]; - if sbuf.len() < extension_len { - continue 'outer; - } - // extension "server name" - if extension == 0x0 { - let mut ebuf = &sbuf[..extension_len]; - if ebuf.len() < 2 { - continue 'outer; - } - let entry_len = - u16::from_be_bytes(ebuf[..2].try_into().unwrap()) as usize; - ebuf = &ebuf[2..]; - if ebuf.len() < entry_len { - continue 'outer; - } - // just make sure no oob - if ebuf.is_empty() { - continue 'outer; - } - let entry_type = ebuf[0]; - // type "DNS hostname" - if entry_type == 0x0 { - ebuf = &ebuf[1..]; - // just make sure no oob - if ebuf.len() < 2 { - continue 'outer; - } - let hostname_len = - u16::from_be_bytes(ebuf[..2].try_into().unwrap()) as usize; - ebuf = &ebuf[2..]; - if ebuf.len() < hostname_len { - continue 'outer; - } - return Ok(Some( - String::from_utf8_lossy(&ebuf[..hostname_len]).into(), - )); - } else { - // TODO - // I assume there's only "DNS hostname" type - // in the the "server name" extension, should - // check if this is true later. - // - // I also assume there's only one entry in the - // "server name" extension list. - return Ok(None); - } - } else { - sbuf = &sbuf[extension_len..]; - } - } } Err(e) => { return Err(e); diff --git a/leaf/src/option/mod.rs b/leaf/src/option/mod.rs index 6c6b14a29..1590512e4 100644 --- a/leaf/src/option/mod.rs +++ b/leaf/src/option/mod.rs @@ -97,8 +97,35 @@ lazy_static! { get_env_var_or("LOG_NO_COLOR", false) }; - pub static ref DOMAIN_SNIFFING: bool = { - get_env_var_or("DOMAIN_SNIFFING", false) + /// Turn on TLS SNI sniffing, the sniffed SNI would override the original + /// destination address, by default the sniffing would perform only on + /// connections with destination port 443, set also TLS_DOMAIN_SNIFFING_ALL + /// to make the sniffing work on all connections. + pub static ref TLS_DOMAIN_SNIFFING: bool = { + get_env_var_or_else( + "TLS_DOMAIN_SNIFFING", + || get_env_var_or("DOMAIN_SNIFFING", false), // deprecated env var + ) + }; + + /// Turn on TLS SNI sniffing for all TCP connections, this may slow down the + /// connections a little bit, depending on whether the sniff can make an early + /// return. + pub static ref TLS_DOMAIN_SNIFFING_ALL: bool = { + get_env_var_or("TLS_DOMAIN_SNIFFING_ALL", false) + }; + + /// Turn on HTTP host sniffing, by default only perform on connections with + /// destination port 80. + pub static ref HTTP_DOMAIN_SNIFFING: bool = { + get_env_var_or("HTTP_DOMAIN_SNIFFING", false) + }; + + /// Turn on HTTP host sniffing for all TCP connections, this may slow down the + /// connections a little bit, depending on whether the sniff can make an early + /// return. + pub static ref HTTP_DOMAIN_SNIFFING_ALL: bool = { + get_env_var_or("HTTP_DOMAIN_SNIFFING_ALL", false) }; /// Uplink timeout after downlink EOF.