Skip to content

Commit

Permalink
Add some java documentation.
Browse files Browse the repository at this point in the history
Signed-off-by: conggguan <[email protected]>
  • Loading branch information
conggguan committed Apr 22, 2024
1 parent df0c9fc commit 2f1183e
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ public final class NeuralSparseQuery extends Query {
private final Query lowScoreTokenQuery;
private final Float rescoreWindowSizeExpansion;

/**
*
* @param field
* @return
*/
@Override
public String toString(String field) {
return "NeuralSparseQuery("
Expand Down Expand Up @@ -80,6 +85,14 @@ public int hashCode() {
return h;
}

/**
*
* @param searcher The searcher that execute the neural_sparse query.
* @param scoreMode How the produced scorers will be consumed.
* @param boost The boost that is propagated by the parent queries.
* @return The weight of currentQuery.
* @throws IOException If creteWeight failed.
*/
@Override
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
return currentQuery.createWeight(searcher, scoreMode, boost);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ private Map<String, Float> getFilteredScoreTokens(boolean aboveThreshold, float
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
}

public BooleanQuery buildFeatureFieldQueryFromTokens(Map<String, Float> tokens, String fieldName) {
private BooleanQuery buildFeatureFieldQueryFromTokens(Map<String, Float> tokens, String fieldName) {
BooleanQuery.Builder builder = new BooleanQuery.Builder();
for (Map.Entry<String, Float> entry : tokens.entrySet()) {
builder.add(FeatureField.newLinearQuery(fieldName, entry.getKey(), entry.getValue()), BooleanClause.Occur.SHOULD);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@
import java.util.Locale;
import java.util.Objects;

/**
* Represents the parameters for neural_sparse two-phase process.
* This class encapsulates settings related to window size expansion, pruning ratio, and whether the two-phase search is enabled.
* It includes mechanisms to update settings from the cluster dynamically.
*/
@Getter
@Setter
@Accessors(chain = true, fluent = true)
Expand All @@ -52,6 +57,12 @@ public class NeuralSparseTwoPhaseParameters implements Writeable {
private Float pruning_ratio;
private Boolean enabled;

/**
* Initialize when start a cluster.
*
* @param clusterService The opensearch clusterService.
* @param settings The env settings to initialize.
*/
public static void initialize(ClusterService clusterService, Settings settings) {
DEFAULT_ENABLED = NeuralSearchSettings.NEURAL_SPARSE_TWO_PHASE_DEFAULT_ENABLED.get(settings);
DEFAULT_WINDOW_SIZE_EXPANSION = NeuralSearchSettings.NEURAL_SPARSE_TWO_PHASE_DEFAULT_WINDOW_SIZE_EXPANSION.get(settings);
Expand Down Expand Up @@ -98,6 +109,13 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeBoolean(enabled);
}

/**
* Builds the content of this object into an XContentBuilder, typically for JSON serialization.
*
* @param builder The builder to fill.
* @return the given XContentBuilder with object content added.
* @throws IOException if building the content fails.
*/
public XContentBuilder doXContent(XContentBuilder builder) throws IOException {
builder.startObject(NAME.getPreferredName());
builder.field(WINDOW_SIZE_EXPANSION.getPreferredName(), window_size_expansion);
Expand All @@ -107,6 +125,13 @@ public XContentBuilder doXContent(XContentBuilder builder) throws IOException {
return builder;
}

/**
* Parses a NeuralSparseTwoPhaseParameters object from XContent (typically JSON).
*
* @param parser the XContentParser to extract data from.
* @return a new instance of NeuralSparseTwoPhaseParameters initialized from the parser.
* @throws IOException if parsing fails.
*/
public static NeuralSparseTwoPhaseParameters parseFromXContent(XContentParser parser) throws IOException {
XContentParser.Token token;
String currentFieldName = "";
Expand Down Expand Up @@ -157,13 +182,24 @@ public int hashcode() {
return builder.toHashCode();
}

/**
* Checks if the two-phase search feature is enabled based on the given parameters.
*
* @param neuralSparseTwoPhaseParameters The parameters to check.
* @return true if enabled, false otherwise.
*/
public static boolean isEnabled(NeuralSparseTwoPhaseParameters neuralSparseTwoPhaseParameters) {
if (Objects.isNull(neuralSparseTwoPhaseParameters)) {
return false;
}
return neuralSparseTwoPhaseParameters.enabled();
}

/**
* A flag to determine if this feature are support.
*
* @return True if cluster are on support, false if it doesn't.
*/
public static boolean isClusterOnOrAfterMinReqVersionForTwoPhaseSearchSupport() {
return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_TWO_PHASE_SEARCH);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

import lombok.extern.log4j.Log4j2;

import static org.opensearch.neuralsearch.search.util.NeuralSparseTwoPhaseUtil.addRescoreContextFromNeuralSparseSparseQuery;
import static org.opensearch.neuralsearch.search.util.NeuralSparseTwoPhaseUtil.addRescoreContextFromNeuralSparseQuery;
import static org.opensearch.neuralsearch.util.HybridQueryUtil.hasAliasFilter;
import static org.opensearch.neuralsearch.util.HybridQueryUtil.hasNestedFieldOrNestedDocs;
import static org.opensearch.neuralsearch.util.HybridQueryUtil.isHybridQuery;
Expand All @@ -45,7 +45,7 @@ public boolean searchWith(
final boolean hasFilterCollector,
final boolean hasTimeout
) throws IOException {
addRescoreContextFromNeuralSparseSparseQuery(query, searchContext);
addRescoreContextFromNeuralSparseQuery(query, searchContext);
if (!isHybridQuery(query, searchContext)) {
validateQuery(searchContext, query);
return super.searchWith(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,35 @@
* Include adding the second phase query to searchContext and set the currentQuery to highScoreTokenQuery.
*/
public class NeuralSparseTwoPhaseUtil {
/**
* @param query The whole query include neuralSparseQuery to executed.
* @param searchContext The searchContext with this query.
*/
public static void addRescoreContextFromNeuralSparseQuery(final Query query, SearchContext searchContext) {
Map<Query, Float> query2weight = new HashMap<>();
float windowSizeExpansion = populateQueryWeightsMapAndGetWindowSizeExpansion(query, query2weight, 1.0f, 1.0f);
Query twoPhaseQuery;
if (query2weight.isEmpty()) {
return;
} else if (query2weight.size() == 1) {
Map.Entry<Query, Float> entry = query2weight.entrySet().stream().findFirst().get();
twoPhaseQuery = new BoostQuery(entry.getKey(), entry.getValue());
} else {
twoPhaseQuery = getNestedTwoPhaseQuery(query2weight);
}
int curWindowSize = (int) (searchContext.size() * windowSizeExpansion);
if (curWindowSize < 0
|| curWindowSize > min(
NeuralSparseTwoPhaseParameters.MAX_WINDOW_SIZE,
MAX_RESCORE_WINDOW_SETTING.get(searchContext.getQueryShardContext().getIndexSettings().getSettings())
)) {
throw new IllegalArgumentException("Two phase final windowSize out of score with value " + curWindowSize + ".");
}
QueryRescorer.QueryRescoreContext rescoreContext = new QueryRescorer.QueryRescoreContext(curWindowSize);
rescoreContext.setQuery(twoPhaseQuery);
rescoreContext.setRescoreQueryWeight(getOriginQueryWeightAfterRescore(searchContext.rescore()));
searchContext.addRescore(rescoreContext);
}

private static float populateQueryWeightsMapAndGetWindowSizeExpansion(
final Query query,
Expand Down Expand Up @@ -72,34 +101,4 @@ private static Query getNestedTwoPhaseQuery(Map<Query, Float> query2weight) {
return builder.build();
}

/**
*
* @param query The whole query include neuralSparseQuery to executed.
* @param searchContext The searchContext with this query.
*/
public static void addRescoreContextFromNeuralSparseSparseQuery(final Query query, SearchContext searchContext) {
Map<Query, Float> query2weight = new HashMap<>();
float windowSizeExpansion = populateQueryWeightsMapAndGetWindowSizeExpansion(query, query2weight, 1.0f, 1.0f);
Query twoPhaseQuery;
if (query2weight.isEmpty()) {
return;
} else if (query2weight.size() == 1) {
Map.Entry<Query, Float> entry = query2weight.entrySet().stream().findFirst().get();
twoPhaseQuery = new BoostQuery(entry.getKey(), entry.getValue());
} else {
twoPhaseQuery = getNestedTwoPhaseQuery(query2weight);
}
int curWindowSize = (int) (searchContext.size() * windowSizeExpansion);
if (curWindowSize < 0
|| curWindowSize > min(
NeuralSparseTwoPhaseParameters.MAX_WINDOW_SIZE,
MAX_RESCORE_WINDOW_SETTING.get(searchContext.getQueryShardContext().getIndexSettings().getSettings())
)) {
throw new IllegalArgumentException("Two phase final windowSize out of score with value " + curWindowSize + ".");
}
QueryRescorer.QueryRescoreContext rescoreContext = new QueryRescorer.QueryRescoreContext(curWindowSize);
rescoreContext.setQuery(twoPhaseQuery);
rescoreContext.setRescoreQueryWeight(getOriginQueryWeightAfterRescore(searchContext.rescore()));
searchContext.addRescore(rescoreContext);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -925,20 +925,6 @@ private void setUpClusterService(Version version) {
NeuralSearchClusterUtil.instance().initialize(clusterService);
}

@SneakyThrows
public void testBuildFeatureFieldQueryFormTokens() {
NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(FIELD_NAME)
.queryText(QUERY_TEXT)
.modelId(MODEL_ID)
.queryTokensSupplier(QUERY_TOKENS_SUPPLIER);
BooleanQuery booleanQuery = sparseEncodingQueryBuilder.buildFeatureFieldQueryFromTokens(
sparseEncodingQueryBuilder.queryTokensSupplier().get(),
FIELD_NAME
);
assertNotNull(booleanQuery);
assertSame(booleanQuery.clauses().size(), 2);
}

@SneakyThrows
public void testTokenDividedByScores_whenDefaultSettings() {
Map<String, Float> map = new HashMap<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.neuralsearch.search.util.NeuralSparseTwoPhaseUtil.addRescoreContextFromNeuralSparseSparseQuery;
import static org.opensearch.neuralsearch.search.util.NeuralSparseTwoPhaseUtil.addRescoreContextFromNeuralSparseQuery;

public class NeuralSparseTwoPhaseUtilTests extends OpenSearchTestCase {

Expand Down Expand Up @@ -98,30 +98,30 @@ public void testInitialize() {
@SneakyThrows
public void testAddTwoPhaseNeuralSparseQuery_whenQuery2WeightEmpty_thenNoRescoreAdded() {
Query query = mock(Query.class);
addRescoreContextFromNeuralSparseSparseQuery(query, mockSearchContext);
addRescoreContextFromNeuralSparseQuery(query, mockSearchContext);
verify(mockSearchContext, never()).addRescore(any());
}

@SneakyThrows
public void testAddTwoPhaseNeuralSparseQuery_whenUnSupportedQuery_thenNoRescoreAdded() {
FunctionScoreQuery functionScoreQuery = new FunctionScoreQuery(normalNeuralSparseQuery, mock(DoubleValuesSource.class));
addRescoreContextFromNeuralSparseSparseQuery(functionScoreQuery, mockSearchContext);
addRescoreContextFromNeuralSparseQuery(functionScoreQuery, mockSearchContext);
DisjunctionMaxQuery disjunctionMaxQuery = new DisjunctionMaxQuery(Collections.emptyList(), 1.0f);
addRescoreContextFromNeuralSparseSparseQuery(disjunctionMaxQuery, mockSearchContext);
addRescoreContextFromNeuralSparseQuery(disjunctionMaxQuery, mockSearchContext);
List<Query> subQueries = new ArrayList<>();
List<Query> filterQueries = new ArrayList<>();
subQueries.add(normalNeuralSparseQuery);
filterQueries.add(new MatchAllDocsQuery());
HybridQuery hybridQuery = new HybridQuery(subQueries, filterQueries);
addRescoreContextFromNeuralSparseSparseQuery(hybridQuery, mockSearchContext);
addRescoreContextFromNeuralSparseQuery(hybridQuery, mockSearchContext);
assertEquals(normalNeuralSparseQuery.getCurrentQuery(), currentQuery);
verify(mockSearchContext, never()).addRescore(any());
}

@SneakyThrows
public void testAddTwoPhaseNeuralSparseQuery_whenSingleEntryInQuery2Weight_thenRescoreAdded() {
NeuralSparseQuery neuralSparseQuery = new NeuralSparseQuery(mock(Query.class), mock(Query.class), mock(Query.class), 5.0f);
addRescoreContextFromNeuralSparseSparseQuery(neuralSparseQuery, mockSearchContext);
addRescoreContextFromNeuralSparseQuery(neuralSparseQuery, mockSearchContext);
verify(mockSearchContext).addRescore(any(QueryRescorer.QueryRescoreContext.class));
}

Expand All @@ -135,7 +135,7 @@ public void testAddTwoPhaseNeuralSparseQuery_whenCompoundBooleanQuery_thenRescor
queryBuilder.add(boostQuery1, BooleanClause.Occur.SHOULD);
queryBuilder.add(boostQuery2, BooleanClause.Occur.SHOULD);
BooleanQuery booleanQuery = queryBuilder.build();
addRescoreContextFromNeuralSparseSparseQuery(booleanQuery, mockSearchContext);
addRescoreContextFromNeuralSparseQuery(booleanQuery, mockSearchContext);
verify(mockSearchContext).addRescore(any(QueryRescorer.QueryRescoreContext.class));
}

Expand All @@ -155,7 +155,7 @@ public void testAddTwoPhaseNeuralSparseQuery_whenBooleanClauseType_thenVerifyBoo
queryBuilder.add(boostQuery3, BooleanClause.Occur.FILTER);
queryBuilder.add(boostQuery4, BooleanClause.Occur.MUST_NOT);
BooleanQuery booleanQuery = queryBuilder.build();
addRescoreContextFromNeuralSparseSparseQuery(booleanQuery, mockSearchContext);
addRescoreContextFromNeuralSparseQuery(booleanQuery, mockSearchContext);
ArgumentCaptor<RescoreContext> rtxCaptor = ArgumentCaptor.forClass(RescoreContext.class);
verify(mockSearchContext).addRescore(rtxCaptor.capture());
QueryRescorer.QueryRescoreContext context = (QueryRescorer.QueryRescoreContext) rtxCaptor.getValue();
Expand All @@ -179,7 +179,7 @@ public void testAddTwoPhaseNeuralSparseQuery_whenBooleanClauseType_thenVerifyBoo
@SneakyThrows
public void testWindowSize_whenNormalConditions_thenWindowSizeIsAsSet() {
NeuralSparseQuery query = normalNeuralSparseQuery;
addRescoreContextFromNeuralSparseSparseQuery(query, mockSearchContext);
addRescoreContextFromNeuralSparseQuery(query, mockSearchContext);
ArgumentCaptor<QueryRescorer.QueryRescoreContext> rescoreContextArgumentCaptor = ArgumentCaptor.forClass(
QueryRescorer.QueryRescoreContext.class
);
Expand All @@ -192,14 +192,16 @@ public void testWindowSize_whenBoundaryConditions_thenThrowException() {

NeuralSparseQuery query = new NeuralSparseQuery(new MatchAllDocsQuery(), new MatchAllDocsQuery(), new MatchAllDocsQuery(), 5000f);
NeuralSparseQuery finalQuery1 = query;
expectThrows(IllegalArgumentException.class, () -> {
addRescoreContextFromNeuralSparseSparseQuery(finalQuery1, mockSearchContext);
});
expectThrows(
IllegalArgumentException.class,
() -> { addRescoreContextFromNeuralSparseQuery(finalQuery1, mockSearchContext); }
);
query = new NeuralSparseQuery(new MatchAllDocsQuery(), new MatchAllDocsQuery(), new MatchAllDocsQuery(), Float.MAX_VALUE);
NeuralSparseQuery finalQuery = query;
expectThrows(IllegalArgumentException.class, () -> {
addRescoreContextFromNeuralSparseSparseQuery(finalQuery, mockSearchContext);
});
expectThrows(
IllegalArgumentException.class,
() -> { addRescoreContextFromNeuralSparseQuery(finalQuery, mockSearchContext); }
);
}

@SneakyThrows
Expand All @@ -211,7 +213,7 @@ public void testRescoreListWeightCalculation_whenMultipleRescoreContexts_thenCal
List<RescoreContext> rescoreContextList = Arrays.asList(mockContext1, mockContext2);
when(mockSearchContext.rescore()).thenReturn(rescoreContextList);
NeuralSparseQuery query = normalNeuralSparseQuery;
addRescoreContextFromNeuralSparseSparseQuery(query, mockSearchContext);
addRescoreContextFromNeuralSparseQuery(query, mockSearchContext);
ArgumentCaptor<RescoreContext> rtxCaptor = ArgumentCaptor.forClass(RescoreContext.class);
verify(mockSearchContext).addRescore(rtxCaptor.capture());
QueryRescorer.QueryRescoreContext context = (QueryRescorer.QueryRescoreContext) rtxCaptor.getValue();
Expand All @@ -222,7 +224,7 @@ public void testRescoreListWeightCalculation_whenMultipleRescoreContexts_thenCal
public void testEmptyRescoreListWeight_whenRescoreListEmpty_thenDefaultWeightUsed() {
when(mockSearchContext.rescore()).thenReturn(Collections.emptyList());
NeuralSparseQuery query = normalNeuralSparseQuery;
addRescoreContextFromNeuralSparseSparseQuery(query, mockSearchContext);
addRescoreContextFromNeuralSparseQuery(query, mockSearchContext);
ArgumentCaptor<RescoreContext> rtxCaptor = ArgumentCaptor.forClass(RescoreContext.class);
verify(mockSearchContext).addRescore(rtxCaptor.capture());
QueryRescorer.QueryRescoreContext context = (QueryRescorer.QueryRescoreContext) rtxCaptor.getValue();
Expand Down

0 comments on commit 2f1183e

Please sign in to comment.