Skip to content

Commit

Permalink
Make Client::request return IntoFuture builder
Browse files Browse the repository at this point in the history
Co-authored-by: Casper Beyer <[email protected]>
  • Loading branch information
n1ghtmare and caspervonb committed Jul 1, 2023
1 parent 98706f3 commit 26d8f1c
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 88 deletions.
157 changes: 80 additions & 77 deletions async-nats/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ use std::sync::Arc;
use std::time::Duration;
use thiserror::Error;
use tokio::sync::mpsc;
use tracing::trace;

static VERSION_RE: Lazy<Regex> =
Lazy::new(|| Regex::new(r#"\Av?([0-9]+)\.?([0-9]+)?\.?([0-9]+)?"#).unwrap());
Expand Down Expand Up @@ -317,10 +316,8 @@ impl Client {
/// # Ok(())
/// # }
/// ```
pub async fn request(&self, subject: String, payload: Bytes) -> Result<Message, RequestError> {
trace!("request sent to subject: {} ({})", subject, payload.len());
let request = Request::new().payload(payload);
self.send_request(subject, request).await
pub fn request(&self, subject: String, payload: Bytes) -> Request {
Request::new(self.clone(), subject, payload)
}

/// Sends the request with headers.
Expand All @@ -344,65 +341,11 @@ impl Client {
headers: HeaderMap,
payload: Bytes,
) -> Result<Message, RequestError> {
let request = Request::new().headers(headers).payload(payload);
self.send_request(subject, request).await
}
let message = Request::new(self.clone(), subject, payload)
.headers(headers)
.await?;

/// Sends the request created by the [Request].
///
/// # Examples
///
/// ```no_run
/// # #[tokio::main]
/// # async fn main() -> Result<(), async_nats::Error> {
/// let client = async_nats::connect("demo.nats.io").await?;
/// let request = async_nats::Request::new().payload("data".into());
/// let response = client.send_request("service".into(), request).await?;
/// # Ok(())
/// # }
/// ```
pub async fn send_request(
&self,
subject: String,
request: Request,
) -> Result<Message, RequestError> {
let inbox = request.inbox.unwrap_or_else(|| self.new_inbox());
let timeout = request.timeout.unwrap_or(self.request_timeout);
let mut sub = self.subscribe(inbox.clone()).await?;
let payload: Bytes = request.payload.unwrap_or_else(Bytes::new);
match request.headers {
Some(headers) => {
self.publish_with_reply_and_headers(subject, inbox, headers, payload)
.await?
}
None => self.publish_with_reply(subject, inbox, payload).await?,
}
self.flush()
.await
.map_err(|err| RequestError::with_source(RequestErrorKind::Other, err))?;
let request = match timeout {
Some(timeout) => {
tokio::time::timeout(timeout, sub.next())
.map_err(|err| RequestError::with_source(RequestErrorKind::TimedOut, err))
.await?
}
None => sub.next().await,
};
match request {
Some(message) => {
if message.status == Some(StatusCode::NO_RESPONDERS) {
return Err(RequestError::with_source(
RequestErrorKind::NoResponders,
"no responders",
));
}
Ok(message)
}
None => Err(RequestError::with_source(
RequestErrorKind::Other,
"broken pipe",
)),
}
Ok(message)
}

/// Create a new globally unique inbox which can be used for replies.
Expand Down Expand Up @@ -534,17 +477,26 @@ impl Client {
}

/// Used for building customized requests.
#[derive(Default)]
#[derive(Debug)]
pub struct Request {
client: Client,
subject: String,
payload: Option<Bytes>,
headers: Option<HeaderMap>,
timeout: Option<Option<Duration>>,
inbox: Option<String>,
}

impl Request {
pub fn new() -> Request {
Default::default()
pub fn new(client: Client, subject: String, payload: Bytes) -> Request {
Request {
client,
subject,
payload: Some(payload),
headers: None,
timeout: None,
inbox: None,
}
}

/// Sets the payload of the request. If not used, empty payload will be sent.
Expand All @@ -554,8 +506,7 @@ impl Request {
/// # #[tokio::main]
/// # async fn main() -> Result<(), async_nats::Error> {
/// let client = async_nats::connect("demo.nats.io").await?;
/// let request = async_nats::Request::new().payload("data".into());
/// client.send_request("service".into(), request).await?;
/// client.request("service".into(), "data".into()).await?;
/// # Ok(())
/// # }
/// ```
Expand All @@ -577,10 +528,11 @@ impl Request {
/// "X-Example",
/// async_nats::HeaderValue::from_str("Value").unwrap(),
/// );
/// let request = async_nats::Request::new()
/// client
/// .request("subject".into(), "data".into())
/// .headers(headers)
/// .payload("data".into());
/// client.send_request("service".into(), request).await?;
/// .await?;
///
/// # Ok(())
/// # }
/// ```
Expand All @@ -598,10 +550,11 @@ impl Request {
/// # #[tokio::main]
/// # async fn main() -> Result<(), async_nats::Error> {
/// let client = async_nats::connect("demo.nats.io").await?;
/// let request = async_nats::Request::new()
/// client
/// .request("service".into(), "data".into())
/// .timeout(Some(std::time::Duration::from_secs(15)))
/// .payload("data".into());
/// client.send_request("service".into(), request).await?;
/// .await?;
///
/// # Ok(())
/// # }
/// ```
Expand All @@ -618,17 +571,67 @@ impl Request {
/// # async fn main() -> Result<(), async_nats::Error> {
/// use std::str::FromStr;
/// let client = async_nats::connect("demo.nats.io").await?;
/// let request = async_nats::Request::new()
/// client
/// .request("subject".into(), "data".into())
/// .inbox("custom_inbox".into())
/// .payload("data".into());
/// client.send_request("service".into(), request).await?;
/// .await?;
///
/// # Ok(())
/// # }
/// ```
pub fn inbox(mut self, inbox: String) -> Request {
self.inbox = Some(inbox);
self
}

async fn send(self) -> Result<Message, RequestError> {
let inbox = self.inbox.unwrap_or_else(|| self.client.new_inbox());
let mut subscriber = self.client.subscribe(inbox.clone()).await?;
let mut publish = self
.client
.publish(self.subject, self.payload.unwrap_or_else(Bytes::new));

if let Some(headers) = self.headers {
publish = publish.headers(headers);
}

publish = publish.reply(inbox);
publish.into_future().await?;

self.client
.flush()
.map_err(|err| RequestError::with_source(RequestErrorKind::Other, err))
.await?;

let period = self.timeout.unwrap_or(self.client.request_timeout);
let message = match period {
Some(period) => {
tokio::time::timeout(period, subscriber.next())
.map_err(|_| RequestError::new(RequestErrorKind::TimedOut))
.await?
}
None => subscriber.next().await,
};

match message {
Some(message) => {
if message.status == Some(StatusCode::NO_RESPONDERS) {
return Err(RequestError::new(RequestErrorKind::NoResponders));
}
Ok(message)
}
None => Err(RequestError::new(RequestErrorKind::Other)),
}
}
}

impl IntoFuture for Request {
type Output = Result<Message, RequestError>;
type IntoFuture = Pin<Box<dyn Future<Output = Result<Message, RequestError>> + Send>>;

fn into_future(self) -> Self::IntoFuture {
Box::pin(self.send())
}
}

#[derive(Error, Debug)]
Expand Down
22 changes: 11 additions & 11 deletions async-nats/tests/client_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@
mod client {
use async_nats::connection::State;
use async_nats::header::HeaderValue;
use async_nats::{
ConnectErrorKind, ConnectOptions, Event, Request, RequestErrorKind, ServerAddr,
};
use async_nats::{ConnectErrorKind, ConnectOptions, Event, RequestErrorKind, ServerAddr};
use bytes::Bytes;
use futures::future::join_all;
use futures::stream::StreamExt;
use std::future::IntoFuture;
use std::path::PathBuf;
use std::str::FromStr;
use std::time::Duration;
Expand Down Expand Up @@ -239,7 +238,9 @@ mod client {

let resp = tokio::time::timeout(
tokio::time::Duration::from_millis(500),
client.request("test".into(), "request".into()),
client
.request("test".into(), "request".into())
.into_future(),
)
.await
.unwrap();
Expand Down Expand Up @@ -268,7 +269,9 @@ mod client {

let err = tokio::time::timeout(
tokio::time::Duration::from_millis(300),
client.request("test".into(), "request".into()),
client
.request("test".into(), "request".into())
.into_future(),
)
.await
.unwrap()
Expand Down Expand Up @@ -296,9 +299,9 @@ mod client {
}
});

let request = Request::new().inbox(inbox.clone());
client
.send_request("service".into(), request)
.request("service".into(), "".into())
.inbox(inbox)
.await
.unwrap();
}
Expand Down Expand Up @@ -770,10 +773,7 @@ mod client {
}
});

client
.request("request".into(), "data".into())
.await
.unwrap();
client.request("".into(), "data".into()).await.unwrap();
inbox_wildcard_subscription.next().await.unwrap();
}

Expand Down

0 comments on commit 26d8f1c

Please sign in to comment.