Skip to content

Commit

Permalink
Merge pull request #49 from o19s/opensearch-query-sets
Browse files Browse the repository at this point in the history
Adding ability to run query sets and save results on OpenSearch
  • Loading branch information
jzonthemtn authored Dec 3, 2024
2 parents 2687f87 + 21ece38 commit c47127d
Show file tree
Hide file tree
Showing 14 changed files with 424 additions and 52 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/bin/bash -e

#QUERY_SET=`curl -s -X POST "http://localhost:9200/_plugins/search_quality_eval/queryset?name=test&description=fake&sampling=pptss" | jq .query_set | tr -d '"'`
curl -s -X POST "http://localhost:9200/_plugins/search_quality_eval/queryset?name=test&description=fake&sampling=pptss&query_set_size=5000"
curl -s -X POST "http://localhost:9200/_plugins/search_quality_eval/queryset?name=test&description=fake&sampling=pptss&query_set_size=100"

#echo ${QUERY_SET}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/bin/bash -e

curl -s "http://localhost:9200/search_quality_eval_query_sets_run_results/_search" | jq
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/bin/bash -e

curl -s "http://localhost:9200/search_quality_eval_query_sets/_search" | jq
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#!/bin/bash -e

QUERY_SET_ID="${1}"
JUDGMENTS_ID="12345"
INDEX="ecommerce"
ID_FIELD="asin"
K="10"

curl -s -X POST "http://localhost:9200/_plugins/search_quality_eval/run?id=${QUERY_SET_ID}&judgments_id=${JUDGMENTS_ID}&index=${INDEX}&id_field=${ID_FIELD}&k=${K}" \
-H "Content-Type: application/json" \
--data-binary '{
"query": {
"match": {
"description": {
"query": "#$query##"
}
}
}
}'
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.index.IndexRequest;
import org.opensearch.action.index.IndexResponse;
import org.opensearch.action.support.WriteRequest;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
Expand Down Expand Up @@ -139,10 +140,25 @@ public void runJob(final ScheduledJobParameter jobParameter, final JobExecutionC
job.put("invocation", "scheduled");
job.put("max_rank", searchQualityEvaluationJobParameter.getMaxRank());

final IndexRequest indexRequest = new IndexRequest().index(SearchQualityEvaluationPlugin.COMPLETED_JOBS_INDEX_NAME)
.id(UUID.randomUUID().toString()).source(job).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);

client.index(indexRequest).get();
final String judgmentsId = UUID.randomUUID().toString();

final IndexRequest indexRequest = new IndexRequest()
.index(SearchQualityEvaluationPlugin.COMPLETED_JOBS_INDEX_NAME)
.id(judgmentsId)
.source(job)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);

client.index(indexRequest, new ActionListener<>() {
@Override
public void onResponse(IndexResponse indexResponse) {
LOGGER.info("Successfully indexed implicit judgments {}", judgmentsId);
}

@Override
public void onFailure(Exception ex) {
LOGGER.error("Unable to index implicit judgments", ex);
}
});

}, exception -> { throw new IllegalStateException("Failed to acquire lock."); }));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ public class SearchQualityEvaluationPlugin extends Plugin implements ActionPlugi
*/
public static final String QUERY_SETS_INDEX_NAME = "search_quality_eval_query_sets";

/**
* The name of the index that stores the query set run results.
*/
public static final String QUERY_SETS_RUN_RESULTS = "search_quality_eval_query_sets_run_results";

@Override
public Collection<Object> createComponents(
final Client client,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.opensearch.client.node.NodeClient;
import org.opensearch.common.xcontent.json.JsonXContent;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.eval.judgments.clickmodel.coec.CoecClickModel;
import org.opensearch.eval.judgments.clickmodel.coec.CoecClickModelParameters;
Expand All @@ -40,6 +41,7 @@
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicBoolean;

public class SearchQualityEvaluationRestHandler extends BaseRestHandler {

Expand All @@ -65,6 +67,11 @@ public class SearchQualityEvaluationRestHandler extends BaseRestHandler {
*/
public static final String QUERYSET_RUN_URL = "/_plugins/search_quality_eval/run";

/**
* The placeholder in the query that gets replaced by the query term when running a query set.
*/
public static final String QUERY_PLACEHOLDER = "#$query##";

@Override
public String getName() {
return "Search Quality Evaluation Framework";
Expand Down Expand Up @@ -98,7 +105,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
if (AllQueriesQuerySampler.NAME.equalsIgnoreCase(sampling)) {

// If we are not sampling queries, the query sets should just be directly
// indexed into OpenSearch using the `ubu_queries` index directly.
// indexed into OpenSearch using the `ubi_queries` index directly.

try {

Expand Down Expand Up @@ -148,20 +155,43 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
} else if(QUERYSET_RUN_URL.equalsIgnoreCase(request.path())) {

final String querySetId = request.param("id");
final String judgmentsId = request.param("judgments_id");
final String index = request.param("index");
final String idField = request.param("id_field", "_id");
final int k = Integer.parseInt(request.param("k", "10"));

if(querySetId == null || judgmentsId == null || index == null) {
return restChannel -> restChannel.sendResponse(new BytesRestResponse(RestStatus.BAD_REQUEST, "{\"error\": \"Missing required parameters.\"}"));
}

if(k < 1) {
return restChannel -> restChannel.sendResponse(new BytesRestResponse(RestStatus.BAD_REQUEST, "{\"error\": \"k must be a positive integer.\"}"));
}

if(!request.hasContent()) {
return restChannel -> restChannel.sendResponse(new BytesRestResponse(RestStatus.BAD_REQUEST, "{\"error\": \"Missing query in body.\"}"));
}

// Get the query JSON from the content.
final String query = new String(BytesReference.toBytes(request.content()));

// Validate the query has a QUERY_PLACEHOLDER.
if(!query.contains(QUERY_PLACEHOLDER)) {
return restChannel -> restChannel.sendResponse(new BytesRestResponse(RestStatus.BAD_REQUEST, "{\"error\": \"Missing query placeholder in query.\"}"));
}

try {

final OpenSearchQuerySetRunner openSearchQuerySetRunner = new OpenSearchQuerySetRunner(client);
final QuerySetRunResult querySetRunResult = openSearchQuerySetRunner.run(querySetId);

// TODO: Index the querySetRunResult.
final QuerySetRunResult querySetRunResult = openSearchQuerySetRunner.run(querySetId, judgmentsId, index, idField, query, k);
openSearchQuerySetRunner.save(querySetRunResult);

} catch (Exception ex) {
LOGGER.error("Unable to retrieve query set with ID {}", querySetId);
LOGGER.error("Unable to run query set with ID {}: ", querySetId, ex);
return restChannel -> restChannel.sendResponse(new BytesRestResponse(RestStatus.INTERNAL_SERVER_ERROR, ex.getMessage()));
}

return restChannel -> restChannel.sendResponse(new BytesRestResponse(RestStatus.OK, "{\"message\": \"Query set " + querySetId + " run initiated.\"}"));
return restChannel -> restChannel.sendResponse(new BytesRestResponse(RestStatus.OK, "{\"message\": \"Run initiated for query set " + querySetId + "\"}"));

// Handle the on-demand creation of implicit judgments.
} else if(IMPLICIT_JUDGMENTS_URL.equalsIgnoreCase(request.path())) {
Expand Down Expand Up @@ -196,16 +226,35 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
job.put("invocation", "on_demand");
job.put("max_rank", maxRank);

final IndexRequest indexRequest = new IndexRequest().index(SearchQualityEvaluationPlugin.COMPLETED_JOBS_INDEX_NAME)
.id(UUID.randomUUID().toString()).source(job).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
final String judgmentsId = UUID.randomUUID().toString();

try {
client.index(indexRequest).get();
} catch (Exception e) {
throw new RuntimeException(e);
}
final IndexRequest indexRequest = new IndexRequest()
.index(SearchQualityEvaluationPlugin.COMPLETED_JOBS_INDEX_NAME)
.id(judgmentsId)
.source(job)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);

final AtomicBoolean success = new AtomicBoolean(false);

client.index(indexRequest, new ActionListener<>() {
@Override
public void onResponse(final IndexResponse indexResponse) {
LOGGER.debug("Judgments indexed: {}", judgmentsId);
success.set(true);
}

return restChannel -> restChannel.sendResponse(new BytesRestResponse(RestStatus.OK, "{\"message\": \"Implicit judgment generation initiated.\"}"));
@Override
public void onFailure(final Exception ex) {
LOGGER.error("Unable to index judgment with ID {}", judgmentsId, ex);
success.set(false);
}
});

if(success.get()) {
return restChannel -> restChannel.sendResponse(new BytesRestResponse(RestStatus.OK, "{\"judgments_id\": \"" + judgmentsId + "\"}"));
} else {
return restChannel -> restChannel.sendResponse(new BytesRestResponse(RestStatus.INTERNAL_SERVER_ERROR,"Unable to index judgments."));
}

} else {
return restChannel -> restChannel.sendResponse(new BytesRestResponse(RestStatus.BAD_REQUEST, "{\"error\": \"Invalid click model.\"}"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,50 @@
*/
package org.opensearch.eval.runners;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.index.IndexRequest;
import org.opensearch.action.index.IndexResponse;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.client.Client;
import org.opensearch.core.action.ActionListener;
import org.opensearch.eval.SearchQualityEvaluationPlugin;
import org.opensearch.eval.judgments.model.Judgment;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.search.SearchHit;
import org.opensearch.search.builder.SearchSourceBuilder;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class OpenSearchQuerySetRunner extends QuerySetRunner {
import static org.opensearch.eval.SearchQualityEvaluationRestHandler.QUERY_PLACEHOLDER;

/**
* A {@link QuerySetRunner} for Amazon OpenSearch.
*/
public class OpenSearchQuerySetRunner implements QuerySetRunner {

private static final Logger LOGGER = LogManager.getLogger(OpenSearchQuerySetRunner.class);

final Client client;

/**
* Creates a new query set runner
* @param client An OpenSearch {@link Client}.
*/
public OpenSearchQuerySetRunner(final Client client) {
this.client = client;
}

@Override
public QuerySetRunResult run(String querySetId) {
public QuerySetRunResult run(final String querySetId, final String judgmentsId, final String index, final String idField, final String query, final int k) {

// TODO: Get the judgments we will use for metric calculation.
final List<Judgment> judgments = new ArrayList<>();

// Get the query set.
final SearchSourceBuilder getQuerySetSearchSourceBuilder = new SearchSourceBuilder();
Expand All @@ -40,43 +62,70 @@ public QuerySetRunResult run(String querySetId) {

try {

// TODO: Don't use .get()
final SearchResponse searchResponse = client.search(getQuerySetSearchRequest).get();

// The queries from the query set that will be run.
final Collection<String> queries = (Collection<String>) searchResponse.getHits().getAt(0).getSourceAsMap().get("queries");
final Collection<Map<String, Long>> queries = (Collection<Map<String, Long>>) searchResponse.getHits().getAt(0).getSourceAsMap().get("queries");

// The results of each query.
final Collection<QueryResult> queryResults = new ArrayList<>();
final List<QueryResult> queryResults = new ArrayList<>();

// TODO: Initiate the running of the query set.
for(final String query : queries) {
for(Map<String, Long> queryMap : queries) {

// TODO: What should this query be?
final SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
searchSourceBuilder.query(QueryBuilders.matchQuery("title", query));
// TODO: Just fetch the id ("asin") field and not all the unnecessary fields.
// Loop over each query in the map and run each one.
for (final String userQuery : queryMap.keySet()) {

// TODO: Allow for setting this index name.
final SearchRequest searchRequest = new SearchRequest("ecommerce");
getQuerySetSearchRequest.source(getQuerySetSearchSourceBuilder);
// Replace the query placeholder with the user query.
final String q = query.replace(QUERY_PLACEHOLDER, userQuery);

final SearchResponse sr = client.search(searchRequest).get();
// Build the query from the one that was passed in.
final SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
searchSourceBuilder.query(QueryBuilders.wrapperQuery(q));
searchSourceBuilder.from(0);
// TODO: If k is > 10, we'll need to page through these.
searchSourceBuilder.size(k);

final List<String> orderedDocumentIds = new ArrayList<>();
String[] includeFields = new String[] {idField};
String[] excludeFields = new String[] {};
searchSourceBuilder.fetchSource(includeFields, excludeFields);

for(final SearchHit hit : sr.getHits().getHits()) {
// TODO: Allow for setting this index name.
final SearchRequest searchRequest = new SearchRequest(index);
getQuerySetSearchRequest.source(searchSourceBuilder);

// TODO: This field needs to be customizable.
orderedDocumentIds.add(hit.getFields().get("asin").toString());
client.search(searchRequest, new ActionListener<>() {

}
@Override
public void onResponse(final SearchResponse searchResponse) {

final List<String> orderedDocumentIds = new ArrayList<>();

for (final SearchHit hit : searchResponse.getHits().getHits()) {

queryResults.add(new QueryResult(orderedDocumentIds));
final Map<String, Object> sourceAsMap = hit.getSourceAsMap();
final String documentId = sourceAsMap.get(idField).toString();

orderedDocumentIds.add(documentId);

}

queryResults.add(new QueryResult(query, orderedDocumentIds, judgments, k));

}

@Override
public void onFailure(Exception ex) {
LOGGER.error("Unable to search for query: {}", query, ex);
}
});

}

}

// TODO: Calculate the search metrics given the results and the judgments.
final SearchMetrics searchMetrics = new SearchMetrics();
final SearchMetrics searchMetrics = new SearchMetrics(queryResults, judgments, k);

return new QuerySetRunResult(queryResults, searchMetrics);

Expand All @@ -86,4 +135,31 @@ public QuerySetRunResult run(String querySetId) {

}

@Override
public void save(final QuerySetRunResult result) throws Exception {

// Index the results into OpenSearch.

final Map<String, Object> results = new HashMap<>();

results.put("run_id", result.getRunId());
results.put("search_metrics", result.getSearchMetrics().getSearchMetricsAsMap());
results.put("query_results", result.getQueryResultsAsMap());

final IndexRequest indexRequest = new IndexRequest(SearchQualityEvaluationPlugin.QUERY_SETS_RUN_RESULTS);
indexRequest.source(results);

client.index(indexRequest, new ActionListener<>() {
@Override
public void onResponse(IndexResponse indexResponse) {
LOGGER.debug("Query set results indexed.");
}

@Override
public void onFailure(Exception ex) {
throw new RuntimeException(ex);
}
});
}

}
Loading

0 comments on commit c47127d

Please sign in to comment.