diff --git a/.config/nats.dic b/.config/nats.dic index b3adc9f98..d9ac2fbf6 100644 --- a/.config/nats.dic +++ b/.config/nats.dic @@ -136,7 +136,7 @@ rustls Acker EndpointSchema auth -filter_subject filter_subjects rollup IoT +RttError diff --git a/async-nats/src/client.rs b/async-nats/src/client.rs index 9be37ad19..9ee4553de 100644 --- a/async-nats/src/client.rs +++ b/async-nats/src/client.rs @@ -485,6 +485,35 @@ impl Client { Ok(()) } + /// Calculates the round trip time between this client and the server, + /// if the server is currently connected. + /// + /// # Examples + /// + /// ```no_run + /// # #[tokio::main] + /// # async fn main() -> Result<(), async_nats::Error> { + /// let client = async_nats::connect("demo.nats.io").await?; + /// let rtt = client.rtt().await?; + /// println!("server rtt: {:?}", rtt); + /// # Ok(()) + /// # } + /// ``` + pub async fn rtt(&self) -> Result { + let (tx, rx) = tokio::sync::oneshot::channel(); + + self.sender.send(Command::Rtt { result: tx }).await?; + + let rtt = rx + .await + // first handle rx error + .map_err(|err| RttError(Box::new(err)))? + // second handle the actual rtt error + .map_err(|err| RttError(Box::new(err)))?; + + Ok(rtt) + } + /// Returns the current state of the connection. /// /// # Examples @@ -684,3 +713,14 @@ impl From for RequestError { RequestError::with_source(RequestErrorKind::Other, e) } } + +/// Error returned when doing a round-trip time measurement fails. +#[derive(Debug, Error)] +#[error("failed to measure round-trip time: {0}")] +pub struct RttError(#[source] Box); + +impl From> for RttError { + fn from(err: tokio::sync::mpsc::error::SendError) -> Self { + RttError(Box::new(err)) + } +} diff --git a/async-nats/src/lib.rs b/async-nats/src/lib.rs index 3288e2a1a..2242d531e 100644 --- a/async-nats/src/lib.rs +++ b/async-nats/src/lib.rs @@ -124,6 +124,7 @@ use thiserror::Error; use futures::future::FutureExt; use futures::select; use futures::stream::Stream; +use std::time::Instant; use tracing::{debug, error}; use core::fmt; @@ -280,6 +281,9 @@ pub(crate) enum Command { result: oneshot::Sender>, }, TryFlush, + Rtt { + result: oneshot::Sender>, + }, } /// `ClientOp` represents all actions of `Client`. @@ -323,6 +327,9 @@ pub(crate) struct ConnectionHandler { info_sender: tokio::sync::watch::Sender, ping_interval: Interval, flush_interval: Interval, + last_ping_time: Option, + last_pong_time: Option, + rtt_senders: Vec>>, } impl ConnectionHandler { @@ -347,6 +354,9 @@ impl ConnectionHandler { info_sender, ping_interval, flush_interval, + last_ping_time: None, + last_pong_time: None, + rtt_senders: Vec::new(), } } @@ -425,6 +435,22 @@ impl ConnectionHandler { } ServerOp::Pong => { debug!("received PONG"); + if self.pending_pings == 1 { + self.last_pong_time = Some(Instant::now()); + + while let Some(sender) = self.rtt_senders.pop() { + if let (Some(ping), Some(pong)) = (self.last_ping_time, self.last_pong_time) + { + let rtt = pong.duration_since(ping); + sender.send(Ok(rtt)).map_err(|_| { + io::Error::new( + io::ErrorKind::Other, + "one shot failed to be received", + ) + })?; + } + } + } self.pending_pings = self.pending_pings.saturating_sub(1); } ServerOp::Error(error) => { @@ -538,6 +564,14 @@ impl ConnectionHandler { } } } + Command::Rtt { result } => { + self.rtt_senders.push(result); + + if self.pending_pings == 0 { + // do a ping and expect a pong - will calculate rtt when handling the pong + self.handle_ping().await?; + } + } Command::Flush { result } => { if let Err(_err) = self.handle_flush().await { if let Err(err) = self.handle_disconnect().await { @@ -612,8 +646,39 @@ impl ConnectionHandler { Ok(()) } + async fn handle_ping(&mut self) -> Result<(), io::Error> { + debug!( + "PING command. Pending pings {}, max pings {}", + self.pending_pings, MAX_PENDING_PINGS + ); + self.pending_pings += 1; + self.ping_interval.reset(); + + if self.pending_pings > MAX_PENDING_PINGS { + debug!( + "pending pings {}, max pings {}. disconnecting", + self.pending_pings, MAX_PENDING_PINGS + ); + self.handle_disconnect().await?; + } + + if self.pending_pings == 1 { + // start the clock for calculating round trip time + self.last_ping_time = Some(Instant::now()); + } + + if let Err(_err) = self.connection.write_op(&ClientOp::Ping).await { + self.handle_disconnect().await?; + } + + self.handle_flush().await?; + Ok(()) + } + async fn handle_disconnect(&mut self) -> io::Result<()> { self.pending_pings = 0; + self.last_ping_time = None; + self.last_pong_time = None; self.connector.events_tx.try_send(Event::Disconnected).ok(); self.connector.state_tx.send(State::Disconnected).ok(); self.handle_reconnect().await?; diff --git a/async-nats/tests/client_tests.rs b/async-nats/tests/client_tests.rs index 5505e21d0..cc8bc7a83 100644 --- a/async-nats/tests/client_tests.rs +++ b/async-nats/tests/client_tests.rs @@ -867,4 +867,15 @@ mod client { .await .unwrap(); } + + #[tokio::test] + async fn rtt() { + let server = nats_server::run_basic_server(); + let client = async_nats::connect(server.client_url()).await.unwrap(); + + let rtt = client.rtt().await.unwrap(); + + println!("rtt: {:?}", rtt); + assert!(rtt.as_nanos() > 0); + } }