Skip to content

Commit

Permalink
Merge pull request #1766 from tursodatabase/libsql-wal-multipart-upload
Browse files Browse the repository at this point in the history
libsql wal multipart upload
  • Loading branch information
MarinPostma authored Oct 2, 2024
2 parents 748a73d + 77151c2 commit b2d59ca
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 41 deletions.
54 changes: 27 additions & 27 deletions libsql-server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -734,33 +734,6 @@ async fn build_server(
async fn main() -> Result<()> {
let args = Cli::parse();

if let Some(ref subcommand) = args.subcommand {
match subcommand {
UtilsSubcommands::AdminShell {
admin_api_url,
namespace,
auth,
} => {
let client = libsql_server::admin_shell::AdminShellClient::new(
admin_api_url.clone(),
auth.clone(),
);
if let Some(ns) = namespace {
client.run_namespace(ns).await?;
}
}
UtilsSubcommands::WalToolkit {
command,
path,
s3_args,
} => {
command.exec(path, s3_args).await?;
}
}

return Ok(());
}

if std::env::var("RUST_LOG").is_err() {
std::env::set_var("RUST_LOG", "info");
}
Expand Down Expand Up @@ -789,6 +762,33 @@ async fn main() -> Result<()> {
)
.init();

if let Some(ref subcommand) = args.subcommand {
match subcommand {
UtilsSubcommands::AdminShell {
admin_api_url,
namespace,
auth,
} => {
let client = libsql_server::admin_shell::AdminShellClient::new(
admin_api_url.clone(),
auth.clone(),
);
if let Some(ns) = namespace {
client.run_namespace(ns).await?;
}
}
UtilsSubcommands::WalToolkit {
command,
path,
s3_args,
} => {
command.exec(path, s3_args).await?;
}
}

return Ok(());
}

args.print_welcome_message();
let server = build_server(&args, set_log_level).await?;
server.start().await?;
Expand Down
4 changes: 2 additions & 2 deletions libsql-wal/src/storage/backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ pub trait Backend: Send + Sync + 'static {
config: &Self::Config,
namespace: &NamespaceName,
key: &SegmentKey,
segment_data: impl Stream<Item = Result<Bytes>> + Send + Sync + 'static,
segment_data: impl Stream<Item = Result<Bytes>> + Send + 'static,
) -> impl Future<Output = Result<()>> + Send;

/// Store `segment_data` with its associated `meta`
Expand Down Expand Up @@ -224,7 +224,7 @@ impl<T: Backend> Backend for Arc<T> {
config: &Self::Config,
namespace: &NamespaceName,
key: &SegmentKey,
segment_data: impl Stream<Item = Result<Bytes>> + Send + Sync + 'static,
segment_data: impl Stream<Item = Result<Bytes>> + Send + 'static,
) -> Result<()> {
self.as_ref()
.store_segment_data(config, namespace, key, segment_data)
Expand Down
145 changes: 133 additions & 12 deletions libsql-wal/src/storage/backend/s3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,15 @@ use aws_config::SdkConfig;
use aws_sdk_s3::operation::create_bucket::CreateBucketError;
use aws_sdk_s3::operation::get_object::GetObjectOutput;
use aws_sdk_s3::primitives::{ByteStream, SdkBody};
use aws_sdk_s3::types::{CreateBucketConfiguration, Object};
use aws_sdk_s3::types::{
CompletedMultipartUpload, CompletedPart, CreateBucketConfiguration, Object,
};
use aws_sdk_s3::Client;
use bytes::{Bytes, BytesMut};
use chrono::{DateTime, Utc};
use http_body::{Frame as HttpFrame, SizeHint};
use futures::future::poll_fn;
use futures::{StreamExt, TryFutureExt as _, TryStreamExt};
use http_body::{Body, Frame as HttpFrame, SizeHint};
use libsql_sys::name::NamespaceName;
use pin_project_lite::pin_project;
use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncReadExt, BufReader};
Expand Down Expand Up @@ -277,6 +281,121 @@ impl<IO: Io> S3Backend<IO> {
Ok(())
}

#[tracing::instrument(skip_all, fields(key))]
async fn s3_put_multipart<B>(
&self,
config: &S3Config,
key: impl ToString,
data: B,
) -> Result<()>
where
B: Body<Data = Bytes, Error = crate::storage::Error>,
{
const MAX_CHUNK_SIZE: u64 = 50 * 1024 * 1024; // 50MB
let (s_chunks, r_chunks) = tokio::sync::mpsc::channel(8);
let key = key.to_string();

let upload = self
.client
.create_multipart_upload()
.bucket(&config.bucket)
.key(&key)
.send()
.await
.map_err(|e| Error::unhandled(e, "creating multipart upload"))?;
let upload_id = upload.upload_id();

let make_chunk_fut = async {
let mut current_chunk_file = self.io.tempfile()?;
let mut current_chunk_len = 0;
tokio::pin!(data);
loop {
let Some(frame) = poll_fn(|cx| data.as_mut().poll_frame(cx)).await else {
break;
};
let frame = frame?;
assert!(frame.is_data());
let data = frame.into_data().unwrap();
let offset = current_chunk_len;
current_chunk_len += data.len() as u64;
let (_, ret) = current_chunk_file.write_all_at_async(data, offset).await;
ret?;
if current_chunk_len >= MAX_CHUNK_SIZE {
let new_chunk_file = self.io.tempfile()?;
current_chunk_len = 0;
let old_chunk_file = std::mem::replace(&mut current_chunk_file, new_chunk_file);
if s_chunks.send(old_chunk_file).await.is_err() {
break;
}
}
}

// make sure we move the sender in the future so the chunk sender eventually exits.
drop(s_chunks);

Ok(())
};

let send_chunks_fut = async {
let builder = tokio_stream::wrappers::ReceiverStream::new(r_chunks)
.enumerate()
.map(|(i, chunk)| {
let i = i;
self.client
.upload_part()
.bucket(&config.bucket)
.key(&key)
.part_number(i as i32 + 1) // part number must be between 1-10000
.set_upload_id(upload_id.map(ToString::to_string))
.body(FileStreamBody::new(chunk).into_byte_stream())
.send()
.map_err(|e| Error::unhandled(e, format!("sending chunk")))
.map_ok(move |resp| {
CompletedPart::builder()
.set_e_tag(resp.e_tag)
.set_part_number(Some(i as i32 + 1))
.build()
})
})
.buffered(8)
.try_fold(CompletedMultipartUpload::builder(), |builder, completed| {
std::future::ready(Ok(builder.parts(completed)))
})
.await?;

Ok(builder.build())
};

let ret = tokio::try_join!(send_chunks_fut, make_chunk_fut);

match ret {
Ok((parts, _)) => {
self.client
.complete_multipart_upload()
.bucket(&config.bucket)
.set_upload_id(upload_id.map(ToString::to_string))
.multipart_upload(parts)
.key(&key)
.send()
.await
.map_err(|e| Error::unhandled(e, format!("completing multipart upload")))?;

Ok(())
}
Err(e) => {
self.client
.abort_multipart_upload()
.bucket(&config.bucket)
.set_upload_id(upload_id.map(ToString::to_string))
.key(&key)
.send()
.await
.map_err(|e| Error::unhandled(e, format!("aborting multipart upload")))?;
Err(e)
}
}
}

async fn fetch_segment_index_inner(
&self,
config: &S3Config,
Expand Down Expand Up @@ -491,20 +610,23 @@ impl<IO: Io> S3Backend<IO> {
}
}

async fn store_segment_data_inner(
async fn store_segment_data_inner<B>(
&self,
config: &S3Config,
namespace: &NamespaceName,
body: ByteStream,
body: B,
segment_key: &SegmentKey,
) -> Result<()> {
) -> Result<()>
where
B: Body<Data = Bytes, Error = crate::storage::Error>,
{
let folder_key = FolderKey {
cluster_id: &config.cluster_id,
namespace,
};
let s3_data_key = s3_segment_data_key(&folder_key, segment_key);

self.s3_put(config, s3_data_key, body).await
self.s3_put_multipart(config, s3_data_key, body).await
}

async fn store_segment_index_inner(
Expand Down Expand Up @@ -635,7 +757,7 @@ where
segment_index: Vec<u8>,
) -> Result<()> {
let segment_key = SegmentKey::from(&meta);
let body = FileStreamBody::new(segment_data).into_byte_stream();
let body = FileStreamBody::new(segment_data);
self.store_segment_data_inner(config, &meta.namespace, body, &segment_key)
.await?;
self.store_segment_index(config, &meta.namespace, &segment_key, segment_index)
Expand Down Expand Up @@ -764,11 +886,10 @@ where
config: &Self::Config,
namespace: &NamespaceName,
segment_key: &SegmentKey,
segment_data: impl Stream<Item = Result<Bytes>> + Send + Sync + 'static,
segment_data: impl Stream<Item = Result<Bytes>> + Send + 'static,
) -> Result<()> {
let byte_stream = StreamBody::new(segment_data);
let body = ByteStream::from_body_1_x(byte_stream);
self.store_segment_data_inner(config, namespace, body, &segment_key)
self.store_segment_data_inner(config, namespace, byte_stream, &segment_key)
.await?;

Ok(())
Expand Down Expand Up @@ -869,7 +990,7 @@ where
F: FileExt,
{
type Data = Bytes;
type Error = std::io::Error;
type Error = crate::storage::Error;

fn poll_frame(
mut self: Pin<&mut Self>,
Expand Down Expand Up @@ -903,7 +1024,7 @@ where
}
Poll::Ready(Err(e)) => {
self.state = StreamState::Done;
return Poll::Ready(Some(Err(e)));
return Poll::Ready(Some(Err(e.into())));
}
Poll::Pending => return Poll::Pending,
},
Expand Down

0 comments on commit b2d59ca

Please sign in to comment.