Skip to content

Commit

Permalink
Adding threshold value for calculating precision.
Browse files Browse the repository at this point in the history
Signed-off-by: jzonthemtn <[email protected]>
  • Loading branch information
jzonthemtn committed Dec 6, 2024
1 parent 4b2c2d6 commit 41ca4a4
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ JUDGMENTS_ID="9183599e-46dd-49e0-9584-df816164a4c2"
INDEX="ecommerce"
ID_FIELD="asin"
K="20"
THRESHOLD="1.0" # Default value

curl -s -X DELETE "http://localhost:9200/search_quality_eval_query_sets_run_results"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
final String searchPipeline = request.param("search_pipeline", null);
final String idField = request.param("id_field", "_id");
final int k = Integer.parseInt(request.param("k", "10"));
final double threshold = Double.parseDouble(request.param("threshold", "1.0"));

if(querySetId == null || querySetId.isEmpty() || judgmentsId == null || judgmentsId.isEmpty() || index == null || index.isEmpty()) {
return restChannel -> restChannel.sendResponse(new BytesRestResponse(RestStatus.BAD_REQUEST, "{\"error\": \"Missing required parameters.\"}"));
Expand All @@ -183,7 +184,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
try {

final OpenSearchQuerySetRunner openSearchQuerySetRunner = new OpenSearchQuerySetRunner(client);
final QuerySetRunResult querySetRunResult = openSearchQuerySetRunner.run(querySetId, judgmentsId, index, searchPipeline, idField, query, k);
final QuerySetRunResult querySetRunResult = openSearchQuerySetRunner.run(querySetId, judgmentsId, index, searchPipeline, idField, query, k, threshold);
openSearchQuerySetRunner.save(querySetRunResult);

} catch (Exception ex) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@

public class PrecisionSearchMetric extends SearchMetric {

private final double threshold;
private final List<Double> relevanceScores;

public PrecisionSearchMetric(final int k, final List<Double> relevanceScores) {
public PrecisionSearchMetric(final int k, final double threshold, final List<Double> relevanceScores) {
super(k);
this.threshold = threshold;
this.relevanceScores = relevanceScores;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,14 @@ public AbstractQuerySetRunner(final Client client) {
* @param idField The field in the index that is used to uniquely identify a document.
* @param query The query that will be used to run the query set.
* @param k The k used for metrics calculation, i.e. DCG@k.
* @param threshold The cutoff for binary judgments. A judgment score greater than or equal
* to this value will be assigned a binary judgment value of 1. A judgment score
* less than this value will be assigned a binary judgment value of 0.
* @return The query set {@link QuerySetRunResult results} and calculated metrics.
*/
abstract QuerySetRunResult run(String querySetId, final String judgmentsId, final String index, final String searchPipeline,
final String idField, final String query, final int k) throws Exception;
final String idField, final String query, final int k,
final double threshold) throws Exception;

/**
* Saves the query set results to a persistent store, which may be the search engine itself.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public OpenSearchQuerySetRunner(final Client client) {
@Override
public QuerySetRunResult run(final String querySetId, final String judgmentsId, final String index,
final String searchPipeline, final String idField, final String query,
final int k) throws Exception {
final int k, final double threshold) throws Exception {

final Collection<Map<String, Long>> querySet = getQuerySet(querySetId);
LOGGER.info("Found {} queries in query set {}", querySet.size(), querySetId);
Expand Down Expand Up @@ -115,7 +115,7 @@ public void onResponse(final SearchResponse searchResponse) {
// Calculate the metrics for this query.
final SearchMetric dcgSearchMetric = new DcgSearchMetric(k, relevanceScores);
final SearchMetric ndcgSearchmetric = new NdcgSearchMetric(k, relevanceScores);
final SearchMetric precisionSearchMetric = new PrecisionSearchMetric(k, relevanceScores);
final SearchMetric precisionSearchMetric = new PrecisionSearchMetric(k, threshold, relevanceScores);

final Collection<SearchMetric> searchMetrics = List.of(dcgSearchMetric, ndcgSearchmetric, precisionSearchMetric);

Expand Down

0 comments on commit 41ca4a4

Please sign in to comment.