Skip to content

Commit

Permalink
add average parameter to MeanAveragePrecision to specify micro or mac…
Browse files Browse the repository at this point in the history
…ro calculation (#2412)
  • Loading branch information
imatiach-msft authored Nov 14, 2023
1 parent 4c01e0c commit 6a5d308
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1191,6 +1191,7 @@ def compute_object_detection_metrics(
continue

metric_OD = MeanAveragePrecision(
average=aggregate_method.lower(),
class_metrics=True,
iou_thresholds=normalized_iou_threshold).to(device)
true_y_cohort = [true_y[cohort_index] for cohort_index
Expand Down
11 changes: 6 additions & 5 deletions responsibleai_vision/tests/test_rai_vision_insights.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,11 +332,12 @@ def run_rai_insights(model, test_data, target_column,
ignore_index)
if task_type == ModelTask.OBJECT_DETECTION:
selection_indexes = [[0]]
aggregate_method = 'Macro'
class_name = classes[0]
iou_threshold = 70
object_detection_cache = {}
metrics = rai_insights.compute_object_detection_metrics(
selection_indexes, aggregate_method, class_name, iou_threshold,
object_detection_cache)
assert len(metrics) == 2
aggregate_methods = ['macro', 'micro']
for aggregate_method in aggregate_methods:
metrics = rai_insights.compute_object_detection_metrics(
selection_indexes, aggregate_method, class_name, iou_threshold,
object_detection_cache)
assert len(metrics) == 2

0 comments on commit 6a5d308

Please sign in to comment.