diff --git a/server/src/main/java/org/opensearch/index/compositeindex/datacube/startree/utils/StarTreeQueryHelper.java b/server/src/main/java/org/opensearch/index/compositeindex/datacube/startree/utils/StarTreeQueryHelper.java index e538be5d5bece..e46cf6f56b36e 100644 --- a/server/src/main/java/org/opensearch/index/compositeindex/datacube/startree/utils/StarTreeQueryHelper.java +++ b/server/src/main/java/org/opensearch/index/compositeindex/datacube/startree/utils/StarTreeQueryHelper.java @@ -152,7 +152,7 @@ private static MetricStat validateStarTreeMetricSupport( MetricStat metricStat = ((MetricAggregatorFactory) aggregatorFactory).getMetricStat(); field = ((MetricAggregatorFactory) aggregatorFactory).getField(); - if (supportedMetrics.containsKey(field) && supportedMetrics.get(field).contains(metricStat)) { + if (field != null && supportedMetrics.containsKey(field) && supportedMetrics.get(field).contains(metricStat)) { return metricStat; } } diff --git a/server/src/main/java/org/opensearch/search/aggregations/support/ValuesSourceAggregatorFactory.java b/server/src/main/java/org/opensearch/search/aggregations/support/ValuesSourceAggregatorFactory.java index d862b2c2784de..41344fd06cbbc 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/support/ValuesSourceAggregatorFactory.java +++ b/server/src/main/java/org/opensearch/search/aggregations/support/ValuesSourceAggregatorFactory.java @@ -104,6 +104,6 @@ public String getStatsSubtype() { } public String getField() { - return config.fieldContext().field(); + return config.fieldContext() != null ? config.fieldContext().field() : null; } } diff --git a/server/src/test/java/org/opensearch/search/aggregations/startree/MetricAggregatorTests.java b/server/src/test/java/org/opensearch/search/aggregations/startree/MetricAggregatorTests.java index 12e83cbbadd5d..05f48eb9243af 100644 --- a/server/src/test/java/org/opensearch/search/aggregations/startree/MetricAggregatorTests.java +++ b/server/src/test/java/org/opensearch/search/aggregations/startree/MetricAggregatorTests.java @@ -28,18 +28,27 @@ import org.opensearch.common.lucene.Lucene; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.FeatureFlags; +import org.opensearch.common.util.MockBigArrays; +import org.opensearch.common.util.MockPageCacheRecycler; +import org.opensearch.core.indices.breaker.CircuitBreakerService; +import org.opensearch.core.indices.breaker.NoneCircuitBreakerService; import org.opensearch.index.codec.composite.CompositeIndexFieldInfo; import org.opensearch.index.codec.composite.CompositeIndexReader; import org.opensearch.index.codec.composite.composite912.Composite912Codec; import org.opensearch.index.codec.composite912.datacube.startree.StarTreeDocValuesFormatTests; import org.opensearch.index.compositeindex.datacube.Dimension; +import org.opensearch.index.compositeindex.datacube.Metric; +import org.opensearch.index.compositeindex.datacube.MetricStat; import org.opensearch.index.compositeindex.datacube.NumericDimension; import org.opensearch.index.mapper.MappedFieldType; import org.opensearch.index.mapper.MapperService; import org.opensearch.index.mapper.NumberFieldMapper; import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryShardContext; import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.aggregations.AggregatorFactories; +import org.opensearch.search.aggregations.AggregatorFactory; import org.opensearch.search.aggregations.AggregatorTestCase; import org.opensearch.search.aggregations.InternalAggregation; import org.opensearch.search.aggregations.metrics.AvgAggregationBuilder; @@ -49,14 +58,17 @@ import org.opensearch.search.aggregations.metrics.InternalSum; import org.opensearch.search.aggregations.metrics.InternalValueCount; import org.opensearch.search.aggregations.metrics.MaxAggregationBuilder; +import org.opensearch.search.aggregations.metrics.MetricAggregatorFactory; import org.opensearch.search.aggregations.metrics.MinAggregationBuilder; import org.opensearch.search.aggregations.metrics.SumAggregationBuilder; import org.opensearch.search.aggregations.metrics.ValueCountAggregationBuilder; +import org.opensearch.search.aggregations.support.ValuesSourceAggregatorFactory; import org.junit.After; import org.junit.Before; import java.io.IOException; import java.util.ArrayList; +import java.util.Collections; import java.util.LinkedList; import java.util.List; import java.util.Random; @@ -69,6 +81,8 @@ import static org.opensearch.search.aggregations.AggregationBuilders.min; import static org.opensearch.search.aggregations.AggregationBuilders.sum; import static org.opensearch.test.InternalAggregationTestCase.DEFAULT_MAX_BUCKETS; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; public class MetricAggregatorTests extends AggregatorTestCase { @@ -267,6 +281,110 @@ public void testStarTreeDocValues() throws IOException { ); } + CircuitBreakerService circuitBreakerService = new NoneCircuitBreakerService(); + + QueryShardContext queryShardContext = queryShardContextMock( + indexSearcher, + mapperServiceMock(), + createIndexSettings(), + circuitBreakerService, + new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), circuitBreakerService).withCircuitBreaking() + ); + + MetricAggregatorFactory aggregatorFactory = mock(MetricAggregatorFactory.class); + when(aggregatorFactory.getSubFactories()).thenReturn(AggregatorFactories.EMPTY); + when(aggregatorFactory.getField()).thenReturn(FIELD_NAME); + when(aggregatorFactory.getMetricStat()).thenReturn(MetricStat.SUM); + + // Case when field and metric type in aggregation are fully supported by star tree. + testCase( + indexSearcher, + query, + queryBuilder, + sumAggregationBuilder, + starTree, + supportedDimensions, + List.of(new Metric(FIELD_NAME, List.of(MetricStat.SUM, MetricStat.MAX, MetricStat.MIN, MetricStat.AVG))), + verifyAggregation(InternalSum::getValue), + aggregatorFactory, + true + ); + + // Case when the field is not supported by star tree + SumAggregationBuilder invalidFieldSumAggBuilder = sum("_name").field("hello"); + testCase( + indexSearcher, + query, + queryBuilder, + invalidFieldSumAggBuilder, + starTree, + supportedDimensions, + Collections.emptyList(), + verifyAggregation(InternalSum::getValue), + invalidFieldSumAggBuilder.build(queryShardContext, null), + false // Invalid fields will return null StarTreeQueryContext which will not cause early termination by leaf collector + ); + + // Case when metric type in aggregation is not supported by star tree but the field is supported. + testCase( + indexSearcher, + query, + queryBuilder, + sumAggregationBuilder, + starTree, + supportedDimensions, + List.of(new Metric(FIELD_NAME, List.of(MetricStat.MAX, MetricStat.MIN, MetricStat.AVG))), + verifyAggregation(InternalSum::getValue), + aggregatorFactory, + false + ); + + // Case when field is not present in supported metrics + testCase( + indexSearcher, + query, + queryBuilder, + sumAggregationBuilder, + starTree, + supportedDimensions, + List.of(new Metric("hello", List.of(MetricStat.MAX, MetricStat.MIN, MetricStat.AVG))), + verifyAggregation(InternalSum::getValue), + aggregatorFactory, + false + ); + + AggregatorFactories aggregatorFactories = mock(AggregatorFactories.class); + when(aggregatorFactories.getFactories()).thenReturn(new AggregatorFactory[] { mock(MetricAggregatorFactory.class) }); + when(aggregatorFactory.getSubFactories()).thenReturn(aggregatorFactories); + + // Case when sub aggregations are present + testCase( + indexSearcher, + query, + queryBuilder, + sumAggregationBuilder, + starTree, + supportedDimensions, + List.of(new Metric("hello", List.of(MetricStat.MAX, MetricStat.MIN, MetricStat.AVG))), + verifyAggregation(InternalSum::getValue), + aggregatorFactory, + false + ); + + // Case when aggregation factory is not metric aggregation + testCase( + indexSearcher, + query, + queryBuilder, + sumAggregationBuilder, + starTree, + supportedDimensions, + List.of(new Metric("hello", List.of(MetricStat.MAX, MetricStat.MIN, MetricStat.AVG))), + verifyAggregation(InternalSum::getValue), + mock(ValuesSourceAggregatorFactory.class), + false + ); + ir.close(); directory.close(); } @@ -287,6 +405,21 @@ private void testC CompositeIndexFieldInfo starTree, List supportedDimensions, BiConsumer verify + ) throws IOException { + testCase(searcher, query, queryBuilder, aggBuilder, starTree, supportedDimensions, Collections.emptyList(), verify, null, true); + } + + private void testCase( + IndexSearcher searcher, + Query query, + QueryBuilder queryBuilder, + T aggBuilder, + CompositeIndexFieldInfo starTree, + List supportedDimensions, + List supportedMetrics, + BiConsumer verify, + AggregatorFactory aggregatorFactory, + boolean assertCollectorEarlyTermination ) throws IOException { V starTreeAggregation = searchAndReduceStarTree( createIndexSettings(), @@ -296,8 +429,11 @@ private void testC aggBuilder, starTree, supportedDimensions, + supportedMetrics, DEFAULT_MAX_BUCKETS, false, + aggregatorFactory, + assertCollectorEarlyTermination, DEFAULT_MAPPED_FIELD ); V expectedAggregation = searchAndReduceStarTree( @@ -308,8 +444,11 @@ private void testC aggBuilder, null, null, + null, DEFAULT_MAX_BUCKETS, false, + aggregatorFactory, + assertCollectorEarlyTermination, DEFAULT_MAPPED_FIELD ); verify.accept(expectedAggregation, starTreeAggregation); diff --git a/server/src/test/java/org/opensearch/search/aggregations/startree/StarTreeFilterTests.java b/server/src/test/java/org/opensearch/search/aggregations/startree/StarTreeFilterTests.java index b03cb5ac7bb9d..c1cb19b9576e4 100644 --- a/server/src/test/java/org/opensearch/search/aggregations/startree/StarTreeFilterTests.java +++ b/server/src/test/java/org/opensearch/search/aggregations/startree/StarTreeFilterTests.java @@ -87,7 +87,8 @@ public void testStarTreeFilterWithDocsInSVDFieldButNoStarNode() throws IOExcepti testStarTreeFilter(10, false); } - private void testStarTreeFilter(int maxLeafDoc, boolean skipStarNodeCreationForSDVDimension) throws IOException { + private Directory createStarTreeIndex(int maxLeafDoc, boolean skipStarNodeCreationForSDVDimension, List docs) + throws IOException { Directory directory = newDirectory(); IndexWriterConfig conf = newIndexWriterConfig(null); conf.setCodec(getCodec(maxLeafDoc, skipStarNodeCreationForSDVDimension)); @@ -95,7 +96,6 @@ private void testStarTreeFilter(int maxLeafDoc, boolean skipStarNodeCreationForS RandomIndexWriter iw = new RandomIndexWriter(random(), directory, conf); int totalDocs = 100; - List docs = new ArrayList<>(); for (int i = 0; i < totalDocs; i++) { Document doc = new Document(); doc.add(new SortedNumericDocValuesField(SNDV, i)); @@ -110,6 +110,15 @@ private void testStarTreeFilter(int maxLeafDoc, boolean skipStarNodeCreationForS } iw.forceMerge(1); iw.close(); + return directory; + } + + private void testStarTreeFilter(int maxLeafDoc, boolean skipStarNodeCreationForSDVDimension) throws IOException { + List docs = new ArrayList<>(); + + Directory directory = createStarTreeIndex(maxLeafDoc, skipStarNodeCreationForSDVDimension, docs); + + int totalDocs = docs.size(); DirectoryReader ir = DirectoryReader.open(directory); initValuesSourceRegistry(); diff --git a/test/framework/src/main/java/org/opensearch/search/aggregations/AggregatorTestCase.java b/test/framework/src/main/java/org/opensearch/search/aggregations/AggregatorTestCase.java index e1728c4476699..27142b298db52 100644 --- a/test/framework/src/main/java/org/opensearch/search/aggregations/AggregatorTestCase.java +++ b/test/framework/src/main/java/org/opensearch/search/aggregations/AggregatorTestCase.java @@ -93,6 +93,7 @@ import org.opensearch.index.cache.query.DisabledQueryCache; import org.opensearch.index.codec.composite.CompositeIndexFieldInfo; import org.opensearch.index.compositeindex.datacube.Dimension; +import org.opensearch.index.compositeindex.datacube.Metric; import org.opensearch.index.compositeindex.datacube.startree.utils.StarTreeQueryHelper; import org.opensearch.index.fielddata.IndexFieldData; import org.opensearch.index.fielddata.IndexFieldDataCache; @@ -348,7 +349,9 @@ protected CountingAggregator createCountingAggregator( IndexSettings indexSettings, CompositeIndexFieldInfo starTree, List supportedDimensions, + List supportedMetrics, MultiBucketConsumer bucketConsumer, + AggregatorFactory aggregatorFactory, MappedFieldType... fieldTypes ) throws IOException { SearchContext searchContext; @@ -360,7 +363,9 @@ protected CountingAggregator createCountingAggregator( queryBuilder, starTree, supportedDimensions, + supportedMetrics, bucketConsumer, + aggregatorFactory, fieldTypes ); } else { @@ -389,7 +394,9 @@ protected SearchContext createSearchContextWithStarTreeContext( QueryBuilder queryBuilder, CompositeIndexFieldInfo starTree, List supportedDimensions, + List supportedMetrics, MultiBucketConsumer bucketConsumer, + AggregatorFactory aggregatorFactory, MappedFieldType... fieldTypes ) throws IOException { SearchContext searchContext = createSearchContext( @@ -406,7 +413,12 @@ protected SearchContext createSearchContextWithStarTreeContext( AggregatorFactories aggregatorFactories = mock(AggregatorFactories.class); when(searchContext.aggregations()).thenReturn(searchContextAggregations); when(searchContextAggregations.factories()).thenReturn(aggregatorFactories); - when(aggregatorFactories.getFactories()).thenReturn(new AggregatorFactory[] {}); + + if (aggregatorFactory != null) { + when(aggregatorFactories.getFactories()).thenReturn(new AggregatorFactory[] { aggregatorFactory }); + } else { + when(aggregatorFactories.getFactories()).thenReturn(new AggregatorFactory[] {}); + } CompositeDataCubeFieldType compositeMappedFieldType = mock(CompositeDataCubeFieldType.class); when(compositeMappedFieldType.name()).thenReturn(starTree.getField()); @@ -414,6 +426,7 @@ protected SearchContext createSearchContextWithStarTreeContext( Set compositeFieldTypes = Set.of(compositeMappedFieldType); when((compositeMappedFieldType).getDimensions()).thenReturn(supportedDimensions); + when((compositeMappedFieldType).getMetrics()).thenReturn(supportedMetrics); MapperService mapperService = mock(MapperService.class); when(mapperService.getCompositeFieldTypes()).thenReturn(compositeFieldTypes); when(searchContext.mapperService()).thenReturn(mapperService); @@ -740,8 +753,11 @@ protected A searchAndReduc AggregationBuilder builder, CompositeIndexFieldInfo compositeIndexFieldInfo, List supportedDimensions, + List supportedMetrics, int maxBucket, boolean hasNested, + AggregatorFactory aggregatorFactory, + boolean assertCollectorEarlyTermination, MappedFieldType... fieldTypes ) throws IOException { query = query.rewrite(searcher); @@ -764,7 +780,9 @@ protected A searchAndReduc indexSettings, compositeIndexFieldInfo, supportedDimensions, + supportedMetrics, bucketConsumer, + aggregatorFactory, fieldTypes ); @@ -772,7 +790,7 @@ protected A searchAndReduc searcher.search(query, countingAggregator); countingAggregator.postCollection(); aggs.add(countingAggregator.buildTopLevel()); - if (compositeIndexFieldInfo != null) { + if (compositeIndexFieldInfo != null && assertCollectorEarlyTermination) { assertEquals(0, countingAggregator.collectCounter.get()); }