Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 1.3] Bugfix to guard against stack overflow errors caused by very large reg-ex input #16101

Merged
merged 3 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
### Removed
### Fixed
- Update help output for _cat ([#14722](https://github.com/opensearch-project/OpenSearch/pull/14722))
- Bugfix to guard against stack overflow errors caused by very large reg-ex input ([#16101](https://github.com/opensearch-project/OpenSearch/pull/16101))

### Security

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,22 @@

package org.opensearch.search.aggregations;

import org.opensearch.OpenSearchException;
import org.opensearch.action.index.IndexRequestBuilder;
import org.opensearch.action.search.SearchPhaseExecutionException;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.search.aggregations.bucket.terms.IncludeExclude;
import org.opensearch.search.aggregations.bucket.terms.RareTermsAggregationBuilder;
import org.opensearch.search.aggregations.bucket.terms.SignificantTermsAggregationBuilder;
import org.opensearch.search.aggregations.bucket.terms.SignificantTermsAggregatorFactory;
import org.opensearch.search.aggregations.bucket.terms.Terms;
import org.opensearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
import org.opensearch.search.aggregations.bucket.terms.TermsAggregatorFactory;
import org.opensearch.test.OpenSearchIntegTestCase;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

import static org.opensearch.search.aggregations.AggregationBuilders.terms;
Expand All @@ -50,6 +59,11 @@ public class AggregationsIntegrationIT extends OpenSearchIntegTestCase {

static int numDocs;

private static final String LARGE_STRING = String.join("", Collections.nCopies(2000, "a"));
private static final String LARGE_STRING_EXCEPTION_MESSAGE = "The length of regex ["
+ LARGE_STRING.length()
+ "] used in the request has exceeded the allowed maximum";

@Override
public void setupSuiteScopeCluster() throws Exception {
assertAcked(prepareCreate("index").addMapping("type", "f", "type=keyword").get());
Expand Down Expand Up @@ -85,4 +99,51 @@ public void testScroll() {
assertEquals(numDocs, total);
}

public void testLargeRegExTermsAggregation() {
for (TermsAggregatorFactory.ExecutionMode executionMode : TermsAggregatorFactory.ExecutionMode.values()) {
TermsAggregationBuilder termsAggregation = terms("my_terms").field("f")
.includeExclude(getLargeStringInclude())
.executionHint(executionMode.toString());
runLargeStringAggregationTest(termsAggregation);
}
}

public void testLargeRegExSignificantTermsAggregation() {
for (SignificantTermsAggregatorFactory.ExecutionMode executionMode : SignificantTermsAggregatorFactory.ExecutionMode.values()) {
SignificantTermsAggregationBuilder significantTerms = new SignificantTermsAggregationBuilder("my_terms").field("f")
.includeExclude(getLargeStringInclude())
.executionHint(executionMode.toString());
runLargeStringAggregationTest(significantTerms);
}
}

public void testLargeRegExRareTermsAggregation() {
// currently this only supports "map" as an execution hint
RareTermsAggregationBuilder rareTerms = new RareTermsAggregationBuilder("my_terms").field("f")
.includeExclude(getLargeStringInclude())
.maxDocCount(2);
runLargeStringAggregationTest(rareTerms);
}

private IncludeExclude getLargeStringInclude() {
return new IncludeExclude(LARGE_STRING, null);
}

private void runLargeStringAggregationTest(AggregationBuilder aggregation) {
boolean exceptionThrown = false;
IncludeExclude include = new IncludeExclude(LARGE_STRING, null);
try {
client().prepareSearch("index").addAggregation(aggregation).get();
} catch (SearchPhaseExecutionException ex) {
exceptionThrown = true;
Throwable nestedException = ex.getCause();
assertNotNull(nestedException);
assertTrue(nestedException instanceof OpenSearchException);
assertNotNull(nestedException.getCause());
assertTrue(nestedException.getCause() instanceof IllegalArgumentException);
String actualExceptionMessage = nestedException.getCause().getMessage();
assertTrue(actualExceptionMessage.startsWith(LARGE_STRING_EXCEPTION_MESSAGE));
}
assertTrue("Exception should have been thrown", exceptionThrown);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,15 @@
import org.apache.lucene.util.automaton.Operations;
import org.apache.lucene.util.automaton.RegExp;
import org.opensearch.OpenSearchParseException;
import org.opensearch.common.Nullable;
import org.opensearch.common.ParseField;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.common.io.stream.Writeable;
import org.opensearch.common.xcontent.ToXContentFragment;
import org.opensearch.common.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentParser;
import org.opensearch.index.IndexSettings;
import org.opensearch.search.DocValueFormat;

import java.io.IOException;
Expand Down Expand Up @@ -337,19 +339,16 @@ public LongBitSet acceptedGlobalOrdinals(SortedSetDocValues globalOrdinals) thro

}

private final RegExp include, exclude;
private final String include, exclude;
private final SortedSet<BytesRef> includeValues, excludeValues;
private final int incZeroBasedPartition;
private final int incNumPartitions;

/**
* @param include The regular expression pattern for the terms to be included
* @param exclude The regular expression pattern for the terms to be excluded
* @param include The string or regular expression pattern for the terms to be included
* @param exclude The string or regular expression pattern for the terms to be excluded
*/
public IncludeExclude(RegExp include, RegExp exclude) {
if (include == null && exclude == null) {
throw new IllegalArgumentException();
}
public IncludeExclude(String include, String exclude) {
this.include = include;
this.exclude = exclude;
this.includeValues = null;
Expand All @@ -358,10 +357,6 @@ public IncludeExclude(RegExp include, RegExp exclude) {
this.incNumPartitions = 0;
}

public IncludeExclude(String include, String exclude) {
this(include == null ? null : new RegExp(include), exclude == null ? null : new RegExp(exclude));
}

/**
* @param includeValues The terms to be included
* @param excludeValues The terms to be excluded
Expand Down Expand Up @@ -412,10 +407,8 @@ public IncludeExclude(StreamInput in) throws IOException {
excludeValues = null;
incZeroBasedPartition = 0;
incNumPartitions = 0;
String includeString = in.readOptionalString();
include = includeString == null ? null : new RegExp(includeString);
String excludeString = in.readOptionalString();
exclude = excludeString == null ? null : new RegExp(excludeString);
include = in.readOptionalString();
exclude = in.readOptionalString();
return;
}
include = null;
Expand Down Expand Up @@ -447,8 +440,8 @@ public void writeTo(StreamOutput out) throws IOException {
boolean regexBased = isRegexBased();
out.writeBoolean(regexBased);
if (regexBased) {
out.writeOptionalString(include == null ? null : include.getOriginalString());
out.writeOptionalString(exclude == null ? null : exclude.getOriginalString());
out.writeOptionalString(include);
out.writeOptionalString(exclude);
} else {
boolean hasIncludes = includeValues != null;
out.writeBoolean(hasIncludes);
Expand Down Expand Up @@ -584,26 +577,54 @@ public boolean isPartitionBased() {
return incNumPartitions > 0;
}

private Automaton toAutomaton() {
Automaton a = null;
private Automaton toAutomaton(@Nullable IndexSettings indexSettings) {
int maxRegexLength = indexSettings == null ? -1 : indexSettings.getMaxRegexLength();
Automaton a;
if (include != null) {
a = include.toAutomaton();
if (include.length() > maxRegexLength) {
throw new IllegalArgumentException(
"The length of regex ["
+ include.length()
+ "] used in the request has exceeded "
+ "the allowed maximum of ["
+ maxRegexLength
+ "]. "
+ "This maximum can be set by changing the ["
+ IndexSettings.MAX_REGEX_LENGTH_SETTING.getKey()
+ "] index level setting."
);
}
a = new RegExp(include).toAutomaton();
} else if (includeValues != null) {
a = Automata.makeStringUnion(includeValues);
} else {
a = Automata.makeAnyString();
}
if (exclude != null) {
a = Operations.minus(a, exclude.toAutomaton(), Operations.DEFAULT_DETERMINIZE_WORK_LIMIT);
if (exclude.length() > maxRegexLength) {
throw new IllegalArgumentException(
"The length of regex ["
+ exclude.length()
+ "] used in the request has exceeded "
+ "the allowed maximum of ["
+ maxRegexLength
+ "]. "
+ "This maximum can be set by changing the ["
+ IndexSettings.MAX_REGEX_LENGTH_SETTING.getKey()
+ "] index level setting."
);
}
Automaton excludeAutomaton = new RegExp(exclude).toAutomaton();
a = Operations.minus(a, excludeAutomaton, Operations.DEFAULT_DETERMINIZE_WORK_LIMIT);
} else if (excludeValues != null) {
a = Operations.minus(a, Automata.makeStringUnion(excludeValues), Operations.DEFAULT_DETERMINIZE_WORK_LIMIT);
}
return a;
}

public StringFilter convertToStringFilter(DocValueFormat format) {
public StringFilter convertToStringFilter(DocValueFormat format, IndexSettings indexSettings) {
if (isRegexBased()) {
return new AutomatonBackedStringFilter(toAutomaton());
return new AutomatonBackedStringFilter(toAutomaton(indexSettings));
}
if (isPartitionBased()) {
return new PartitionedStringFilter();
Expand All @@ -624,10 +645,10 @@ private static SortedSet<BytesRef> parseForDocValues(SortedSet<BytesRef> endUser
return result;
}

public OrdinalsFilter convertToOrdinalsFilter(DocValueFormat format) {
public OrdinalsFilter convertToOrdinalsFilter(DocValueFormat format, IndexSettings indexSettings) {

if (isRegexBased()) {
return new AutomatonBackedOrdinalsFilter(toAutomaton());
return new AutomatonBackedOrdinalsFilter(toAutomaton(indexSettings));
}
if (isPartitionBased()) {
return new PartitionedOrdinalsFilter();
Expand Down Expand Up @@ -684,7 +705,7 @@ public LongFilter convertToDoubleFilter() {
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
if (include != null) {
builder.field(INCLUDE_FIELD.getPreferredName(), include.getOriginalString());
builder.field(INCLUDE_FIELD.getPreferredName(), include);
} else if (includeValues != null) {
builder.startArray(INCLUDE_FIELD.getPreferredName());
for (BytesRef value : includeValues) {
Expand All @@ -698,7 +719,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.endObject();
}
if (exclude != null) {
builder.field(EXCLUDE_FIELD.getPreferredName(), exclude.getOriginalString());
builder.field(EXCLUDE_FIELD.getPreferredName(), exclude);
} else if (excludeValues != null) {
builder.startArray(EXCLUDE_FIELD.getPreferredName());
for (BytesRef value : excludeValues) {
Expand All @@ -711,14 +732,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws

@Override
public int hashCode() {
return Objects.hash(
include == null ? null : include.getOriginalString(),
exclude == null ? null : exclude.getOriginalString(),
includeValues,
excludeValues,
incZeroBasedPartition,
incNumPartitions
);
return Objects.hash(include, exclude, includeValues, excludeValues, incZeroBasedPartition, incNumPartitions);
}

@Override
Expand All @@ -730,14 +744,8 @@ public boolean equals(Object obj) {
return false;
}
IncludeExclude other = (IncludeExclude) obj;
return Objects.equals(
include == null ? null : include.getOriginalString(),
other.include == null ? null : other.include.getOriginalString()
)
&& Objects.equals(
exclude == null ? null : exclude.getOriginalString(),
other.exclude == null ? null : other.exclude.getOriginalString()
)
return Objects.equals(include, other.include)
&& Objects.equals(exclude, other.exclude)
&& Objects.equals(includeValues, other.includeValues)
&& Objects.equals(excludeValues, other.excludeValues)
&& Objects.equals(incZeroBasedPartition, other.incZeroBasedPartition)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

import org.opensearch.common.ParseField;
import org.opensearch.common.logging.DeprecationLogger;
import org.opensearch.index.IndexSettings;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.search.DocValueFormat;
import org.opensearch.search.aggregations.Aggregator;
Expand Down Expand Up @@ -250,7 +251,10 @@ Aggregator create(
double precision,
CardinalityUpperBound cardinality
) throws IOException {
final IncludeExclude.StringFilter filter = includeExclude == null ? null : includeExclude.convertToStringFilter(format);
IndexSettings indexSettings = context.getQueryShardContext().getIndexSettings();
final IncludeExclude.StringFilter filter = includeExclude == null
? null
: includeExclude.convertToStringFilter(format, indexSettings);
return new StringRareTermsAggregator(
name,
factories,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

import org.opensearch.common.ParseField;
import org.opensearch.common.logging.DeprecationLogger;
import org.opensearch.index.IndexSettings;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.search.DocValueFormat;
Expand Down Expand Up @@ -325,8 +326,10 @@ Aggregator create(
CardinalityUpperBound cardinality,
Map<String, Object> metadata
) throws IOException {

final IncludeExclude.StringFilter filter = includeExclude == null ? null : includeExclude.convertToStringFilter(format);
IndexSettings indexSettings = aggregationContext.getQueryShardContext().getIndexSettings();
final IncludeExclude.StringFilter filter = includeExclude == null
? null
: includeExclude.convertToStringFilter(format, indexSettings);
return new MapStringTermsAggregator(
name,
factories,
Expand Down Expand Up @@ -364,8 +367,10 @@ Aggregator create(
CardinalityUpperBound cardinality,
Map<String, Object> metadata
) throws IOException {

final IncludeExclude.OrdinalsFilter filter = includeExclude == null ? null : includeExclude.convertToOrdinalsFilter(format);
IndexSettings indexSettings = aggregationContext.getQueryShardContext().getIndexSettings();
final IncludeExclude.OrdinalsFilter filter = includeExclude == null
? null
: includeExclude.convertToOrdinalsFilter(format, indexSettings);
boolean remapGlobalOrd = true;
if (cardinality == CardinalityUpperBound.ONE && factories == AggregatorFactories.EMPTY && includeExclude == null) {
/*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import org.opensearch.common.util.BigArrays;
import org.opensearch.common.util.BytesRefHash;
import org.opensearch.common.util.ObjectArray;
import org.opensearch.index.IndexSettings;
import org.opensearch.index.mapper.MappedFieldType;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryShardContext;
Expand Down Expand Up @@ -137,7 +138,10 @@ protected Aggregator createInternal(

// TODO - need to check with mapping that this is indeed a text field....

IncludeExclude.StringFilter incExcFilter = includeExclude == null ? null : includeExclude.convertToStringFilter(DocValueFormat.RAW);
IndexSettings indexSettings = searchContext.getQueryShardContext().getIndexSettings();
IncludeExclude.StringFilter incExcFilter = includeExclude == null
? null
: includeExclude.convertToStringFilter(DocValueFormat.RAW, indexSettings);

MapStringTermsAggregator.CollectorSource collectorSource = new SignificantTextCollectorSource(
queryShardContext.lookup().source(),
Expand Down
Loading
Loading