Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor async blob read to avoid blocking calls, support non multipa… #10192

Merged
merged 1 commit into from
Sep 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -228,35 +228,50 @@ public void readBlobAsync(String blobName, ActionListener<ReadContext> listener)
try (AmazonAsyncS3Reference amazonS3Reference = SocketAccess.doPrivileged(blobStore::asyncClientReference)) {
final S3AsyncClient s3AsyncClient = amazonS3Reference.get().client();
final String bucketName = blobStore.bucket();
final String blobKey = buildKey(blobName);

final GetObjectAttributesResponse blobMetadata = getBlobMetadata(s3AsyncClient, bucketName, blobName).get();
final CompletableFuture<GetObjectAttributesResponse> blobMetadataFuture = getBlobMetadata(s3AsyncClient, bucketName, blobKey);

final long blobSize = blobMetadata.objectSize();
final int numberOfParts = blobMetadata.objectParts().totalPartsCount();
final String blobChecksum = blobMetadata.checksum().checksumCRC32();

final List<InputStreamContainer> blobPartStreams = new ArrayList<>();
final List<CompletableFuture<InputStreamContainer>> blobPartInputStreamFutures = new ArrayList<>();
// S3 multipart files use 1 to n indexing
for (int partNumber = 1; partNumber <= numberOfParts; partNumber++) {
blobPartInputStreamFutures.add(getBlobPartInputStreamContainer(s3AsyncClient, bucketName, blobName, partNumber));
}

CompletableFuture.allOf(blobPartInputStreamFutures.toArray(CompletableFuture[]::new)).whenComplete((unused, throwable) -> {
if (throwable == null) {
listener.onResponse(
new ReadContext(
blobSize,
blobPartInputStreamFutures.stream().map(CompletableFuture::join).collect(Collectors.toList()),
blobChecksum
)
);
} else {
blobMetadataFuture.whenComplete((blobMetadata, throwable) -> {
if (throwable != null) {
Exception ex = throwable.getCause() instanceof Exception
? (Exception) throwable.getCause()
: new Exception(throwable.getCause());
listener.onFailure(ex);
return;
}

final List<CompletableFuture<InputStreamContainer>> blobPartInputStreamFutures = new ArrayList<>();
final long blobSize = blobMetadata.objectSize();
final Integer numberOfParts = blobMetadata.objectParts() == null ? null : blobMetadata.objectParts().totalPartsCount();
final String blobChecksum = blobMetadata.checksum().checksumCRC32();

if (numberOfParts == null) {
blobPartInputStreamFutures.add(getBlobPartInputStreamContainer(s3AsyncClient, bucketName, blobKey, null));
} else {
// S3 multipart files use 1 to n indexing
for (int partNumber = 1; partNumber <= numberOfParts; partNumber++) {
blobPartInputStreamFutures.add(getBlobPartInputStreamContainer(s3AsyncClient, bucketName, blobKey, partNumber));
}
}

CompletableFuture.allOf(blobPartInputStreamFutures.toArray(CompletableFuture[]::new))
.whenComplete((unused, partThrowable) -> {
if (partThrowable == null) {
listener.onResponse(
new ReadContext(
blobSize,
blobPartInputStreamFutures.stream().map(CompletableFuture::join).collect(Collectors.toList()),
blobChecksum
)
);
} else {
Exception ex = partThrowable.getCause() instanceof Exception
? (Exception) partThrowable.getCause()
: new Exception(partThrowable.getCause());
listener.onFailure(ex);
}
});
});
} catch (Exception ex) {
listener.onFailure(SdkException.create("Error occurred while fetching blob parts from the repository", ex));
Expand Down Expand Up @@ -685,41 +700,47 @@ static Tuple<Long, Long> numberOfMultiparts(final long totalSize, final long par
* the stream and its related metadata.
* @param s3AsyncClient Async client to be utilized to fetch the object part
* @param bucketName Name of the S3 bucket
* @param blobName Identifier of the blob for which the parts will be fetched
* @param partNumber Part number for the blob to be retrieved
* @param blobKey Identifier of the blob for which the parts will be fetched
* @param partNumber Optional part number for the blob to be retrieved
* @return A future of {@link InputStreamContainer} containing the stream and stream metadata.
*/
CompletableFuture<InputStreamContainer> getBlobPartInputStreamContainer(
S3AsyncClient s3AsyncClient,
String bucketName,
String blobName,
int partNumber
String blobKey,
@Nullable Integer partNumber
) {
final GetObjectRequest.Builder getObjectRequestBuilder = GetObjectRequest.builder()
.bucket(bucketName)
.key(blobName)
.partNumber(partNumber);
final boolean isMultipartObject = partNumber != null;
final GetObjectRequest.Builder getObjectRequestBuilder = GetObjectRequest.builder().bucket(bucketName).key(blobKey);

if (isMultipartObject) {
getObjectRequestBuilder.partNumber(partNumber);
}

return SocketAccess.doPrivileged(
() -> s3AsyncClient.getObject(getObjectRequestBuilder.build(), AsyncResponseTransformer.toBlockingInputStream())
.thenApply(S3BlobContainer::transformResponseToInputStreamContainer)
.thenApply(response -> transformResponseToInputStreamContainer(response, isMultipartObject))
);
}

/**
* Transforms the stream response object from S3 into an {@link InputStreamContainer}
* @param streamResponse Response stream object from S3
* @param isMultipartObject Flag to denote a multipart object response
* @return {@link InputStreamContainer} containing the stream and stream metadata
*/
// Package-Private for testing.
static InputStreamContainer transformResponseToInputStreamContainer(ResponseInputStream<GetObjectResponse> streamResponse) {
static InputStreamContainer transformResponseToInputStreamContainer(
ResponseInputStream<GetObjectResponse> streamResponse,
boolean isMultipartObject
) {
final GetObjectResponse getObjectResponse = streamResponse.response();
final String contentRange = getObjectResponse.contentRange();
final Long contentLength = getObjectResponse.contentLength();
if (contentRange == null || contentLength == null) {
if ((isMultipartObject && contentRange == null) || contentLength == null) {
andrross marked this conversation as resolved.
Show resolved Hide resolved
throw SdkException.builder().message("Failed to fetch required metadata for blob part").build();
}
final Long offset = HttpRangeUtils.getStartOffsetFromRangeHeader(getObjectResponse.contentRange());
final long offset = isMultipartObject ? HttpRangeUtils.getStartOffsetFromRangeHeader(getObjectResponse.contentRange()) : 0L;
return new InputStreamContainer(streamResponse, getObjectResponse.contentLength(), offset);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@
import org.opensearch.common.io.InputStreamContainer;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.unit.ByteSizeUnit;
import org.opensearch.repositories.s3.async.AsyncTransferManager;
import org.opensearch.test.OpenSearchTestCase;

import java.io.ByteArrayInputStream;
Expand All @@ -100,7 +99,6 @@
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
Expand Down Expand Up @@ -919,7 +917,7 @@ public void testListBlobsByPrefixInLexicographicOrderWithLimitGreaterThanNumberO
testListBlobsByPrefixInLexicographicOrder(12, 2, BlobContainer.BlobNameSortOrder.LEXICOGRAPHIC);
}

public void testReadBlobAsync() throws Exception {
public void testReadBlobAsyncMultiPart() throws Exception {
final String bucketName = randomAlphaOfLengthBetween(1, 10);
final String blobName = randomAlphaOfLengthBetween(1, 10);
final String checksum = randomAlphaOfLength(10);
Expand All @@ -932,19 +930,14 @@ public void testReadBlobAsync() throws Exception {
final AmazonAsyncS3Reference amazonAsyncS3Reference = new AmazonAsyncS3Reference(
AmazonAsyncS3WithCredentials.create(s3AsyncClient, s3AsyncClient, null)
);
final AsyncTransferManager asyncTransferManager = new AsyncTransferManager(
10000L,
mock(ExecutorService.class),
mock(ExecutorService.class)
);

final S3BlobStore blobStore = mock(S3BlobStore.class);
final BlobPath blobPath = new BlobPath();

when(blobStore.bucket()).thenReturn(bucketName);
when(blobStore.getStatsMetricPublisher()).thenReturn(new StatsMetricPublisher());
when(blobStore.serverSideEncryption()).thenReturn(false);
when(blobStore.asyncClientReference()).thenReturn(amazonAsyncS3Reference);
when(blobStore.getAsyncTransferManager()).thenReturn(asyncTransferManager);

CompletableFuture<GetObjectAttributesResponse> getObjectAttributesResponseCompletableFuture = new CompletableFuture<>();
getObjectAttributesResponseCompletableFuture.complete(
Expand Down Expand Up @@ -984,6 +977,60 @@ public void testReadBlobAsync() throws Exception {
}
}

public void testReadBlobAsyncSinglePart() throws Exception {
final String bucketName = randomAlphaOfLengthBetween(1, 10);
final String blobName = randomAlphaOfLengthBetween(1, 10);
final String checksum = randomAlphaOfLength(10);

final int objectSize = 100;

final S3AsyncClient s3AsyncClient = mock(S3AsyncClient.class);
final AmazonAsyncS3Reference amazonAsyncS3Reference = new AmazonAsyncS3Reference(
AmazonAsyncS3WithCredentials.create(s3AsyncClient, s3AsyncClient, null)
);
final S3BlobStore blobStore = mock(S3BlobStore.class);
final BlobPath blobPath = new BlobPath();

when(blobStore.bucket()).thenReturn(bucketName);
when(blobStore.getStatsMetricPublisher()).thenReturn(new StatsMetricPublisher());
when(blobStore.serverSideEncryption()).thenReturn(false);
when(blobStore.asyncClientReference()).thenReturn(amazonAsyncS3Reference);

CompletableFuture<GetObjectAttributesResponse> getObjectAttributesResponseCompletableFuture = new CompletableFuture<>();
getObjectAttributesResponseCompletableFuture.complete(
GetObjectAttributesResponse.builder()
.checksum(Checksum.builder().checksumCRC32(checksum).build())
.objectSize((long) objectSize)
.build()
);
when(s3AsyncClient.getObjectAttributes(any(GetObjectAttributesRequest.class))).thenReturn(
getObjectAttributesResponseCompletableFuture
);

mockObjectResponse(s3AsyncClient, bucketName, blobName, objectSize);

CountDownLatch countDownLatch = new CountDownLatch(1);
CountingCompletionListener<ReadContext> readContextActionListener = new CountingCompletionListener<>();
LatchedActionListener<ReadContext> listener = new LatchedActionListener<>(readContextActionListener, countDownLatch);

final S3BlobContainer blobContainer = new S3BlobContainer(blobPath, blobStore);
blobContainer.readBlobAsync(blobName, listener);
countDownLatch.await();

assertEquals(1, readContextActionListener.getResponseCount());
assertEquals(0, readContextActionListener.getFailureCount());
ReadContext readContext = readContextActionListener.getResponse();
assertEquals(1, readContext.getNumberOfParts());
assertEquals(checksum, readContext.getBlobChecksum());
assertEquals(objectSize, readContext.getBlobSize());

InputStreamContainer inputStreamContainer = readContext.getPartStreams().stream().findFirst().get();
assertEquals(objectSize, inputStreamContainer.getContentLength());
assertEquals(0, inputStreamContainer.getOffset());
assertEquals(objectSize, inputStreamContainer.getInputStream().readAllBytes().length);

}

public void testReadBlobAsyncFailure() throws Exception {
final String bucketName = randomAlphaOfLengthBetween(1, 10);
final String blobName = randomAlphaOfLengthBetween(1, 10);
Expand All @@ -996,19 +1043,14 @@ public void testReadBlobAsyncFailure() throws Exception {
final AmazonAsyncS3Reference amazonAsyncS3Reference = new AmazonAsyncS3Reference(
AmazonAsyncS3WithCredentials.create(s3AsyncClient, s3AsyncClient, null)
);
final AsyncTransferManager asyncTransferManager = new AsyncTransferManager(
10000L,
mock(ExecutorService.class),
mock(ExecutorService.class)
);

final S3BlobStore blobStore = mock(S3BlobStore.class);
final BlobPath blobPath = new BlobPath();

when(blobStore.bucket()).thenReturn(bucketName);
when(blobStore.getStatsMetricPublisher()).thenReturn(new StatsMetricPublisher());
when(blobStore.serverSideEncryption()).thenReturn(false);
when(blobStore.asyncClientReference()).thenReturn(amazonAsyncS3Reference);
when(blobStore.getAsyncTransferManager()).thenReturn(asyncTransferManager);

CompletableFuture<GetObjectAttributesResponse> getObjectAttributesResponseCompletableFuture = new CompletableFuture<>();
getObjectAttributesResponseCompletableFuture.complete(
Expand Down Expand Up @@ -1071,7 +1113,7 @@ public void testGetBlobPartInputStream() throws Exception {
final String blobName = randomAlphaOfLengthBetween(1, 10);
final String bucketName = randomAlphaOfLengthBetween(1, 10);
final long contentLength = 10L;
final String contentRange = "bytes 0-10/100";
final String contentRange = "bytes 10-20/100";
final InputStream inputStream = ResponseInputStream.nullInputStream();

final S3AsyncClient s3AsyncClient = mock(S3AsyncClient.class);
Expand All @@ -1095,9 +1137,17 @@ public void testGetBlobPartInputStream() throws Exception {
)
).thenReturn(getObjectPartResponse);

// Header based offset in case of a multi part object request
InputStreamContainer inputStreamContainer = blobContainer.getBlobPartInputStreamContainer(s3AsyncClient, bucketName, blobName, 0)
.get();

assertEquals(10, inputStreamContainer.getOffset());
assertEquals(contentLength, inputStreamContainer.getContentLength());
assertEquals(inputStream.available(), inputStreamContainer.getInputStream().available());

// 0 offset in case of a single part object request
inputStreamContainer = blobContainer.getBlobPartInputStreamContainer(s3AsyncClient, bucketName, blobName, null).get();

assertEquals(0, inputStreamContainer.getOffset());
assertEquals(contentLength, inputStreamContainer.getContentLength());
assertEquals(inputStream.available(), inputStreamContainer.getInputStream().available());
Expand All @@ -1108,28 +1158,65 @@ public void testTransformResponseToInputStreamContainer() throws Exception {
final long contentLength = 10L;
final InputStream inputStream = ResponseInputStream.nullInputStream();

final S3AsyncClient s3AsyncClient = mock(S3AsyncClient.class);

GetObjectResponse getObjectResponse = GetObjectResponse.builder().contentLength(contentLength).build();

// Exception when content range absent for multipart object
ResponseInputStream<GetObjectResponse> responseInputStreamNoRange = new ResponseInputStream<>(getObjectResponse, inputStream);
assertThrows(SdkException.class, () -> S3BlobContainer.transformResponseToInputStreamContainer(responseInputStreamNoRange));
assertThrows(SdkException.class, () -> S3BlobContainer.transformResponseToInputStreamContainer(responseInputStreamNoRange, true));

// No exception when content range absent for single part object
ResponseInputStream<GetObjectResponse> responseInputStreamNoRangeSinglePart = new ResponseInputStream<>(
getObjectResponse,
inputStream
);
InputStreamContainer inputStreamContainer = S3BlobContainer.transformResponseToInputStreamContainer(
responseInputStreamNoRangeSinglePart,
false
);
assertEquals(contentLength, inputStreamContainer.getContentLength());
assertEquals(0, inputStreamContainer.getOffset());

// Exception when length is absent
getObjectResponse = GetObjectResponse.builder().contentRange(contentRange).build();
ResponseInputStream<GetObjectResponse> responseInputStreamNoContentLength = new ResponseInputStream<>(
getObjectResponse,
inputStream
);
assertThrows(SdkException.class, () -> S3BlobContainer.transformResponseToInputStreamContainer(responseInputStreamNoContentLength));
assertThrows(
SdkException.class,
() -> S3BlobContainer.transformResponseToInputStreamContainer(responseInputStreamNoContentLength, true)
);

// No exception when range and length both are present
getObjectResponse = GetObjectResponse.builder().contentRange(contentRange).contentLength(contentLength).build();
ResponseInputStream<GetObjectResponse> responseInputStream = new ResponseInputStream<>(getObjectResponse, inputStream);
InputStreamContainer inputStreamContainer = S3BlobContainer.transformResponseToInputStreamContainer(responseInputStream);
inputStreamContainer = S3BlobContainer.transformResponseToInputStreamContainer(responseInputStream, true);
assertEquals(contentLength, inputStreamContainer.getContentLength());
assertEquals(0, inputStreamContainer.getOffset());
assertEquals(inputStream.available(), inputStreamContainer.getInputStream().available());
}

private void mockObjectResponse(S3AsyncClient s3AsyncClient, String bucketName, String blobName, int objectSize) {

final InputStream inputStream = new ByteArrayInputStream(randomByteArrayOfLength(objectSize));

GetObjectResponse getObjectResponse = GetObjectResponse.builder().contentLength((long) objectSize).build();

CompletableFuture<ResponseInputStream<GetObjectResponse>> getObjectPartResponse = new CompletableFuture<>();
ResponseInputStream<GetObjectResponse> responseInputStream = new ResponseInputStream<>(getObjectResponse, inputStream);
getObjectPartResponse.complete(responseInputStream);

GetObjectRequest getObjectRequest = GetObjectRequest.builder().bucket(bucketName).key(blobName).build();

when(
s3AsyncClient.getObject(
eq(getObjectRequest),
ArgumentMatchers.<AsyncResponseTransformer<GetObjectResponse, ResponseInputStream<GetObjectResponse>>>any()
)
).thenReturn(getObjectPartResponse);

}

private void mockObjectPartResponse(
S3AsyncClient s3AsyncClient,
String bucketName,
Expand Down