From 41ca4a40a0ef6ddca0ead894108646e16c98b9ce Mon Sep 17 00:00:00 2001 From: jzonthemtn Date: Fri, 6 Dec 2024 16:30:11 -0500 Subject: [PATCH] Adding threshold value for calculating precision. Signed-off-by: jzonthemtn --- .../scripts/run-query-set.sh | 1 + .../opensearch/eval/SearchQualityEvaluationRestHandler.java | 3 ++- .../org/opensearch/eval/metrics/PrecisionSearchMetric.java | 4 +++- .../org/opensearch/eval/runners/AbstractQuerySetRunner.java | 6 +++++- .../opensearch/eval/runners/OpenSearchQuerySetRunner.java | 4 ++-- 5 files changed, 13 insertions(+), 5 deletions(-) diff --git a/opensearch-search-quality-evaluation-plugin/scripts/run-query-set.sh b/opensearch-search-quality-evaluation-plugin/scripts/run-query-set.sh index 9befe60..477dbd9 100755 --- a/opensearch-search-quality-evaluation-plugin/scripts/run-query-set.sh +++ b/opensearch-search-quality-evaluation-plugin/scripts/run-query-set.sh @@ -5,6 +5,7 @@ JUDGMENTS_ID="9183599e-46dd-49e0-9584-df816164a4c2" INDEX="ecommerce" ID_FIELD="asin" K="20" +THRESHOLD="1.0" # Default value curl -s -X DELETE "http://localhost:9200/search_quality_eval_query_sets_run_results" diff --git a/opensearch-search-quality-evaluation-plugin/src/main/java/org/opensearch/eval/SearchQualityEvaluationRestHandler.java b/opensearch-search-quality-evaluation-plugin/src/main/java/org/opensearch/eval/SearchQualityEvaluationRestHandler.java index 69577fa..48f4840 100644 --- a/opensearch-search-quality-evaluation-plugin/src/main/java/org/opensearch/eval/SearchQualityEvaluationRestHandler.java +++ b/opensearch-search-quality-evaluation-plugin/src/main/java/org/opensearch/eval/SearchQualityEvaluationRestHandler.java @@ -159,6 +159,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli final String searchPipeline = request.param("search_pipeline", null); final String idField = request.param("id_field", "_id"); final int k = Integer.parseInt(request.param("k", "10")); + final double threshold = Double.parseDouble(request.param("threshold", "1.0")); if(querySetId == null || querySetId.isEmpty() || judgmentsId == null || judgmentsId.isEmpty() || index == null || index.isEmpty()) { return restChannel -> restChannel.sendResponse(new BytesRestResponse(RestStatus.BAD_REQUEST, "{\"error\": \"Missing required parameters.\"}")); @@ -183,7 +184,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli try { final OpenSearchQuerySetRunner openSearchQuerySetRunner = new OpenSearchQuerySetRunner(client); - final QuerySetRunResult querySetRunResult = openSearchQuerySetRunner.run(querySetId, judgmentsId, index, searchPipeline, idField, query, k); + final QuerySetRunResult querySetRunResult = openSearchQuerySetRunner.run(querySetId, judgmentsId, index, searchPipeline, idField, query, k, threshold); openSearchQuerySetRunner.save(querySetRunResult); } catch (Exception ex) { diff --git a/opensearch-search-quality-evaluation-plugin/src/main/java/org/opensearch/eval/metrics/PrecisionSearchMetric.java b/opensearch-search-quality-evaluation-plugin/src/main/java/org/opensearch/eval/metrics/PrecisionSearchMetric.java index 6269e50..a1bcbcd 100644 --- a/opensearch-search-quality-evaluation-plugin/src/main/java/org/opensearch/eval/metrics/PrecisionSearchMetric.java +++ b/opensearch-search-quality-evaluation-plugin/src/main/java/org/opensearch/eval/metrics/PrecisionSearchMetric.java @@ -12,10 +12,12 @@ public class PrecisionSearchMetric extends SearchMetric { + private final double threshold; private final List relevanceScores; - public PrecisionSearchMetric(final int k, final List relevanceScores) { + public PrecisionSearchMetric(final int k, final double threshold, final List relevanceScores) { super(k); + this.threshold = threshold; this.relevanceScores = relevanceScores; } diff --git a/opensearch-search-quality-evaluation-plugin/src/main/java/org/opensearch/eval/runners/AbstractQuerySetRunner.java b/opensearch-search-quality-evaluation-plugin/src/main/java/org/opensearch/eval/runners/AbstractQuerySetRunner.java index ea475ab..f64b4ac 100644 --- a/opensearch-search-quality-evaluation-plugin/src/main/java/org/opensearch/eval/runners/AbstractQuerySetRunner.java +++ b/opensearch-search-quality-evaluation-plugin/src/main/java/org/opensearch/eval/runners/AbstractQuerySetRunner.java @@ -46,10 +46,14 @@ public AbstractQuerySetRunner(final Client client) { * @param idField The field in the index that is used to uniquely identify a document. * @param query The query that will be used to run the query set. * @param k The k used for metrics calculation, i.e. DCG@k. + * @param threshold The cutoff for binary judgments. A judgment score greater than or equal + * to this value will be assigned a binary judgment value of 1. A judgment score + * less than this value will be assigned a binary judgment value of 0. * @return The query set {@link QuerySetRunResult results} and calculated metrics. */ abstract QuerySetRunResult run(String querySetId, final String judgmentsId, final String index, final String searchPipeline, - final String idField, final String query, final int k) throws Exception; + final String idField, final String query, final int k, + final double threshold) throws Exception; /** * Saves the query set results to a persistent store, which may be the search engine itself. diff --git a/opensearch-search-quality-evaluation-plugin/src/main/java/org/opensearch/eval/runners/OpenSearchQuerySetRunner.java b/opensearch-search-quality-evaluation-plugin/src/main/java/org/opensearch/eval/runners/OpenSearchQuerySetRunner.java index eb323c2..1124a23 100644 --- a/opensearch-search-quality-evaluation-plugin/src/main/java/org/opensearch/eval/runners/OpenSearchQuerySetRunner.java +++ b/opensearch-search-quality-evaluation-plugin/src/main/java/org/opensearch/eval/runners/OpenSearchQuerySetRunner.java @@ -53,7 +53,7 @@ public OpenSearchQuerySetRunner(final Client client) { @Override public QuerySetRunResult run(final String querySetId, final String judgmentsId, final String index, final String searchPipeline, final String idField, final String query, - final int k) throws Exception { + final int k, final double threshold) throws Exception { final Collection> querySet = getQuerySet(querySetId); LOGGER.info("Found {} queries in query set {}", querySet.size(), querySetId); @@ -115,7 +115,7 @@ public void onResponse(final SearchResponse searchResponse) { // Calculate the metrics for this query. final SearchMetric dcgSearchMetric = new DcgSearchMetric(k, relevanceScores); final SearchMetric ndcgSearchmetric = new NdcgSearchMetric(k, relevanceScores); - final SearchMetric precisionSearchMetric = new PrecisionSearchMetric(k, relevanceScores); + final SearchMetric precisionSearchMetric = new PrecisionSearchMetric(k, threshold, relevanceScores); final Collection searchMetrics = List.of(dcgSearchMetric, ndcgSearchmetric, precisionSearchMetric);