Skip to content

Commit

Permalink
Updating dcg/ndcg calculations to avoid NaN.
Browse files Browse the repository at this point in the history
  • Loading branch information
jzonthemtn committed Dec 10, 2024
1 parent d62ca4d commit 5b20319
Show file tree
Hide file tree
Showing 7 changed files with 54 additions and 26 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/bin/bash -e

curl -s "http://localhost:9200/sqe_metrics_sample_data/_search" | jq
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/bin/bash -e

QUERY_SET_ID="${1}"
JUDGMENTS_ID="9183599e-46dd-49e0-9584-df816164a4c2"
QUERY_SET_ID="955df665-9cd5-4828-bda0-9aea2d002279"
JUDGMENTS_ID="0615e159-675b-4c60-875a-24daeb8c126c"
INDEX="ecommerce"
ID_FIELD="asin"
K="20"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,9 @@ protected double calculateDcg(final List<Double> relevanceScores) {
final double numerator = Math.pow(2, relevanceScore) - 1.0;
final double denominator = Math.log(i) / Math.log(i + 2);

LOGGER.debug("numerator = {}, denominator = {}", numerator, denominator);
dcg += (numerator / denominator);
if (denominator != 0) {
dcg += (numerator / denominator);
}

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
*/
package org.opensearch.eval.metrics;

import java.util.Collections;
import java.util.Comparator;
import java.util.List;

/**
Expand All @@ -32,13 +34,27 @@ public String getName() {
@Override
public double calculate() {

// Make the ideal relevance scores by sorting the relevance scores largest to smallest.
relevanceScores.sort(Double::compare);

double dcg = super.calculate();
double idcg = super.calculateDcg(relevanceScores);

return dcg / idcg;
if(dcg == 0) {

// The ndcg is 0. No need to continue.
return 0;

} else {

// Make the ideal relevance scores by sorting the relevance scores largest to smallest.
relevanceScores.sort(Comparator.reverseOrder());

double idcg = super.calculateDcg(relevanceScores);

if(idcg == 0) {
return 0;
} else {
return dcg / idcg;
}

}

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ protected List<Double> getRelevanceScores(final String judgmentsId, final String
// Ordered list of scores.
final List<Double> scores = new ArrayList<>();

// Go through each document up to k and get the score.
// For each document (up to k), get the judgment for the document.
for (int i = 0; i < k && i < orderedDocumentIds.size(); i++) {

final String documentId = orderedDocumentIds.get(i);
Expand All @@ -177,6 +177,8 @@ protected List<Double> getRelevanceScores(final String judgmentsId, final String
// If a judgment for this query/doc pair is not found, Double.NaN will be returned.
if(!Double.isNaN(judgmentValue)) {
scores.add(judgmentValue);
} else {
LOGGER.info("No score found for document ID {} with judgments {} and query {}", documentId, judgmentsId, query);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.opensearch.eval.metrics.NdcgSearchMetric;
import org.opensearch.eval.metrics.PrecisionSearchMetric;
import org.opensearch.eval.metrics.SearchMetric;
import org.opensearch.eval.utils.TimeUtils;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.search.SearchHit;
import org.opensearch.search.builder.SearchSourceBuilder;
Expand Down Expand Up @@ -145,13 +146,15 @@ public void onFailure(Exception ex) {
final Map<String, Double> sumOfMetrics = new HashMap<>();
for(final QueryResult queryResult : queryResults) {
for(final SearchMetric searchMetric : queryResult.getSearchMetrics()) {
//LOGGER.info("Summing: {} - {}", searchMetric.getName(), searchMetric.getValue());
sumOfMetrics.merge(searchMetric.getName(), searchMetric.getValue(), Double::sum);
}
}

// Now divide by the number of queries.
final Map<String, Double> querySetMetrics = new HashMap<>();
for(final String metric : sumOfMetrics.keySet()) {
//LOGGER.info("Dividing by the query set size: {} / {}", sumOfMetrics.get(metric), querySetSize);
querySetMetrics.put(metric, sumOfMetrics.get(metric) / querySetSize);
}

Expand Down Expand Up @@ -180,38 +183,41 @@ public void save(final QuerySetRunResult result) throws Exception {

// Add each metric to the object to index.
for (final String metric : result.getSearchMetrics().keySet()) {
results.put(metric, result.getSearchMetrics().get(metric));
results.put(metric, String.valueOf(result.getSearchMetrics().get(metric)));
}

final IndexRequest indexRequest = new IndexRequest(SearchQualityEvaluationPlugin.QUERY_SETS_RUN_RESULTS_INDEX_NAME)
.id(UUID.randomUUID().toString())
.source(results);

client.index(indexRequest, new ActionListener<>() {
@Override
public void onResponse(IndexResponse indexResponse) {
LOGGER.debug("Query set results indexed.");
}

@Override
public void onFailure(Exception ex) {
throw new RuntimeException(ex);
}
});

// TODO: Index the metrics as expected by the dashboards.
client.index(indexRequest).get();
// client.index(indexRequest, new ActionListener<>() {
// @Override
// public void onResponse(IndexResponse indexResponse) {
// LOGGER.debug("Query set results indexed.");
// }
//
// @Override
// public void onFailure(Exception ex) {
// throw new RuntimeException(ex);
// }
// });

// Now, index the metrics as expected by the dashboards.

// See https://github.com/o19s/opensearch-search-quality-evaluation/blob/main/opensearch-dashboard-prototyping/METRICS_SCHEMA.md
// See https://github.com/o19s/opensearch-search-quality-evaluation/blob/main/opensearch-dashboard-prototyping/sample_data.ndjson

final BulkRequest bulkRequest = new BulkRequest();
final String timestamp = TimeUtils.getTimestamp();

for(final QueryResult queryResult : result.getQueryResults()) {

for(final SearchMetric searchMetric : queryResult.getSearchMetrics()) {

// TODO: Make sure all of these items have values.
final Map<String, Object> metrics = new HashMap<>();
metrics.put("datetime", "2024-09-01T00:00:00");
metrics.put("datetime", timestamp);
metrics.put("search_config", "research_1");
metrics.put("query_set_id", result.getQuerySetId());
metrics.put("query", queryResult.getQuery());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ public Collection<Map<String, Object>> getQueryResultsAsMap() {

// Calculate and add each metric to the map.
for(final SearchMetric searchMetric : queryResult.getSearchMetrics()) {
q.put(searchMetric.getName(), searchMetric.calculate());
q.put(searchMetric.getName(), String.valueOf(searchMetric.calculate()));
}

qs.add(q);
Expand Down

0 comments on commit 5b20319

Please sign in to comment.