From 061bc75a444c15096e89374777693236d29b11ea Mon Sep 17 00:00:00 2001 From: eric Date: Mon, 7 Oct 2024 17:18:53 +0800 Subject: [PATCH] Refactor --- leaf/src/proxy/failover/mod.rs | 318 ++++++++++++++++----------------- 1 file changed, 158 insertions(+), 160 deletions(-) diff --git a/leaf/src/proxy/failover/mod.rs b/leaf/src/proxy/failover/mod.rs index 7298f6e57..126268315 100644 --- a/leaf/src/proxy/failover/mod.rs +++ b/leaf/src/proxy/failover/mod.rs @@ -22,160 +22,172 @@ pub use datagram::Handler as DatagramHandler; pub use stream::Handler as StreamHandler; #[derive(Debug, Eq, Ord, PartialEq, PartialOrd)] -pub(self) struct Measure(usize, u128, String); // (index, duration in millis, tag) +struct Measure { + idx: usize, + rtt: u128, + tag: String, +} + +impl Measure { + fn new(idx: usize, rtt: u128, tag: String) -> Self { + Self { idx, rtt, tag } + } +} -pub(self) async fn health_check( +async fn single_health_check( network: Network, idx: usize, tag: String, h: AnyOutboundHandler, dns_client: SyncDnsClient, delay: Duration, - health_check_timeout: u32, ) -> Measure { tokio::time::sleep(delay).await; - debug!("health checking [{}] ({}) index ({})", &tag, &network, idx); + let dest = match network { + Network::Tcp => SocksAddr::Domain("www.google.com".to_string(), 443), + Network::Udp => SocksAddr::Ip(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)), 53)), + }; - let tag_c = tag.clone(); + let sess = Session { + destination: dest, + new_conn_once: true, + ..Default::default() + }; - let measure = async move { - let dest = match network { - Network::Tcp => SocksAddr::Domain("www.google.com".to_string(), 443), - Network::Udp => { - SocksAddr::Ip(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)), 53)) - } - }; - - let sess = Session { - destination: dest, - new_conn_once: true, - ..Default::default() - }; - - let start = Instant::now(); - - match network { - Network::Tcp => { - let stream = - match crate::proxy::connect_stream_outbound(&sess, dns_client, &h).await { - Ok(s) => s, - Err(_) => return Measure(idx, u128::MAX, tag), + let start = Instant::now(); + + match network { + Network::Tcp => { + let stream = match crate::proxy::connect_stream_outbound(&sess, dns_client, &h).await { + Ok(s) => s, + Err(_) => return Measure::new(idx, u128::MAX, tag), + }; + let m: Measure; + + let Ok(h) = h.stream() else { + return Measure::new(idx, u128::MAX, tag); + }; + + // TODO Mock an LHS stream with the given payload. + match h.handle(&sess, None, stream).await { + Ok(stream) => { + let Ok(tls_handler) = crate::proxy::tls::outbound::StreamHandler::new( + String::from(""), + vec![], + None, + false, + ) else { + return Measure::new(idx, u128::MAX, tag); }; - let m: Measure; - let Ok(h) = h.stream() else { - return Measure(idx, u128::MAX, tag); - }; + let Ok(mut stream) = tls_handler.handle(&sess, None, Some(stream)).await else { + return Measure::new(idx, u128::MAX - 1, tag); + }; - // TODO Mock an LHS stream with the given payload. - match h.handle(&sess, None, stream).await { - Ok(stream) => { - let Ok(tls_handler) = crate::proxy::tls::outbound::StreamHandler::new( - String::from(""), - vec![], - None, - false, - ) else { - return Measure(idx, u128::MAX, tag); - }; - - let Ok(mut stream) = tls_handler.handle(&sess, None, Some(stream)).await - else { - return Measure(idx, u128::MAX - 1, tag); - }; - - if stream.write_all(b"GET / HTTP/1.1\r\n\r\n").await.is_err() { - return Measure(idx, u128::MAX - 2, tag); + if stream.write_all(b"GET / HTTP/1.1\r\n\r\n").await.is_err() { + return Measure::new(idx, u128::MAX - 2, tag); + } + let mut buf = BytesMut::with_capacity(2 * 1024); + match stream.read_buf(&mut buf).await { + Ok(n) => { + debug!( + "received {} bytes tcp health check response: {}", + n, + String::from_utf8_lossy(&buf[..n.min(12)]), + ); + let elapsed = Instant::now().duration_since(start); + m = Measure::new(idx, elapsed.as_millis(), tag); } - let mut buf = BytesMut::with_capacity(2 * 1024); - match stream.read_buf(&mut buf).await { - Ok(n) => { - debug!( - "received {} bytes tcp health check response: {}", - n, - String::from_utf8_lossy(&buf[..n.min(12)]), - ); - let elapsed = Instant::now().duration_since(start); - m = Measure(idx, elapsed.as_millis(), tag); - } - Err(_) => { - m = Measure(idx, u128::MAX - 3, tag); - } + Err(_) => { + m = Measure::new(idx, u128::MAX - 3, tag); } - let _ = stream.shutdown().await; - } - Err(_) => { - m = Measure(idx, u128::MAX, tag); } + let _ = stream.shutdown().await; + } + Err(_) => { + m = Measure::new(idx, u128::MAX, tag); } - return m; } - Network::Udp => { - let transport = - match crate::proxy::connect_datagram_outbound(&sess, dns_client, &h).await { - Ok(t) => t, - Err(_) => return Measure(idx, u128::MAX, tag), - }; - let h = if let Ok(h) = h.datagram() { - h - } else { - return Measure(idx, u128::MAX, tag); + return m; + } + Network::Udp => { + let transport = + match crate::proxy::connect_datagram_outbound(&sess, dns_client, &h).await { + Ok(t) => t, + Err(_) => return Measure::new(idx, u128::MAX, tag), }; - match h.handle(&sess, transport).await { - Ok(socket) => { - let addr = SocksAddr::Ip(SocketAddr::new( - IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)), - 53, - )); - let mut msg = Message::new(); - let name = match Name::from_str("www.google.com.") { - Ok(n) => n, - Err(e) => { - warn!("invalid domain name: {}", e); - return Measure(idx, u128::MAX, tag); - } - }; - let query = Query::query(name, RecordType::A); - msg.add_query(query); - let mut rng = StdRng::from_entropy(); - let id: u16 = rng.gen(); - msg.set_id(id); - msg.set_op_code(OpCode::Query); - msg.set_message_type(MessageType::Query); - msg.set_recursion_desired(true); - let msg_buf = match msg.to_vec() { - Ok(b) => b, - Err(e) => { - warn!("encode message to buffer failed: {}", e); - return Measure(idx, u128::MAX, tag); - } - }; - - let (mut recv, mut send) = socket.split(); - - if send.send_to(&msg_buf, &addr).await.is_err() { - return Measure(idx, u128::MAX - 2, tag); + let h = if let Ok(h) = h.datagram() { + h + } else { + return Measure::new(idx, u128::MAX, tag); + }; + match h.handle(&sess, transport).await { + Ok(socket) => { + let addr = + SocksAddr::Ip(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)), 53)); + let mut msg = Message::new(); + let name = match Name::from_str("www.google.com.") { + Ok(n) => n, + Err(e) => { + warn!("invalid domain name: {}", e); + return Measure::new(idx, u128::MAX, tag); + } + }; + let query = Query::query(name, RecordType::A); + msg.add_query(query); + let mut rng = StdRng::from_entropy(); + let id: u16 = rng.gen(); + msg.set_id(id); + msg.set_op_code(OpCode::Query); + msg.set_message_type(MessageType::Query); + msg.set_recursion_desired(true); + let msg_buf = match msg.to_vec() { + Ok(b) => b, + Err(e) => { + warn!("encode message to buffer failed: {}", e); + return Measure::new(idx, u128::MAX, tag); } - let mut buf = vec![0u8; 1500]; - match recv.recv_from(&mut buf).await { - Ok((n, _)) => { - debug!("received {} bytes udp health check response", n); - let elapsed = tokio::time::Instant::now().duration_since(start); - Measure(idx, elapsed.as_millis(), tag) - } - Err(_) => Measure(idx, u128::MAX - 3, tag), + }; + + let (mut recv, mut send) = socket.split(); + + if send.send_to(&msg_buf, &addr).await.is_err() { + return Measure::new(idx, u128::MAX - 2, tag); + } + let mut buf = vec![0u8; 1500]; + match recv.recv_from(&mut buf).await { + Ok((n, _)) => { + debug!("received {} bytes udp health check response", n); + let elapsed = tokio::time::Instant::now().duration_since(start); + Measure::new(idx, elapsed.as_millis(), tag) } + Err(_) => Measure::new(idx, u128::MAX - 3, tag), } - Err(_) => Measure(idx, u128::MAX, tag), } + Err(_) => Measure::new(idx, u128::MAX, tag), } } - }; + } +} - timeout(Duration::from_secs(health_check_timeout.into()), measure) - .await - .unwrap_or(Measure(idx, u128::MAX - 1, tag_c)) +async fn health_check( + network: Network, + idx: usize, + tag: String, + h: AnyOutboundHandler, + dns_client: SyncDnsClient, + delay: Duration, + health_check_timeout: u64, +) -> Measure { + debug!("health checking [{}] ({}) index ({})", &tag, &network, idx); + + timeout( + Duration::from_secs(health_check_timeout), + single_health_check(network, idx, tag.clone(), h, dns_client, delay), + ) + .await + .unwrap_or(Measure::new(idx, u128::MAX - 1, tag)) } pub(self) async fn health_check_task( @@ -212,43 +224,29 @@ pub(self) async fn health_check_task( a.clone(), dns_client_cloned, delay, - health_check_timeout, + health_check_timeout as u64, ))); } let mut measures = futures::future::join_all(checks).await; - measures.sort_by(|a, b| a.1.cmp(&b.1)); + measures.sort_by(|a, b| a.rtt.cmp(&b.rtt)); - debug!( - "[{}] sorted health check results:\n{:#?}", - network, measures - ); + debug!("[{}] sorted health check results: {:?}", network, measures); if !health_check_prefers.is_empty() { // Find the minimal RTT among the preferred outbounds. let mut min_prefer_actor_rtt = Duration::from_secs(health_check_timeout as u64).as_millis(); - for a in health_check_prefers.iter() { - if let Some(rtt) = measures.iter().find(|x| &x.2 == a) { - if rtt.1 < min_prefer_actor_rtt { - min_prefer_actor_rtt = rtt.1; + for t in health_check_prefers.iter() { + if let Some(m) = measures.iter().find(|x| &x.tag == t) { + if m.rtt < min_prefer_actor_rtt { + min_prefer_actor_rtt = m.rtt; } } } - fn is_preferred_actor( - idx: usize, - prefers: &[String], - actors: &[AnyOutboundHandler], - ) -> bool { - for a in prefers.iter() { - if let Ok(m_idx) = actors.binary_search_by_key(&a, |x| x.tag()) { - if idx == m_idx { - return true; - } - } - } - false + fn is_preferred_actor(tag: &String, prefers: &[String]) -> bool { + prefers.iter().find(|x| x == &tag).is_some() } // If an outbound is preferred, we subtract its RTT with the minimal @@ -256,15 +254,15 @@ pub(self) async fn health_check_task( // The min RTT must not larger than the timeout value to avoid // preferring unavailable outbounds. for m in measures.iter_mut() { - if is_preferred_actor(m.0, &health_check_prefers, &actors) { - m.1 -= min_prefer_actor_rtt; + if is_preferred_actor(&m.tag, &health_check_prefers) { + m.rtt -= min_prefer_actor_rtt; } } - measures.sort_by(|a, b| a.1.cmp(&b.1)); + measures.sort_by(|a, b| a.rtt.cmp(&b.rtt)); debug!( - "[{}] sorted health check results after applying prefer actors:\n{:#?}", + "[{}] sorted health check results after applying preferred actors: {:?}", network, measures ); } @@ -272,9 +270,9 @@ pub(self) async fn health_check_task( let priorities: Vec = measures .iter() .map(|m| { - let mut repr = actors[m.0].tag().to_owned(); + let mut repr = actors[m.idx].tag().to_owned(); repr.push('('); - repr.push_str(m.1.to_string().as_str()); + repr.push_str(m.rtt.to_string().as_str()); repr.push(')'); repr }) @@ -292,7 +290,7 @@ pub(self) async fn health_check_task( let all_failed = |measures: &Vec| -> bool { let threshold = Duration::from_secs(health_check_timeout.into()).as_millis(); for m in measures.iter() { - if m.1 < threshold { + if m.rtt < threshold { return false; } } @@ -302,12 +300,12 @@ pub(self) async fn health_check_task( if !(last_resort.is_some() && all_failed(&measures)) { if !failover { // if failover is disabled, put only 1 actor in schedule - schedule.push(measures[0].0); - trace!("put {} in schedule", measures[0].0); + schedule.push(measures[0].idx); + trace!("put {} in schedule", measures[0].idx); } else { for m in measures { - schedule.push(m.0); - trace!("put {} in schedule", m.0); + schedule.push(m.idx); + trace!("put {} in schedule", m.idx); } } }