Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding ability to run query sets and save results on OpenSearch #49

Merged
merged 11 commits into from
Dec 3, 2024
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/bin/bash -e

#QUERY_SET=`curl -s -X POST "http://localhost:9200/_plugins/search_quality_eval/queryset?name=test&description=fake&sampling=pptss" | jq .query_set | tr -d '"'`
curl -s -X POST "http://localhost:9200/_plugins/search_quality_eval/queryset?name=test&description=fake&sampling=pptss&query_set_size=5000"
curl -s -X POST "http://localhost:9200/_plugins/search_quality_eval/queryset?name=test&description=fake&sampling=pptss&query_set_size=100"

#echo ${QUERY_SET}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/bin/bash -e

curl -s "http://localhost:9200/search_quality_eval_query_sets_run_results/_search" | jq
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/bin/bash -e

curl -s "http://localhost:9200/search_quality_eval_query_sets/_search" | jq
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#!/bin/bash -e

QUERY_SET_ID="${1}"

curl -s -X POST "http://localhost:9200/_plugins/search_quality_eval/run?id=${QUERY_SET_ID}" | jq
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.index.IndexRequest;
import org.opensearch.action.index.IndexResponse;
import org.opensearch.action.support.WriteRequest;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
Expand Down Expand Up @@ -139,10 +140,25 @@ public void runJob(final ScheduledJobParameter jobParameter, final JobExecutionC
job.put("invocation", "scheduled");
job.put("max_rank", searchQualityEvaluationJobParameter.getMaxRank());

final IndexRequest indexRequest = new IndexRequest().index(SearchQualityEvaluationPlugin.COMPLETED_JOBS_INDEX_NAME)
.id(UUID.randomUUID().toString()).source(job).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);

client.index(indexRequest).get();
final String judgmentsId = UUID.randomUUID().toString();

final IndexRequest indexRequest = new IndexRequest()
.index(SearchQualityEvaluationPlugin.COMPLETED_JOBS_INDEX_NAME)
.id(judgmentsId)
.source(job)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);

client.index(indexRequest, new ActionListener<>() {
@Override
public void onResponse(IndexResponse indexResponse) {
LOGGER.info("Successfully indexed implicit judgments {}", judgmentsId);
}

@Override
public void onFailure(Exception ex) {
LOGGER.error("Unable to index implicit judgments", ex);
}
});

}, exception -> { throw new IllegalStateException("Failed to acquire lock."); }));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ public class SearchQualityEvaluationPlugin extends Plugin implements ActionPlugi
*/
public static final String QUERY_SETS_INDEX_NAME = "search_quality_eval_query_sets";

/**
* The name of the index that stores the query set run results.
*/
public static final String QUERY_SETS_RUN_RESULTS = "search_quality_eval_query_sets_run_results";

@Override
public Collection<Object> createComponents(
final Client client,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicBoolean;

public class SearchQualityEvaluationRestHandler extends BaseRestHandler {

Expand Down Expand Up @@ -98,7 +99,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
if (AllQueriesQuerySampler.NAME.equalsIgnoreCase(sampling)) {

// If we are not sampling queries, the query sets should just be directly
// indexed into OpenSearch using the `ubu_queries` index directly.
// indexed into OpenSearch using the `ubi_queries` index directly.

try {

Expand Down Expand Up @@ -148,20 +149,24 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
} else if(QUERYSET_RUN_URL.equalsIgnoreCase(request.path())) {

final String querySetId = request.param("id");
final String judgmentsId = request.param("judgments");
jzonthemtn marked this conversation as resolved.
Show resolved Hide resolved

if(querySetId == null || judgmentsId == null) {
return restChannel -> restChannel.sendResponse(new BytesRestResponse(RestStatus.BAD_REQUEST, "{\"error\": \"Missing required parameters.\"}"));
}

try {

final OpenSearchQuerySetRunner openSearchQuerySetRunner = new OpenSearchQuerySetRunner(client);
final QuerySetRunResult querySetRunResult = openSearchQuerySetRunner.run(querySetId);

// TODO: Index the querySetRunResult.
final QuerySetRunResult querySetRunResult = openSearchQuerySetRunner.run(querySetId, judgmentsId);
openSearchQuerySetRunner.save(querySetRunResult);

} catch (Exception ex) {
LOGGER.error("Unable to retrieve query set with ID {}", querySetId);
LOGGER.error("Unable to run query set with ID {}: ", querySetId, ex);
return restChannel -> restChannel.sendResponse(new BytesRestResponse(RestStatus.INTERNAL_SERVER_ERROR, ex.getMessage()));
}

return restChannel -> restChannel.sendResponse(new BytesRestResponse(RestStatus.OK, "{\"message\": \"Query set " + querySetId + " run initiated.\"}"));
return restChannel -> restChannel.sendResponse(new BytesRestResponse(RestStatus.OK, "{\"message\": \"Run initiated for query set " + querySetId + "\"}"));

// Handle the on-demand creation of implicit judgments.
} else if(IMPLICIT_JUDGMENTS_URL.equalsIgnoreCase(request.path())) {
Expand Down Expand Up @@ -196,16 +201,35 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
job.put("invocation", "on_demand");
job.put("max_rank", maxRank);

final IndexRequest indexRequest = new IndexRequest().index(SearchQualityEvaluationPlugin.COMPLETED_JOBS_INDEX_NAME)
.id(UUID.randomUUID().toString()).source(job).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
final String judgmentsId = UUID.randomUUID().toString();

try {
client.index(indexRequest).get();
} catch (Exception e) {
throw new RuntimeException(e);
}
final IndexRequest indexRequest = new IndexRequest()
.index(SearchQualityEvaluationPlugin.COMPLETED_JOBS_INDEX_NAME)
.id(judgmentsId)
.source(job)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);

final AtomicBoolean success = new AtomicBoolean(false);

client.index(indexRequest, new ActionListener<>() {
@Override
public void onResponse(final IndexResponse indexResponse) {
LOGGER.debug("Judgments indexed: {}", judgmentsId);
success.set(true);
}

return restChannel -> restChannel.sendResponse(new BytesRestResponse(RestStatus.OK, "{\"message\": \"Implicit judgment generation initiated.\"}"));
@Override
public void onFailure(final Exception ex) {
LOGGER.error("Unable to index judgment with ID {}", judgmentsId, ex);
success.set(false);
}
});

if(success.get()) {
return restChannel -> restChannel.sendResponse(new BytesRestResponse(RestStatus.OK, "{\"judgments_id\": \"" + judgmentsId + "\"}"));
} else {
return restChannel -> restChannel.sendResponse(new BytesRestResponse(RestStatus.INTERNAL_SERVER_ERROR,"Unable to index judgments."));
}

} else {
return restChannel -> restChannel.sendResponse(new BytesRestResponse(RestStatus.BAD_REQUEST, "{\"error\": \"Invalid click model.\"}"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,44 @@
*/
package org.opensearch.eval.runners;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.index.IndexRequest;
import org.opensearch.action.index.IndexResponse;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.client.Client;
import org.opensearch.core.action.ActionListener;
import org.opensearch.eval.SearchQualityEvaluationPlugin;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.search.SearchHit;
import org.opensearch.search.builder.SearchSourceBuilder;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class OpenSearchQuerySetRunner extends QuerySetRunner {
/**
* A {@link QuerySetRunner} for Amazon OpenSearch.
*/
public class OpenSearchQuerySetRunner implements QuerySetRunner {

private static final Logger LOGGER = LogManager.getLogger(OpenSearchQuerySetRunner.class);

final Client client;

/**
* Creates a new query set runner
* @param client An OpenSearch {@link Client}.
*/
public OpenSearchQuerySetRunner(final Client client) {
this.client = client;
}

@Override
public QuerySetRunResult run(String querySetId) {
public QuerySetRunResult run(final String querySetId, final String judgmentsId) {

// Get the query set.
final SearchSourceBuilder getQuerySetSearchSourceBuilder = new SearchSourceBuilder();
Expand All @@ -40,38 +56,77 @@ public QuerySetRunResult run(String querySetId) {

try {

// TODO: Don't use .get()
final SearchResponse searchResponse = client.search(getQuerySetSearchRequest).get();

// The queries from the query set that will be run.
final Collection<String> queries = (Collection<String>) searchResponse.getHits().getAt(0).getSourceAsMap().get("queries");
final Collection<Map<String, Long>> queries = (Collection<Map<String, Long>>) searchResponse.getHits().getAt(0).getSourceAsMap().get("queries");

// The results of each query.
final Collection<QueryResult> queryResults = new ArrayList<>();
final List<QueryResult> queryResults = new ArrayList<>();

// TODO: Initiate the running of the query set.
for(final String query : queries) {
for(Map<String, Long> queryMap : queries) {

// TODO: What should this query be?
final SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
searchSourceBuilder.query(QueryBuilders.matchQuery("title", query));
// TODO: Just fetch the id ("asin") field and not all the unnecessary fields.
// Loop over each query in the map and run each one.
for (final String query : queryMap.keySet()) {

// TODO: Allow for setting this index name.
final SearchRequest searchRequest = new SearchRequest("ecommerce");
getQuerySetSearchRequest.source(getQuerySetSearchSourceBuilder);
// TODO: Allow the user to pass these values in.
final String index = "ecommerce";
final String idField = "asin";
jzonthemtn marked this conversation as resolved.
Show resolved Hide resolved

final SearchResponse sr = client.search(searchRequest).get();
// TODO: Allow the user to pass this in.
final String q = "{\n" +
" \"query\": {\n" +
" \"match\": {\n" +
" \"description\": {\n" +
" \"query\": \" + query + \"\n" +
" }\n" +
" }\n" +
" }\n" +
"}";

final List<String> orderedDocumentIds = new ArrayList<>();
// TODO: What should this query be?
jzonthemtn marked this conversation as resolved.
Show resolved Hide resolved
final SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
//searchSourceBuilder.query(QueryBuilders.wrapperQuery(q));
searchSourceBuilder.query(QueryBuilders.matchQuery("description", query));
searchSourceBuilder.from(0);
searchSourceBuilder.size(10);
jzonthemtn marked this conversation as resolved.
Show resolved Hide resolved

for(final SearchHit hit : sr.getHits().getHits()) {
String[] includeFields = new String[] {idField};
String[] excludeFields = new String[] {};
searchSourceBuilder.fetchSource(includeFields, excludeFields);

// TODO: This field needs to be customizable.
orderedDocumentIds.add(hit.getFields().get("asin").toString());
// TODO: Allow for setting this index name.
final SearchRequest searchRequest = new SearchRequest(index);
getQuerySetSearchRequest.source(searchSourceBuilder);

}
client.search(searchRequest, new ActionListener<>() {

@Override
public void onResponse(final SearchResponse searchResponse) {

final List<String> orderedDocumentIds = new ArrayList<>();

for (final SearchHit hit : searchResponse.getHits().getHits()) {

final Map<String, Object> sourceAsMap = hit.getSourceAsMap();
final String documentId = sourceAsMap.get(idField).toString();

orderedDocumentIds.add(documentId);

}

queryResults.add(new QueryResult(query, orderedDocumentIds));

}

@Override
public void onFailure(Exception ex) {
LOGGER.error("Unable to search for query: {}", query, ex);
}
});

queryResults.add(new QueryResult(orderedDocumentIds));
}

}

Expand All @@ -86,4 +141,31 @@ public QuerySetRunResult run(String querySetId) {

}

@Override
public void save(final QuerySetRunResult result) throws Exception {

// Index the results into OpenSearch.

final Map<String, Object> results = new HashMap<>();

results.put("run_id", result.getRunId());
results.put("search_metrics", result.getSearchMetrics().getSearchMetricsAsMap());
results.put("query_results", result.getQueryResultsAsMap());

final IndexRequest indexRequest = new IndexRequest(SearchQualityEvaluationPlugin.QUERY_SETS_RUN_RESULTS);
indexRequest.source(results);

client.index(indexRequest, new ActionListener<>() {
@Override
public void onResponse(IndexResponse indexResponse) {
LOGGER.debug("Query set results indexed.");
}

@Override
public void onFailure(Exception ex) {
throw new RuntimeException(ex);
}
});
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,50 @@

import java.util.List;

/**
* Contains the search results for a query.
*/
public class QueryResult {

private final String query;
private final List<String> orderedDocumentIds;

public QueryResult(final List<String> orderedDocumentIds) {
// TODO: Calculate these metrics.
private final SearchMetrics searchMetrics = new SearchMetrics();

/**
* Creates the search results.
* @param query The query used to generate this result.
* @param orderedDocumentIds A list of ordered document IDs in the same order as they appeared
* in the query.
*/
public QueryResult(final String query, final List<String> orderedDocumentIds) {
this.query = query;
this.orderedDocumentIds = orderedDocumentIds;
}

/**
* Gets the query used to generate this result.
* @return The query used to generate this result.
*/
public String getQuery() {
return query;
}

/**
* Gets the list of ordered document IDs.
* @return A list of ordered documented IDs.
*/
public List<String> getOrderedDocumentIds() {
return orderedDocumentIds;
}

/**
* Gets the search metrics for this query.
* @return The {@link SearchMetrics} for this query.
*/
public SearchMetrics getSearchMetrics() {
return searchMetrics;
}

}
Loading
Loading