diff --git a/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQuery.java b/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQuery.java index 6a107955e..83f5dc952 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQuery.java +++ b/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQuery.java @@ -32,6 +32,11 @@ public final class NeuralSparseQuery extends Query { private final Query lowScoreTokenQuery; private final Float rescoreWindowSizeExpansion; + /** + * + * @param field + * @return + */ @Override public String toString(String field) { return "NeuralSparseQuery(" @@ -80,6 +85,14 @@ public int hashCode() { return h; } + /** + * + * @param searcher The searcher that execute the neural_sparse query. + * @param scoreMode How the produced scorers will be consumed. + * @param boost The boost that is propagated by the parent queries. + * @return The weight of currentQuery. + * @throws IOException If creteWeight failed. + */ @Override public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { return currentQuery.createWeight(searcher, scoreMode, boost); diff --git a/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java index 8af6804e0..ecd48aa23 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java @@ -452,7 +452,7 @@ private Map getFilteredScoreTokens(boolean aboveThreshold, float .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); } - public BooleanQuery buildFeatureFieldQueryFromTokens(Map tokens, String fieldName) { + private BooleanQuery buildFeatureFieldQueryFromTokens(Map tokens, String fieldName) { BooleanQuery.Builder builder = new BooleanQuery.Builder(); for (Map.Entry entry : tokens.entrySet()) { builder.add(FeatureField.newLinearQuery(fieldName, entry.getKey(), entry.getValue()), BooleanClause.Occur.SHOULD); diff --git a/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseTwoPhaseParameters.java b/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseTwoPhaseParameters.java index 6a4284429..5f19b6b8f 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseTwoPhaseParameters.java +++ b/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseTwoPhaseParameters.java @@ -28,6 +28,11 @@ import java.util.Locale; import java.util.Objects; +/** + * Represents the parameters for neural_sparse two-phase process. + * This class encapsulates settings related to window size expansion, pruning ratio, and whether the two-phase search is enabled. + * It includes mechanisms to update settings from the cluster dynamically. + */ @Getter @Setter @Accessors(chain = true, fluent = true) @@ -52,6 +57,12 @@ public class NeuralSparseTwoPhaseParameters implements Writeable { private Float pruning_ratio; private Boolean enabled; + /** + * Initialize when start a cluster. + * + * @param clusterService The opensearch clusterService. + * @param settings The env settings to initialize. + */ public static void initialize(ClusterService clusterService, Settings settings) { DEFAULT_ENABLED = NeuralSearchSettings.NEURAL_SPARSE_TWO_PHASE_DEFAULT_ENABLED.get(settings); DEFAULT_WINDOW_SIZE_EXPANSION = NeuralSearchSettings.NEURAL_SPARSE_TWO_PHASE_DEFAULT_WINDOW_SIZE_EXPANSION.get(settings); @@ -98,6 +109,13 @@ public void writeTo(StreamOutput out) throws IOException { out.writeBoolean(enabled); } + /** + * Builds the content of this object into an XContentBuilder, typically for JSON serialization. + * + * @param builder The builder to fill. + * @return the given XContentBuilder with object content added. + * @throws IOException if building the content fails. + */ public XContentBuilder doXContent(XContentBuilder builder) throws IOException { builder.startObject(NAME.getPreferredName()); builder.field(WINDOW_SIZE_EXPANSION.getPreferredName(), window_size_expansion); @@ -107,6 +125,13 @@ public XContentBuilder doXContent(XContentBuilder builder) throws IOException { return builder; } + /** + * Parses a NeuralSparseTwoPhaseParameters object from XContent (typically JSON). + * + * @param parser the XContentParser to extract data from. + * @return a new instance of NeuralSparseTwoPhaseParameters initialized from the parser. + * @throws IOException if parsing fails. + */ public static NeuralSparseTwoPhaseParameters parseFromXContent(XContentParser parser) throws IOException { XContentParser.Token token; String currentFieldName = ""; @@ -157,6 +182,12 @@ public int hashcode() { return builder.toHashCode(); } + /** + * Checks if the two-phase search feature is enabled based on the given parameters. + * + * @param neuralSparseTwoPhaseParameters The parameters to check. + * @return true if enabled, false otherwise. + */ public static boolean isEnabled(NeuralSparseTwoPhaseParameters neuralSparseTwoPhaseParameters) { if (Objects.isNull(neuralSparseTwoPhaseParameters)) { return false; @@ -164,6 +195,11 @@ public static boolean isEnabled(NeuralSparseTwoPhaseParameters neuralSparseTwoPh return neuralSparseTwoPhaseParameters.enabled(); } + /** + * A flag to determine if this feature are support. + * + * @return True if cluster are on support, false if it doesn't. + */ public static boolean isClusterOnOrAfterMinReqVersionForTwoPhaseSearchSupport() { return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_TWO_PHASE_SEARCH); } diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java index c172fcbac..174a953d6 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java @@ -25,7 +25,7 @@ import lombok.extern.log4j.Log4j2; -import static org.opensearch.neuralsearch.search.util.NeuralSparseTwoPhaseUtil.addRescoreContextFromNeuralSparseSparseQuery; +import static org.opensearch.neuralsearch.search.util.NeuralSparseTwoPhaseUtil.addRescoreContextFromNeuralSparseQuery; import static org.opensearch.neuralsearch.util.HybridQueryUtil.hasAliasFilter; import static org.opensearch.neuralsearch.util.HybridQueryUtil.hasNestedFieldOrNestedDocs; import static org.opensearch.neuralsearch.util.HybridQueryUtil.isHybridQuery; @@ -45,7 +45,7 @@ public boolean searchWith( final boolean hasFilterCollector, final boolean hasTimeout ) throws IOException { - addRescoreContextFromNeuralSparseSparseQuery(query, searchContext); + addRescoreContextFromNeuralSparseQuery(query, searchContext); if (!isHybridQuery(query, searchContext)) { validateQuery(searchContext, query); return super.searchWith(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout); diff --git a/src/main/java/org/opensearch/neuralsearch/search/util/NeuralSparseTwoPhaseUtil.java b/src/main/java/org/opensearch/neuralsearch/search/util/NeuralSparseTwoPhaseUtil.java index 6c163d01b..1d0bfbc9e 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/util/NeuralSparseTwoPhaseUtil.java +++ b/src/main/java/org/opensearch/neuralsearch/search/util/NeuralSparseTwoPhaseUtil.java @@ -27,6 +27,35 @@ * Include adding the second phase query to searchContext and set the currentQuery to highScoreTokenQuery. */ public class NeuralSparseTwoPhaseUtil { + /** + * @param query The whole query include neuralSparseQuery to executed. + * @param searchContext The searchContext with this query. + */ + public static void addRescoreContextFromNeuralSparseQuery(final Query query, SearchContext searchContext) { + Map query2weight = new HashMap<>(); + float windowSizeExpansion = populateQueryWeightsMapAndGetWindowSizeExpansion(query, query2weight, 1.0f, 1.0f); + Query twoPhaseQuery; + if (query2weight.isEmpty()) { + return; + } else if (query2weight.size() == 1) { + Map.Entry entry = query2weight.entrySet().stream().findFirst().get(); + twoPhaseQuery = new BoostQuery(entry.getKey(), entry.getValue()); + } else { + twoPhaseQuery = getNestedTwoPhaseQuery(query2weight); + } + int curWindowSize = (int) (searchContext.size() * windowSizeExpansion); + if (curWindowSize < 0 + || curWindowSize > min( + NeuralSparseTwoPhaseParameters.MAX_WINDOW_SIZE, + MAX_RESCORE_WINDOW_SETTING.get(searchContext.getQueryShardContext().getIndexSettings().getSettings()) + )) { + throw new IllegalArgumentException("Two phase final windowSize out of score with value " + curWindowSize + "."); + } + QueryRescorer.QueryRescoreContext rescoreContext = new QueryRescorer.QueryRescoreContext(curWindowSize); + rescoreContext.setQuery(twoPhaseQuery); + rescoreContext.setRescoreQueryWeight(getOriginQueryWeightAfterRescore(searchContext.rescore())); + searchContext.addRescore(rescoreContext); + } private static float populateQueryWeightsMapAndGetWindowSizeExpansion( final Query query, @@ -72,34 +101,4 @@ private static Query getNestedTwoPhaseQuery(Map query2weight) { return builder.build(); } - /** - * - * @param query The whole query include neuralSparseQuery to executed. - * @param searchContext The searchContext with this query. - */ - public static void addRescoreContextFromNeuralSparseSparseQuery(final Query query, SearchContext searchContext) { - Map query2weight = new HashMap<>(); - float windowSizeExpansion = populateQueryWeightsMapAndGetWindowSizeExpansion(query, query2weight, 1.0f, 1.0f); - Query twoPhaseQuery; - if (query2weight.isEmpty()) { - return; - } else if (query2weight.size() == 1) { - Map.Entry entry = query2weight.entrySet().stream().findFirst().get(); - twoPhaseQuery = new BoostQuery(entry.getKey(), entry.getValue()); - } else { - twoPhaseQuery = getNestedTwoPhaseQuery(query2weight); - } - int curWindowSize = (int) (searchContext.size() * windowSizeExpansion); - if (curWindowSize < 0 - || curWindowSize > min( - NeuralSparseTwoPhaseParameters.MAX_WINDOW_SIZE, - MAX_RESCORE_WINDOW_SETTING.get(searchContext.getQueryShardContext().getIndexSettings().getSettings()) - )) { - throw new IllegalArgumentException("Two phase final windowSize out of score with value " + curWindowSize + "."); - } - QueryRescorer.QueryRescoreContext rescoreContext = new QueryRescorer.QueryRescoreContext(curWindowSize); - rescoreContext.setQuery(twoPhaseQuery); - rescoreContext.setRescoreQueryWeight(getOriginQueryWeightAfterRescore(searchContext.rescore())); - searchContext.addRescore(rescoreContext); - } } diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java index 1d103b01d..c04fdc59a 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java @@ -925,20 +925,6 @@ private void setUpClusterService(Version version) { NeuralSearchClusterUtil.instance().initialize(clusterService); } - @SneakyThrows - public void testBuildFeatureFieldQueryFormTokens() { - NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(FIELD_NAME) - .queryText(QUERY_TEXT) - .modelId(MODEL_ID) - .queryTokensSupplier(QUERY_TOKENS_SUPPLIER); - BooleanQuery booleanQuery = sparseEncodingQueryBuilder.buildFeatureFieldQueryFromTokens( - sparseEncodingQueryBuilder.queryTokensSupplier().get(), - FIELD_NAME - ); - assertNotNull(booleanQuery); - assertSame(booleanQuery.clauses().size(), 2); - } - @SneakyThrows public void testTokenDividedByScores_whenDefaultSettings() { Map map = new HashMap<>(); diff --git a/src/test/java/org/opensearch/neuralsearch/search/util/NeuralSparseTwoPhaseUtilTests.java b/src/test/java/org/opensearch/neuralsearch/search/util/NeuralSparseTwoPhaseUtilTests.java index 693d356db..d03d4881c 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/util/NeuralSparseTwoPhaseUtilTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/util/NeuralSparseTwoPhaseUtilTests.java @@ -45,7 +45,7 @@ import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import static org.opensearch.neuralsearch.search.util.NeuralSparseTwoPhaseUtil.addRescoreContextFromNeuralSparseSparseQuery; +import static org.opensearch.neuralsearch.search.util.NeuralSparseTwoPhaseUtil.addRescoreContextFromNeuralSparseQuery; public class NeuralSparseTwoPhaseUtilTests extends OpenSearchTestCase { @@ -98,22 +98,22 @@ public void testInitialize() { @SneakyThrows public void testAddTwoPhaseNeuralSparseQuery_whenQuery2WeightEmpty_thenNoRescoreAdded() { Query query = mock(Query.class); - addRescoreContextFromNeuralSparseSparseQuery(query, mockSearchContext); + addRescoreContextFromNeuralSparseQuery(query, mockSearchContext); verify(mockSearchContext, never()).addRescore(any()); } @SneakyThrows public void testAddTwoPhaseNeuralSparseQuery_whenUnSupportedQuery_thenNoRescoreAdded() { FunctionScoreQuery functionScoreQuery = new FunctionScoreQuery(normalNeuralSparseQuery, mock(DoubleValuesSource.class)); - addRescoreContextFromNeuralSparseSparseQuery(functionScoreQuery, mockSearchContext); + addRescoreContextFromNeuralSparseQuery(functionScoreQuery, mockSearchContext); DisjunctionMaxQuery disjunctionMaxQuery = new DisjunctionMaxQuery(Collections.emptyList(), 1.0f); - addRescoreContextFromNeuralSparseSparseQuery(disjunctionMaxQuery, mockSearchContext); + addRescoreContextFromNeuralSparseQuery(disjunctionMaxQuery, mockSearchContext); List subQueries = new ArrayList<>(); List filterQueries = new ArrayList<>(); subQueries.add(normalNeuralSparseQuery); filterQueries.add(new MatchAllDocsQuery()); HybridQuery hybridQuery = new HybridQuery(subQueries, filterQueries); - addRescoreContextFromNeuralSparseSparseQuery(hybridQuery, mockSearchContext); + addRescoreContextFromNeuralSparseQuery(hybridQuery, mockSearchContext); assertEquals(normalNeuralSparseQuery.getCurrentQuery(), currentQuery); verify(mockSearchContext, never()).addRescore(any()); } @@ -121,7 +121,7 @@ public void testAddTwoPhaseNeuralSparseQuery_whenUnSupportedQuery_thenNoRescoreA @SneakyThrows public void testAddTwoPhaseNeuralSparseQuery_whenSingleEntryInQuery2Weight_thenRescoreAdded() { NeuralSparseQuery neuralSparseQuery = new NeuralSparseQuery(mock(Query.class), mock(Query.class), mock(Query.class), 5.0f); - addRescoreContextFromNeuralSparseSparseQuery(neuralSparseQuery, mockSearchContext); + addRescoreContextFromNeuralSparseQuery(neuralSparseQuery, mockSearchContext); verify(mockSearchContext).addRescore(any(QueryRescorer.QueryRescoreContext.class)); } @@ -135,7 +135,7 @@ public void testAddTwoPhaseNeuralSparseQuery_whenCompoundBooleanQuery_thenRescor queryBuilder.add(boostQuery1, BooleanClause.Occur.SHOULD); queryBuilder.add(boostQuery2, BooleanClause.Occur.SHOULD); BooleanQuery booleanQuery = queryBuilder.build(); - addRescoreContextFromNeuralSparseSparseQuery(booleanQuery, mockSearchContext); + addRescoreContextFromNeuralSparseQuery(booleanQuery, mockSearchContext); verify(mockSearchContext).addRescore(any(QueryRescorer.QueryRescoreContext.class)); } @@ -155,7 +155,7 @@ public void testAddTwoPhaseNeuralSparseQuery_whenBooleanClauseType_thenVerifyBoo queryBuilder.add(boostQuery3, BooleanClause.Occur.FILTER); queryBuilder.add(boostQuery4, BooleanClause.Occur.MUST_NOT); BooleanQuery booleanQuery = queryBuilder.build(); - addRescoreContextFromNeuralSparseSparseQuery(booleanQuery, mockSearchContext); + addRescoreContextFromNeuralSparseQuery(booleanQuery, mockSearchContext); ArgumentCaptor rtxCaptor = ArgumentCaptor.forClass(RescoreContext.class); verify(mockSearchContext).addRescore(rtxCaptor.capture()); QueryRescorer.QueryRescoreContext context = (QueryRescorer.QueryRescoreContext) rtxCaptor.getValue(); @@ -179,7 +179,7 @@ public void testAddTwoPhaseNeuralSparseQuery_whenBooleanClauseType_thenVerifyBoo @SneakyThrows public void testWindowSize_whenNormalConditions_thenWindowSizeIsAsSet() { NeuralSparseQuery query = normalNeuralSparseQuery; - addRescoreContextFromNeuralSparseSparseQuery(query, mockSearchContext); + addRescoreContextFromNeuralSparseQuery(query, mockSearchContext); ArgumentCaptor rescoreContextArgumentCaptor = ArgumentCaptor.forClass( QueryRescorer.QueryRescoreContext.class ); @@ -192,14 +192,16 @@ public void testWindowSize_whenBoundaryConditions_thenThrowException() { NeuralSparseQuery query = new NeuralSparseQuery(new MatchAllDocsQuery(), new MatchAllDocsQuery(), new MatchAllDocsQuery(), 5000f); NeuralSparseQuery finalQuery1 = query; - expectThrows(IllegalArgumentException.class, () -> { - addRescoreContextFromNeuralSparseSparseQuery(finalQuery1, mockSearchContext); - }); + expectThrows( + IllegalArgumentException.class, + () -> { addRescoreContextFromNeuralSparseQuery(finalQuery1, mockSearchContext); } + ); query = new NeuralSparseQuery(new MatchAllDocsQuery(), new MatchAllDocsQuery(), new MatchAllDocsQuery(), Float.MAX_VALUE); NeuralSparseQuery finalQuery = query; - expectThrows(IllegalArgumentException.class, () -> { - addRescoreContextFromNeuralSparseSparseQuery(finalQuery, mockSearchContext); - }); + expectThrows( + IllegalArgumentException.class, + () -> { addRescoreContextFromNeuralSparseQuery(finalQuery, mockSearchContext); } + ); } @SneakyThrows @@ -211,7 +213,7 @@ public void testRescoreListWeightCalculation_whenMultipleRescoreContexts_thenCal List rescoreContextList = Arrays.asList(mockContext1, mockContext2); when(mockSearchContext.rescore()).thenReturn(rescoreContextList); NeuralSparseQuery query = normalNeuralSparseQuery; - addRescoreContextFromNeuralSparseSparseQuery(query, mockSearchContext); + addRescoreContextFromNeuralSparseQuery(query, mockSearchContext); ArgumentCaptor rtxCaptor = ArgumentCaptor.forClass(RescoreContext.class); verify(mockSearchContext).addRescore(rtxCaptor.capture()); QueryRescorer.QueryRescoreContext context = (QueryRescorer.QueryRescoreContext) rtxCaptor.getValue(); @@ -222,7 +224,7 @@ public void testRescoreListWeightCalculation_whenMultipleRescoreContexts_thenCal public void testEmptyRescoreListWeight_whenRescoreListEmpty_thenDefaultWeightUsed() { when(mockSearchContext.rescore()).thenReturn(Collections.emptyList()); NeuralSparseQuery query = normalNeuralSparseQuery; - addRescoreContextFromNeuralSparseSparseQuery(query, mockSearchContext); + addRescoreContextFromNeuralSparseQuery(query, mockSearchContext); ArgumentCaptor rtxCaptor = ArgumentCaptor.forClass(RescoreContext.class); verify(mockSearchContext).addRescore(rtxCaptor.capture()); QueryRescorer.QueryRescoreContext context = (QueryRescorer.QueryRescoreContext) rtxCaptor.getValue();