Skip to content

Commit

Permalink
Refactor NeuralSparseTwoPhaseUtil.
Browse files Browse the repository at this point in the history
Signed-off-by: conggguan <[email protected]>
  • Loading branch information
conggguan committed Apr 25, 2024
1 parent 3a9da75 commit 37646cb
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,40 +35,20 @@ public class NeuralSparseTwoPhaseUtil {
* @param searchContext The searchContext with this query.
*/
public static void addRescoreContextFromNeuralSparseQuery(final Query query, final SearchContext searchContext) {
Map<Query, Float> query2weight = new HashMap<>();
float windowSizeExpansion = populateQueryWeightsMapAndGetWindowSizeExpansion(query, query2weight, 1.0f, 1.0f);
Query twoPhaseQuery;
if (query2weight.isEmpty()) {
return;
}
if (query2weight.size() == 1) {
Map.Entry<Query, Float> entry = query2weight.entrySet().iterator().next();
twoPhaseQuery = new BoostQuery(entry.getKey(), entry.getValue());
} else {
twoPhaseQuery = getNestedTwoPhaseQuery(query2weight);
}
int curWindowSize = (int) (searchContext.size() * windowSizeExpansion);
if (curWindowSize < 0 || curWindowSize > NeuralSparseTwoPhaseParameters.MAX_WINDOW_SIZE) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"Two phase final windowSize %d out of score with limit %d. "
+ "You can change the value of cluster setting [plugins.neural_search.neural_sparse.two_phase.max_window_size] "
+ "to a integer at least 50.",
curWindowSize,
NeuralSparseTwoPhaseParameters.MAX_WINDOW_SIZE
)
);
}
QueryRescorer.QueryRescoreContext rescoreContext = new QueryRescorer.QueryRescoreContext(curWindowSize);
rescoreContext.setQuery(twoPhaseQuery);
rescoreContext.setRescoreQueryWeight(getOriginQueryWeightAfterRescore(searchContext.rescore()));
searchContext.addRescore(rescoreContext);
Map<NeuralSparseQuery, Float> neuralSparseQuery2Weight = new HashMap<>();
// Store all neuralSparse query and it's global weight in neuralSparseQuery2Weight, and get the max windowSizeExpansion of them..
float windowSizeExpansion = populateQueryWeightsMapAndGetWindowSizeExpansion(query, neuralSparseQuery2Weight, 1.0f, 1.0f);
Query twoPhaseQuery = getNestedTwoPhaseQueryFromNeuralSparseQuerySet(neuralSparseQuery2Weight);
if (twoPhaseQuery == null) return;
// Set the valid neural_sparse query's current query to it's highScoreTokenQuery.
neuralSparseQuery2Weight.keySet().forEach(NeuralSparseQuery::setCurrentQueryToHighScoreTokenQuery);
// Add two phase to searchContext's rescore list.
addTwoPhaseQuery2RescoreContext(searchContext, windowSizeExpansion, twoPhaseQuery);
}

private static float populateQueryWeightsMapAndGetWindowSizeExpansion(
final Query query,
Map<Query, Float> query2Weight,
Map<NeuralSparseQuery, Float> query2weight,
float weight,
float windoSizeExpansion
) {
Expand All @@ -77,36 +57,65 @@ private static float populateQueryWeightsMapAndGetWindowSizeExpansion(
weight *= boostQuery.getBoost();
windoSizeExpansion = max(
windoSizeExpansion,
populateQueryWeightsMapAndGetWindowSizeExpansion(boostQuery.getQuery(), query2Weight, weight, windoSizeExpansion)
populateQueryWeightsMapAndGetWindowSizeExpansion(boostQuery.getQuery(), query2weight, weight, windoSizeExpansion)
);
} else if (query instanceof BooleanQuery) {
for (BooleanClause clause : (BooleanQuery) query) {
if (clause.isScoring()) {
windoSizeExpansion = max(
windoSizeExpansion,
populateQueryWeightsMapAndGetWindowSizeExpansion(clause.getQuery(), query2Weight, weight, windoSizeExpansion)
populateQueryWeightsMapAndGetWindowSizeExpansion(clause.getQuery(), query2weight, weight, windoSizeExpansion)
);
}
}
} else if (query instanceof NeuralSparseQuery) {
query2Weight.put(((NeuralSparseQuery) query).getLowScoreTokenQuery(), weight);
((NeuralSparseQuery) query).setCurrentQueryToHighScoreTokenQuery();
query2weight.put(((NeuralSparseQuery) query), weight);
windoSizeExpansion = max(windoSizeExpansion, ((NeuralSparseQuery) query).getRescoreWindowSizeExpansion());
}
// ToDo Support for other compound query.
return windoSizeExpansion;
}

private static float getOriginQueryWeightAfterRescore(List<RescoreContext> rescoreContextList) {
private static float getOriginQueryWeightAfterRescore(final List<RescoreContext> rescoreContextList) {
return rescoreContextList.stream()
.filter(ctx -> ctx instanceof QueryRescorer.QueryRescoreContext)
.map(ctx -> ((QueryRescorer.QueryRescoreContext) ctx).queryWeight())
.reduce(1.0f, (a, b) -> a * b);
}

private static Query getNestedTwoPhaseQuery(Map<Query, Float> query2weight) {
private static void addTwoPhaseQuery2RescoreContext(
final SearchContext searchContext,
final float windowSizeExpansion,
Query twoPhaseQuery
) {
int curWindowSize = (int) (searchContext.size() * windowSizeExpansion);
if (curWindowSize < 0 || curWindowSize > NeuralSparseTwoPhaseParameters.MAX_WINDOW_SIZE) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"Two phase final windowSize %d out of score with limit %d. "
+ "You can change the value of cluster setting [plugins.neural_search.neural_sparse.two_phase.max_window_size] "
+ "to a integer at least 50.",
curWindowSize,
NeuralSparseTwoPhaseParameters.MAX_WINDOW_SIZE
)
);
}
QueryRescorer.QueryRescoreContext rescoreContext = new QueryRescorer.QueryRescoreContext(curWindowSize);
rescoreContext.setQuery(twoPhaseQuery);
rescoreContext.setRescoreQueryWeight(getOriginQueryWeightAfterRescore(searchContext.rescore()));
searchContext.addRescore(rescoreContext);
}

private static Query getNestedTwoPhaseQueryFromNeuralSparseQuerySet(final Map<NeuralSparseQuery, Float> originNeuralSparse2weight) {
if (originNeuralSparse2weight.isEmpty()) return null;
BooleanQuery.Builder builder = new BooleanQuery.Builder();
query2weight.forEach((query, weight) -> { builder.add(new BoostQuery(query, weight), BooleanClause.Occur.SHOULD); });
originNeuralSparse2weight.forEach(
(neuralSparseQuery, weight) -> builder.add(
new BoostQuery(neuralSparseQuery.getLowScoreTokenQuery(), weight),
BooleanClause.Occur.SHOULD
)
);
return builder.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import lombok.SneakyThrows;
import org.junit.Before;
import org.opensearch.Version;
import org.opensearch.cluster.ClusterState;
import org.opensearch.cluster.node.DiscoveryNodes;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.settings.ClusterSettings;
Expand Down Expand Up @@ -45,8 +47,8 @@ public class NeuralSparseTwoPhaseParametersTests extends OpenSearchTestCase {
public static NeuralSparseTwoPhaseParameters TWO_PHASE_PARAMETERS = new NeuralSparseTwoPhaseParameters().enabled(TEST_ENABLED)
.pruning_ratio(TEST_PRUNING_RATIO)
.window_size_expansion(TEST_WINDOW_SIZE_EXPANSION);

private ClusterSettings clusterSettings;
DiscoveryNodes mockDiscoveryNodes = mock(DiscoveryNodes.class);

@Before
public void setUpNeuralSparseTwoPhaseParameters() {
Expand All @@ -60,10 +62,19 @@ public void setUpNeuralSparseTwoPhaseParameters() {
NeuralSearchSettings.NEURAL_SPARSE_TWO_PHASE_MAX_WINDOW_SIZE
)
).collect(Collectors.toSet());

clusterSettings = new ClusterSettings(settings, settingsSet);
ClusterService clusterService = mock(ClusterService.class);
when(clusterService.getClusterSettings()).thenReturn(clusterSettings);
NeuralSparseTwoPhaseParameters.initialize(clusterService, settings);

ClusterService mockClusterService = mock(ClusterService.class);
ClusterState mockClusterState = mock(ClusterState.class);

when(mockClusterService.state()).thenReturn(mockClusterState);
when(mockClusterState.getNodes()).thenReturn(mockDiscoveryNodes);
when(mockDiscoveryNodes.getMinNodeVersion()).thenReturn(Version.CURRENT);
when(mockClusterService.getClusterSettings()).thenReturn(clusterSettings);

NeuralSearchClusterUtil.instance().initialize(mockClusterService);
NeuralSparseTwoPhaseParameters.initialize(mockClusterService, settings);
}

public void testDefaultValue() {
Expand Down Expand Up @@ -252,6 +263,7 @@ public void testToXContent() {
assertEquals(NeuralSparseTwoPhaseParameters.DEFAULT_ENABLED, inner.get(NeuralSparseTwoPhaseParameters.ENABLED.getPreferredName()));
}

@SneakyThrows
public void testEquals() {
NeuralSparseTwoPhaseParameters param = NeuralSparseTwoPhaseParameters.getDefaultSettings();
NeuralSparseTwoPhaseParameters paramSame = NeuralSparseTwoPhaseParameters.getDefaultSettings();
Expand All @@ -267,6 +279,7 @@ public void testEquals() {
assertNotEquals(paramDiffEnabled, param);
}

@SneakyThrows
public void testIsEnabled() {
NeuralSparseTwoPhaseParameters enabled = new NeuralSparseTwoPhaseParameters().enabled(true);
NeuralSparseTwoPhaseParameters disabled = new NeuralSparseTwoPhaseParameters().enabled(false);
Expand All @@ -275,6 +288,7 @@ public void testIsEnabled() {
assertFalse(NeuralSparseTwoPhaseParameters.isEnabled(null));
}

@SneakyThrows
public void testIsClusterOnOrAfterMinReqVersionForTwoPhaseSearchSupport() {
ClusterService clusterServiceBefore = NeuralSearchClusterTestUtils.mockClusterService(Version.V_2_13_0);
NeuralSearchClusterUtil.instance().initialize(clusterServiceBefore);
Expand Down

0 comments on commit 37646cb

Please sign in to comment.