Skip to content

Commit

Permalink
addressed PR comments - modified Buffer and parser implementation
Browse files Browse the repository at this point in the history
Signed-off-by: Jeremy Michael <[email protected]>
  • Loading branch information
Jeremy Michael committed Jan 4, 2025
1 parent e5f2f79 commit be297ca
Show file tree
Hide file tree
Showing 15 changed files with 109 additions and 139 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ private Arn getArn() {
try {
return Arn.fromString(awsStsRoleArn);
} catch (final Exception e) {
throw new IllegalArgumentException(String.format("Invalid ARN format for awsStsRoleArn. Check the format of %s", awsStsRoleArn));
}
throw new IllegalArgumentException(String.format("The value provided for sts_role_arn is not a valid AWS ARN. Provided value: %s", awsStsRoleArn)); }
}

public String getAwsStsRoleArn() {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.dataprepper.plugins.source.sqs;

import com.fasterxml.jackson.annotation.JsonProperty;
Expand All @@ -14,10 +19,10 @@
public class QueueConfig {

private static final Integer DEFAULT_MAXIMUM_MESSAGES = null;
private static final Boolean DEFAULT_VISIBILITY_DUPLICATE_PROTECTION = false;
private static final boolean DEFAULT_VISIBILITY_DUPLICATE_PROTECTION = false;
private static final Duration DEFAULT_VISIBILITY_TIMEOUT_SECONDS = null;
private static final Duration DEFAULT_VISIBILITY_DUPLICATE_PROTECTION_TIMEOUT = Duration.ofHours(2);
private static final Duration DEFAULT_WAIT_TIME_SECONDS = Duration.ofSeconds(20);
private static final Duration DEFAULT_WAIT_TIME_SECONDS = null;
private static final Duration DEFAULT_POLL_DELAY_SECONDS = Duration.ofSeconds(0);
static final int DEFAULT_NUMBER_OF_WORKERS = 1;

Expand Down Expand Up @@ -45,7 +50,7 @@ public class QueueConfig {

@JsonProperty("visibility_duplication_protection")
@NotNull
private Boolean visibilityDuplicateProtection = DEFAULT_VISIBILITY_DUPLICATE_PROTECTION;
private boolean visibilityDuplicateProtection = DEFAULT_VISIBILITY_DUPLICATE_PROTECTION;

@JsonProperty("visibility_duplicate_protection_timeout")
@DurationMin(seconds = 30)
Expand Down Expand Up @@ -73,7 +78,7 @@ public Duration getVisibilityTimeout() {
return visibilityTimeout;
}

public Boolean getVisibilityDuplicateProtection() {
public boolean getVisibilityDuplicateProtection() {
return visibilityDuplicateProtection;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,76 +1,51 @@
package org.opensearch.dataprepper.plugins.source.sqs;

/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.opensearch.dataprepper.buffer.common.BufferAccumulator;
import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSet;
import org.opensearch.dataprepper.model.buffer.Buffer;
import org.opensearch.dataprepper.model.event.Event;
import org.opensearch.dataprepper.model.event.EventMetadata;
import org.opensearch.dataprepper.model.event.JacksonEvent;
import org.opensearch.dataprepper.model.record.Record;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.services.sqs.model.Message;
import com.fasterxml.jackson.databind.node.ObjectNode;
import java.time.Instant;
import java.util.Objects;
import software.amazon.awssdk.services.sqs.model.MessageAttributeValue;
import software.amazon.awssdk.services.sqs.model.MessageSystemAttributeName;

import java.util.Collections;
import java.util.Map;

/**
* Implements the SqsMessageHandler to read and parse SQS messages generically and push to buffer.
*/
public class RawSqsMessageHandler implements SqsMessageHandler {

private static final Logger LOG = LoggerFactory.getLogger(RawSqsMessageHandler.class);
private static final ObjectMapper objectMapper = new ObjectMapper();

/**
* Processes the SQS message, attempting to parse it as JSON, and adds it to the buffer.
*
* @param message - the SQS message for processing
* @param url - the SQS queue url
* @param bufferAccumulator - the buffer accumulator
* @param acknowledgementSet - the acknowledgement set for end-to-end acknowledgements
*/
@Override
public void handleMessage(final Message message,
final String url,
final BufferAccumulator<Record<Event>> bufferAccumulator,
final String url,
final Buffer<Record<Event>> buffer,
final int bufferTimeoutMillis,
final AcknowledgementSet acknowledgementSet) {
try {
ObjectNode dataNode = objectMapper.createObjectNode();
dataNode.set("message", parseMessageBody(message.body()));
dataNode.put("queueUrl", url);

Instant now = Instant.now();
int unixTimestamp = (int) now.getEpochSecond();
dataNode.put("sentTimestamp", unixTimestamp);

final Record<Event> event = new Record<Event>(JacksonEvent.builder()
.withEventType("sqs-event")
.withData(dataNode)
.build());

if (Objects.nonNull(acknowledgementSet)) {
acknowledgementSet.add(event.getData());
final Map<MessageSystemAttributeName, String> systemAttributes = message.attributes();
final Map<String, MessageAttributeValue> customAttributes = message.messageAttributes();
final Event event = JacksonEvent.builder()
.withEventType("DOCUMENT")
.withData(Collections.singletonMap("message", message.body()))
.build();
final EventMetadata eventMetadata = event.getMetadata();
eventMetadata.setAttribute("url", url);
final String sentTimestamp = systemAttributes.get(MessageSystemAttributeName.SENT_TIMESTAMP);
eventMetadata.setAttribute("SentTimestamp", sentTimestamp);
for (Map.Entry<String, MessageAttributeValue> entry : customAttributes.entrySet()) {
eventMetadata.setAttribute(entry.getKey(), entry.getValue().stringValue());
}

bufferAccumulator.add(new Record<>(event.getData()));

if (acknowledgementSet != null) {
acknowledgementSet.add(event);
}
buffer.write(new Record<>(event), bufferTimeoutMillis);
} catch (Exception e) {
LOG.error("Error processing SQS message: {}", e.getMessage(), e);
throw new RuntimeException(e);
}
}

JsonNode parseMessageBody(String messageBody) {
try {
return objectMapper.readTree(messageBody);
} catch (Exception e) {
return objectMapper.getNodeFactory().textNode(messageBody);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,13 @@

package org.opensearch.dataprepper.plugins.source.sqs;

import org.opensearch.dataprepper.buffer.common.BufferAccumulator;
import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSet;
import org.opensearch.dataprepper.model.buffer.Buffer;
import org.opensearch.dataprepper.model.event.Event;
import org.opensearch.dataprepper.model.record.Record;

import software.amazon.awssdk.services.sqs.model.Message;

import java.io.IOException;

public class SqsEventProcessor {
private final SqsMessageHandler sqsMessageHandler;
SqsEventProcessor(final SqsMessageHandler sqsMessageHandler) {
Expand All @@ -22,9 +20,10 @@ public class SqsEventProcessor {

void addSqsObject(final Message message,
final String url,
final BufferAccumulator<Record<Event>> bufferAccumulator,
final Buffer<Record<Event>> buffer,
final int bufferTimeoutmillis,
final AcknowledgementSet acknowledgementSet) throws IOException {
sqsMessageHandler.handleMessage(message, url, bufferAccumulator, acknowledgementSet);
sqsMessageHandler.handleMessage(message, url, buffer, bufferTimeoutmillis, acknowledgementSet);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
*/
package org.opensearch.dataprepper.plugins.source.sqs;

import org.opensearch.dataprepper.buffer.common.BufferAccumulator;
import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSet;
import org.opensearch.dataprepper.model.buffer.Buffer;
import org.opensearch.dataprepper.model.event.Event;
import org.opensearch.dataprepper.model.record.Record;
import software.amazon.awssdk.services.sqs.model.Message;
Expand All @@ -14,6 +14,7 @@
public interface SqsMessageHandler {
void handleMessage(final Message message,
final String url,
final BufferAccumulator<Record<Event>> bufferAccumulator,
final Buffer<Record<Event>> buffer,
final int bufferTimeoutMillis,
final AcknowledgementSet acknowledgementSet) throws IOException ;
}
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.dataprepper.plugins.source.sqs;

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,11 @@ public void start() {

sqsSourceConfig.getQueues().forEach(queueConfig -> {
String queueUrl = queueConfig.getUrl();
String queueName = queueUrl.substring(queueUrl.lastIndexOf('/') + 1);

int numWorkers = queueConfig.getNumWorkers();
ExecutorService executorService = Executors.newFixedThreadPool(
numWorkers, BackgroundThreadFactory.defaultExecutorThreadFactory("sqs-source-new-" + queueUrl));
numWorkers, BackgroundThreadFactory.defaultExecutorThreadFactory("sqs-source" + queueName));
allSqsUrlExecutorServices.add(executorService);
List<SqsWorker> workers = IntStream.range(0, numWorkers)
.mapToObj(i -> new SqsWorker(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.dataprepper.plugins.source.sqs;

import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.dataprepper.plugins.source.sqs;

import com.fasterxml.jackson.annotation.JsonProperty;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.core.exception.SdkException;
import software.amazon.awssdk.services.sqs.SqsClient;
import software.amazon.awssdk.services.sqs.model.ChangeMessageVisibilityRequest;
import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchRequest;
import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchRequestEntry;
import software.amazon.awssdk.services.sqs.model.ChangeMessageVisibilityRequest;
import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchResponse;
import software.amazon.awssdk.services.sqs.model.Message;
import software.amazon.awssdk.services.sqs.model.ReceiveMessageRequest;
import software.amazon.awssdk.services.sqs.model.ReceiveMessageResponse;
import software.amazon.awssdk.services.sqs.model.SqsException;
import software.amazon.awssdk.services.sts.model.StsException;
import org.opensearch.dataprepper.buffer.common.BufferAccumulator;
import org.opensearch.dataprepper.model.buffer.Buffer;

import java.time.Duration;
Expand Down Expand Up @@ -61,7 +61,8 @@ public class SqsWorker implements Runnable {
private final boolean endToEndAcknowledgementsEnabled;
private final AcknowledgementSetManager acknowledgementSetManager;
private volatile boolean isStopped = false;
private final BufferAccumulator<Record<Event>> bufferAccumulator;
private final Buffer<Record<Event>> buffer;
private final int bufferTimeoutMillis;
private Map<Message, Integer> messageVisibilityTimesMap;

public SqsWorker(final Buffer<Record<Event>> buffer,
Expand All @@ -79,12 +80,11 @@ public SqsWorker(final Buffer<Record<Event>> buffer,
this.acknowledgementSetManager = acknowledgementSetManager;
this.standardBackoff = backoff;
this.endToEndAcknowledgementsEnabled = sqsSourceConfig.getAcknowledgements();
this.bufferAccumulator = BufferAccumulator.create(buffer, sqsSourceConfig.getNumberOfRecordsToAccumulate(), sqsSourceConfig.getBufferTimeout());
this.buffer = buffer;
this.bufferTimeoutMillis = (int) sqsSourceConfig.getBufferTimeout().toMillis();

messageVisibilityTimesMap = new HashMap<>();

failedAttemptCount = 0;

sqsMessagesReceivedCounter = pluginMetrics.counter(SQS_MESSAGES_RECEIVED_METRIC_NAME);
sqsMessagesDeletedCounter = pluginMetrics.counter(SQS_MESSAGES_DELETED_METRIC_NAME);
sqsMessagesFailedCounter = pluginMetrics.counter(SQS_MESSAGES_FAILED_METRIC_NAME);
Expand Down Expand Up @@ -131,9 +131,11 @@ int processSqsMessages() {
private List<Message> getMessagesFromSqs() {
try {
final ReceiveMessageRequest request = createReceiveMessageRequest();
final List<Message> messages = sqsClient.receiveMessage(request).messages();
final ReceiveMessageResponse response = sqsClient.receiveMessage(request);
List<Message> messages = response.messages();
failedAttemptCount = 0;
return messages;

} catch (final SqsException | StsException e) {
LOG.error("Error reading from SQS: {}. Retrying with exponential backoff.", e.getMessage());
applyBackoff();
Expand All @@ -160,16 +162,18 @@ private void applyBackoff() {
private ReceiveMessageRequest createReceiveMessageRequest() {
ReceiveMessageRequest.Builder requestBuilder = ReceiveMessageRequest.builder()
.queueUrl(queueConfig.getUrl())
.waitTimeSeconds((int) queueConfig.getWaitTime().getSeconds());
.attributeNamesWithStrings("All")
.messageAttributeNames("All");

if (queueConfig.getWaitTime() != null) {
requestBuilder.waitTimeSeconds((int) queueConfig.getWaitTime().getSeconds());
}
if (queueConfig.getMaximumMessages() != null) {
requestBuilder.maxNumberOfMessages(queueConfig.getMaximumMessages());
}

if (queueConfig.getVisibilityTimeout() != null) {
requestBuilder.visibilityTimeout((int) queueConfig.getVisibilityTimeout().getSeconds());
}

return requestBuilder.build();
}

Expand Down Expand Up @@ -244,14 +248,6 @@ private List<DeleteMessageBatchRequestEntry> processSqsEvents(final List<Message
}
}

if (!messages.isEmpty()) {
try {
bufferAccumulator.flush();
} catch (final Exception e) {
throw new RuntimeException(e);
}
}

return deleteMessageBatchRequestEntryCollection;
}

Expand All @@ -260,7 +256,7 @@ private Optional<DeleteMessageBatchRequestEntry> processSqsObject(
final Message message,
final AcknowledgementSet acknowledgementSet) {
try {
sqsEventProcessor.addSqsObject(message, queueConfig.getUrl(), bufferAccumulator, acknowledgementSet);
sqsEventProcessor.addSqsObject(message, queueConfig.getUrl(), buffer, bufferTimeoutMillis, acknowledgementSet);
return Optional.of(buildDeleteMessageBatchRequestEntry(message));
} catch (final Exception e) {
sqsMessagesFailedCounter.increment();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,12 @@ void validateStsRoleArn_with_invalid_format_throws_exception() throws NoSuchFiel

try (final MockedStatic<Arn> arnMockedStatic = mockStatic(Arn.class)) {
arnMockedStatic.when(() -> Arn.fromString(invalidFormatArn))
.thenThrow(new IllegalArgumentException("Invalid ARN format for awsStsRoleArn. Check the format of " + invalidFormatArn));
.thenThrow(new IllegalArgumentException("The value provided for sts_role_arn is not a valid AWS ARN. Provided value: " + invalidFormatArn));

IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> {
awsAuthenticationOptions.validateStsRoleArn();
});
assertThat(exception.getMessage(), equalTo("Invalid ARN format for awsStsRoleArn. Check the format of " + invalidFormatArn));
assertThat(exception.getMessage(), equalTo("The value provided for sts_role_arn is not a valid AWS ARN. Provided value: " + invalidFormatArn));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,6 @@ void testDefaultValues() {
assertFalse(queueConfig.getVisibilityDuplicateProtection(), "Visibility duplicate protection should default to false");
assertEquals(Duration.ofHours(2), queueConfig.getVisibilityDuplicateProtectionTimeout(),
"Visibility duplicate protection timeout should default to 2 hours");
assertEquals(Duration.ofSeconds(20), queueConfig.getWaitTime(), "Wait time should default to 20 seconds");
assertNull(queueConfig.getWaitTime(), "Wait time should default to null");
}
}
Loading

0 comments on commit be297ca

Please sign in to comment.