diff --git a/async-nats/Cargo.toml b/async-nats/Cargo.toml index b1f10af40..869484af1 100644 --- a/async-nats/Cargo.toml +++ b/async-nats/Cargo.toml @@ -44,6 +44,7 @@ criterion = { version = "0.5", features = ["async_tokio"]} nats-server = { path = "../nats-server" } rand = "0.8" tokio = { version = "1.25.0", features = ["rt-multi-thread"] } +tokio-util = { version = "0.7", features = ["io"] } futures = { version = "0.3.28", default-features = false, features = ["std", "async-await"] } tracing-subscriber = "0.3" async-nats = {path = ".", features = ["experimental"]} diff --git a/async-nats/src/jetstream/object_store/mod.rs b/async-nats/src/jetstream/object_store/mod.rs index 64a23f77e..0627dba9a 100644 --- a/async-nats/src/jetstream/object_store/mod.rs +++ b/async-nats/src/jetstream/object_store/mod.rs @@ -12,21 +12,23 @@ // limitations under the License. //! Object Store module -use std::collections::VecDeque; use std::fmt::Display; -use std::{cmp, str::FromStr, task::Poll, time::Duration}; +use std::io; +use std::pin::Pin; +use std::task::Context; +use std::{str::FromStr, task::Poll, time::Duration}; use crate::subject::Subject; use crate::{HeaderMap, HeaderValue}; use base64::engine::general_purpose::{STANDARD, URL_SAFE}; use base64::engine::Engine; -use bytes::BytesMut; +use bytes::{Bytes, BytesMut}; use futures::future::BoxFuture; use once_cell::sync::Lazy; use ring::digest::SHA256; use tokio::io::AsyncReadExt; -use futures::{Stream, StreamExt}; +use futures::{FutureExt, Stream, StreamExt}; use regex::Regex; use serde::{Deserialize, Serialize}; use tracing::{debug, trace}; @@ -88,24 +90,42 @@ pub struct ObjectStore { impl ObjectStore { /// Gets an [Object] from the [ObjectStore]. /// - /// [Object] implements [tokio::io::AsyncRead] that allows - /// to read the data from Object Store. + /// [Object] implements [Stream] that allows + /// to stream chunks from Object Store. /// /// # Examples /// /// ```no_run /// # #[tokio::main] - /// # async fn main() -> Result<(), async_nats::Error> { - /// use tokio::io::AsyncReadExt; + /// # async fn main() -> Result<(), Box> { + /// use std::env; + /// + /// use tokio::fs::File; + /// /// let client = async_nats::connect("demo.nats.io").await?; /// let jetstream = async_nats::jetstream::new(client); /// /// let bucket = jetstream.get_object_store("store").await?; /// let mut object = bucket.get("FOO").await?; /// - /// // Object implements `tokio::io::AsyncRead`. - /// let mut bytes = vec![]; - /// object.read_to_end(&mut bytes).await?; + /// // Use the `Stream` implementation + /// use futures::TryStreamExt as _; + /// use tokio::io::AsyncWriteExt as _; + /// + /// let mut file = File::create(env::temp_dir().join("FOO.bin")).await?; + /// while let Some(chunk) = object.try_next().await? { + /// file.write_all(&chunk).await?; + /// } + /// file.sync_all().await?; + /// + /// // Alternatively use `tokio_util` with the `io` feature + /// // to convert the `Stream` into `AsyncRead` + /// // (less efficient because of the added memcpy) + /// let mut reader = tokio_util::io::StreamReader::new(object); + /// + /// let mut file = File::create(env::temp_dir().join("FOO.bin")).await?; + /// tokio::io::copy(&mut reader, &mut file).await?; + /// file.sync_all().await?; /// # Ok(()) /// # } /// ``` @@ -919,7 +939,6 @@ impl Stream for List { /// Represents an object stored in a bucket. pub struct Object { pub info: ObjectInfo, - remaining_bytes: VecDeque, has_pending_messages: bool, digest: Option, subscription: Option, @@ -932,7 +951,6 @@ impl Object { Object { subscription: None, info, - remaining_bytes: VecDeque::new(), has_pending_messages: true, digest: Some(ring::digest::Context::new(&SHA256)), subscription_future: None, @@ -946,24 +964,19 @@ impl Object { } } -impl tokio::io::AsyncRead for Object { - fn poll_read( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &mut tokio::io::ReadBuf<'_>, - ) -> std::task::Poll> { - let (buf1, _buf2) = self.remaining_bytes.as_slices(); - if !buf1.is_empty() { - let len = cmp::min(buf.remaining(), buf1.len()); - buf.put_slice(&buf1[..len]); - self.remaining_bytes.drain(..len); - return Poll::Ready(Ok(())); +impl Stream for Object { + type Item = io::Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if !self.has_pending_messages { + return Poll::Ready(None); } - if self.has_pending_messages { - if self.subscription.is_none() { - let future = match self.subscription_future.as_mut() { - Some(future) => future, + let subscription = match &mut self.subscription { + Some(subscription) => subscription, + None => { + let subscription_future = match &mut self.subscription_future { + Some(subscription_future) => subscription_future, None => { let stream = self.stream.clone(); let bucket = self.info.bucket.clone(); @@ -982,77 +995,68 @@ impl tokio::io::AsyncRead for Object { })) } }; - match future.as_mut().poll(cx) { - Poll::Ready(subscription) => { - self.subscription = Some(subscription.unwrap()); + + match subscription_future.poll_unpin(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(Ok(subscription)) => self.subscription.insert(subscription), + Poll::Ready(Err(err)) => { + return Poll::Ready(Some(Err(io::Error::new( + io::ErrorKind::Other, + format!("error from JetStream create subscription: {err}"), + )))) } - Poll::Pending => (), } } - if let Some(subscription) = self.subscription.as_mut() { - match subscription.poll_next_unpin(cx) { - Poll::Ready(message) => match message { - Some(message) => { - let message = message.map_err(|err| { - std::io::Error::new( - std::io::ErrorKind::Other, - format!("error from JetStream subscription: {err}"), - ) - })?; - let len = cmp::min(buf.remaining(), message.payload.len()); - buf.put_slice(&message.payload[..len]); - if let Some(context) = &mut self.digest { - context.update(&message.payload); - } - self.remaining_bytes.extend(&message.payload[len..]); + }; - let info = message.info().map_err(|err| { - std::io::Error::new( - std::io::ErrorKind::Other, - format!("error from JetStream subscription: {err}"), - ) - })?; - if info.pending == 0 { - let digest = self.digest.take().map(|context| context.finish()); - if let Some(digest) = digest { - if self - .info - .digest - .as_ref() - .map(|digest_self| { - format!("SHA-256={}", URL_SAFE.encode(digest)) - != *digest_self - }) - .unwrap_or(false) - { - return Poll::Ready(Err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - "wrong digest", - ))); - } - } else { - return Poll::Ready(Err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - "digest should be Some", - ))); - } - self.has_pending_messages = false; - self.subscription = None; - } - Poll::Ready(Ok(())) + match subscription.poll_next_unpin(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(Some(Ok(message))) => { + if let Some(digest) = &mut self.digest { + digest.update(&message.message.payload); + } + + let info = message.info().map_err(|err| { + io::Error::new( + io::ErrorKind::Other, + format!("error from JetStream subscription: {err}"), + ) + })?; + + if info.pending == 0 { + self.has_pending_messages = false; + self.subscription = None; + + if let Some(digest) = self.digest.take() { + let digest = digest.finish(); + + if self + .info + .digest + .as_ref() + .map(|digest_self| { + format!("SHA-256={}", URL_SAFE.encode(digest)) != *digest_self + }) + .unwrap_or(false) + { + return Poll::Ready(Some(Err(io::Error::new( + io::ErrorKind::InvalidData, + "wrong digest", + )))); } - None => Poll::Ready(Err(std::io::Error::new( - std::io::ErrorKind::Other, - "subscription ended before reading whole object", - ))), - }, - Poll::Pending => Poll::Pending, + } } - } else { - Poll::Pending + + Poll::Ready(Some(Ok(message.message.payload))) } - } else { - Poll::Ready(Ok(())) + Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(io::Error::new( + io::ErrorKind::Other, + format!("error from JetStream subscription: {err}"), + )))), + Poll::Ready(None) => Poll::Ready(Some(Err(io::Error::new( + io::ErrorKind::Other, + "subscription ended before reading whole object", + )))), } } } diff --git a/async-nats/tests/compatibility.rs b/async-nats/tests/compatibility.rs index 1509ad4ce..7716116d9 100644 --- a/async-nats/tests/compatibility.rs +++ b/async-nats/tests/compatibility.rs @@ -13,7 +13,7 @@ #[cfg(feature = "compatibility_tests")] mod compatibility { - use futures::{pin_mut, stream::Peekable, StreamExt}; + use futures::{pin_mut, stream::Peekable, StreamExt, TryStreamExt}; use core::panic; use std::{collections::HashMap, pin::Pin, str::from_utf8}; @@ -27,7 +27,6 @@ mod compatibility { }; use ring::digest::{self, SHA256}; use serde::{Deserialize, Serialize}; - use tokio::io::AsyncReadExt; #[tokio::test] async fn kv() { @@ -226,7 +225,9 @@ mod compatibility { let mut object = bucket.get(request.object).await.unwrap(); let mut contents = vec![]; - object.read_to_end(&mut contents).await.unwrap(); + while let Some(chunk) = object.try_next().await.unwrap() { + contents.extend_from_slice(&chunk); + } let digest = digest::digest(&SHA256, &contents); @@ -295,7 +296,9 @@ mod compatibility { let mut object = bucket.get(request.object).await.unwrap(); let mut contents = vec![]; - object.read_to_end(&mut contents).await.unwrap(); + while let Some(chunk) = object.try_next().await.unwrap() { + contents.extend_from_slice(&chunk); + } let digest = digest::digest(&SHA256, &contents); diff --git a/async-nats/tests/object_store.rs b/async-nats/tests/object_store.rs index b7366cabe..10534742a 100644 --- a/async-nats/tests/object_store.rs +++ b/async-nats/tests/object_store.rs @@ -20,10 +20,9 @@ mod object_store { stream::DirectGetErrorKind, }; use base64::Engine; - use futures::StreamExt; + use futures::{StreamExt, TryStreamExt}; use rand::RngCore; use ring::digest::SHA256; - use tokio::io::AsyncReadExt; #[tokio::test] async fn get_and_put() { @@ -51,16 +50,8 @@ mod object_store { let mut object = bucket.get("FOO").await.unwrap(); let mut result = Vec::new(); - loop { - let mut buffer = [0; 1024]; - if let Ok(n) = object.read(&mut buffer).await { - if n == 0 { - println!("finished"); - break; - } - - result.extend_from_slice(&buffer[..n]); - } + while let Some(chunk) = object.try_next().await.unwrap() { + result.extend_from_slice(&chunk); } assert_eq!( Some(format!( @@ -79,7 +70,9 @@ mod object_store { let mut contents = Vec::new(); tracing::info!("reading content"); - object_link.read_to_end(&mut contents).await.unwrap(); + while let Some(chunk) = object_link.try_next().await.unwrap() { + contents.extend_from_slice(&chunk); + } assert_eq!(contents, result); bucket @@ -350,7 +343,9 @@ mod object_store { assert_eq!(object.info.digest, Some(format!("SHA-256={digest}"))); let mut result = Vec::new(); - object.read_to_end(&mut result).await.unwrap(); + while let Some(chunk) = object.try_next().await.unwrap() { + result.extend_from_slice(&chunk); + } assert_eq!(result, file); } }