diff --git a/libsql-server/src/main.rs b/libsql-server/src/main.rs index 8f7f7ab98f..74095c2d03 100644 --- a/libsql-server/src/main.rs +++ b/libsql-server/src/main.rs @@ -734,33 +734,6 @@ async fn build_server( async fn main() -> Result<()> { let args = Cli::parse(); - if let Some(ref subcommand) = args.subcommand { - match subcommand { - UtilsSubcommands::AdminShell { - admin_api_url, - namespace, - auth, - } => { - let client = libsql_server::admin_shell::AdminShellClient::new( - admin_api_url.clone(), - auth.clone(), - ); - if let Some(ns) = namespace { - client.run_namespace(ns).await?; - } - } - UtilsSubcommands::WalToolkit { - command, - path, - s3_args, - } => { - command.exec(path, s3_args).await?; - } - } - - return Ok(()); - } - if std::env::var("RUST_LOG").is_err() { std::env::set_var("RUST_LOG", "info"); } @@ -789,6 +762,33 @@ async fn main() -> Result<()> { ) .init(); + if let Some(ref subcommand) = args.subcommand { + match subcommand { + UtilsSubcommands::AdminShell { + admin_api_url, + namespace, + auth, + } => { + let client = libsql_server::admin_shell::AdminShellClient::new( + admin_api_url.clone(), + auth.clone(), + ); + if let Some(ns) = namespace { + client.run_namespace(ns).await?; + } + } + UtilsSubcommands::WalToolkit { + command, + path, + s3_args, + } => { + command.exec(path, s3_args).await?; + } + } + + return Ok(()); + } + args.print_welcome_message(); let server = build_server(&args, set_log_level).await?; server.start().await?; diff --git a/libsql-wal/src/storage/backend/mod.rs b/libsql-wal/src/storage/backend/mod.rs index cc90c0db75..4ff3437b3e 100644 --- a/libsql-wal/src/storage/backend/mod.rs +++ b/libsql-wal/src/storage/backend/mod.rs @@ -57,7 +57,7 @@ pub trait Backend: Send + Sync + 'static { config: &Self::Config, namespace: &NamespaceName, key: &SegmentKey, - segment_data: impl Stream> + Send + Sync + 'static, + segment_data: impl Stream> + Send + 'static, ) -> impl Future> + Send; /// Store `segment_data` with its associated `meta` @@ -224,7 +224,7 @@ impl Backend for Arc { config: &Self::Config, namespace: &NamespaceName, key: &SegmentKey, - segment_data: impl Stream> + Send + Sync + 'static, + segment_data: impl Stream> + Send + 'static, ) -> Result<()> { self.as_ref() .store_segment_data(config, namespace, key, segment_data) diff --git a/libsql-wal/src/storage/backend/s3.rs b/libsql-wal/src/storage/backend/s3.rs index 70b1b321b7..c9aea76741 100644 --- a/libsql-wal/src/storage/backend/s3.rs +++ b/libsql-wal/src/storage/backend/s3.rs @@ -12,11 +12,15 @@ use aws_config::SdkConfig; use aws_sdk_s3::operation::create_bucket::CreateBucketError; use aws_sdk_s3::operation::get_object::GetObjectOutput; use aws_sdk_s3::primitives::{ByteStream, SdkBody}; -use aws_sdk_s3::types::{CreateBucketConfiguration, Object}; +use aws_sdk_s3::types::{ + CompletedMultipartUpload, CompletedPart, CreateBucketConfiguration, Object, +}; use aws_sdk_s3::Client; use bytes::{Bytes, BytesMut}; use chrono::{DateTime, Utc}; -use http_body::{Frame as HttpFrame, SizeHint}; +use futures::future::poll_fn; +use futures::{StreamExt, TryFutureExt as _, TryStreamExt}; +use http_body::{Body, Frame as HttpFrame, SizeHint}; use libsql_sys::name::NamespaceName; use pin_project_lite::pin_project; use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncReadExt, BufReader}; @@ -277,6 +281,121 @@ impl S3Backend { Ok(()) } + #[tracing::instrument(skip_all, fields(key))] + async fn s3_put_multipart( + &self, + config: &S3Config, + key: impl ToString, + data: B, + ) -> Result<()> + where + B: Body, + { + const MAX_CHUNK_SIZE: u64 = 50 * 1024 * 1024; // 50MB + let (s_chunks, r_chunks) = tokio::sync::mpsc::channel(8); + let key = key.to_string(); + + let upload = self + .client + .create_multipart_upload() + .bucket(&config.bucket) + .key(&key) + .send() + .await + .map_err(|e| Error::unhandled(e, "creating multipart upload"))?; + let upload_id = upload.upload_id(); + + let make_chunk_fut = async { + let mut current_chunk_file = self.io.tempfile()?; + let mut current_chunk_len = 0; + tokio::pin!(data); + loop { + let Some(frame) = poll_fn(|cx| data.as_mut().poll_frame(cx)).await else { + break; + }; + let frame = frame?; + assert!(frame.is_data()); + let data = frame.into_data().unwrap(); + let offset = current_chunk_len; + current_chunk_len += data.len() as u64; + let (_, ret) = current_chunk_file.write_all_at_async(data, offset).await; + ret?; + if current_chunk_len >= MAX_CHUNK_SIZE { + let new_chunk_file = self.io.tempfile()?; + current_chunk_len = 0; + let old_chunk_file = std::mem::replace(&mut current_chunk_file, new_chunk_file); + if s_chunks.send(old_chunk_file).await.is_err() { + break; + } + } + } + + // make sure we move the sender in the future so the chunk sender eventually exits. + drop(s_chunks); + + Ok(()) + }; + + let send_chunks_fut = async { + let builder = tokio_stream::wrappers::ReceiverStream::new(r_chunks) + .enumerate() + .map(|(i, chunk)| { + let i = i; + self.client + .upload_part() + .bucket(&config.bucket) + .key(&key) + .part_number(i as i32 + 1) // part number must be between 1-10000 + .set_upload_id(upload_id.map(ToString::to_string)) + .body(FileStreamBody::new(chunk).into_byte_stream()) + .send() + .map_err(|e| Error::unhandled(e, format!("sending chunk"))) + .map_ok(move |resp| { + CompletedPart::builder() + .set_e_tag(resp.e_tag) + .set_part_number(Some(i as i32 + 1)) + .build() + }) + }) + .buffered(8) + .try_fold(CompletedMultipartUpload::builder(), |builder, completed| { + std::future::ready(Ok(builder.parts(completed))) + }) + .await?; + + Ok(builder.build()) + }; + + let ret = tokio::try_join!(send_chunks_fut, make_chunk_fut); + + match ret { + Ok((parts, _)) => { + self.client + .complete_multipart_upload() + .bucket(&config.bucket) + .set_upload_id(upload_id.map(ToString::to_string)) + .multipart_upload(parts) + .key(&key) + .send() + .await + .map_err(|e| Error::unhandled(e, format!("completing multipart upload")))?; + + Ok(()) + } + Err(e) => { + self.client + .abort_multipart_upload() + .bucket(&config.bucket) + .set_upload_id(upload_id.map(ToString::to_string)) + .key(&key) + .send() + .await + .map_err(|e| Error::unhandled(e, format!("aborting multipart upload")))?; + Err(e) + } + } + } + async fn fetch_segment_index_inner( &self, config: &S3Config, @@ -491,20 +610,23 @@ impl S3Backend { } } - async fn store_segment_data_inner( + async fn store_segment_data_inner( &self, config: &S3Config, namespace: &NamespaceName, - body: ByteStream, + body: B, segment_key: &SegmentKey, - ) -> Result<()> { + ) -> Result<()> + where + B: Body, + { let folder_key = FolderKey { cluster_id: &config.cluster_id, namespace, }; let s3_data_key = s3_segment_data_key(&folder_key, segment_key); - self.s3_put(config, s3_data_key, body).await + self.s3_put_multipart(config, s3_data_key, body).await } async fn store_segment_index_inner( @@ -635,7 +757,7 @@ where segment_index: Vec, ) -> Result<()> { let segment_key = SegmentKey::from(&meta); - let body = FileStreamBody::new(segment_data).into_byte_stream(); + let body = FileStreamBody::new(segment_data); self.store_segment_data_inner(config, &meta.namespace, body, &segment_key) .await?; self.store_segment_index(config, &meta.namespace, &segment_key, segment_index) @@ -764,11 +886,10 @@ where config: &Self::Config, namespace: &NamespaceName, segment_key: &SegmentKey, - segment_data: impl Stream> + Send + Sync + 'static, + segment_data: impl Stream> + Send + 'static, ) -> Result<()> { let byte_stream = StreamBody::new(segment_data); - let body = ByteStream::from_body_1_x(byte_stream); - self.store_segment_data_inner(config, namespace, body, &segment_key) + self.store_segment_data_inner(config, namespace, byte_stream, &segment_key) .await?; Ok(()) @@ -869,7 +990,7 @@ where F: FileExt, { type Data = Bytes; - type Error = std::io::Error; + type Error = crate::storage::Error; fn poll_frame( mut self: Pin<&mut Self>, @@ -903,7 +1024,7 @@ where } Poll::Ready(Err(e)) => { self.state = StreamState::Done; - return Poll::Ready(Some(Err(e))); + return Poll::Ready(Some(Err(e.into()))); } Poll::Pending => return Poll::Pending, },