Skip to content

Commit

Permalink
fix RangeIndex error in error analysis for object detection models
Browse files Browse the repository at this point in the history
  • Loading branch information
imatiach-msft committed Nov 13, 2023
1 parent 73dc621 commit 4625c32
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,22 @@ def __init__(self, model, dataset, image_mode, transformations,
self.predictions = np.array(predictions_joined)
self.predict_proba = self.model.predict_proba(test)

def index_predictions(self, index, predictions):
"""Index the predictions.
:param index: The index to use.
:type index: list
:param predictions: The predictions to index.
:type predictions: list
"""
if not isinstance(index, list):
index = list(index)
if isinstance(predictions, list):
predictions = [predictions[i] for i in index]
else:
predictions = predictions[index]
return predictions

def predict(self, X):
"""Predict the class labels for the provided data.
Expand All @@ -124,7 +140,7 @@ def predict(self, X):
:rtype: list
"""
index = X.index
predictions = self.predictions[index]
predictions = self.index_predictions(index, self.predictions)
if self.task_type == ModelTask.MULTILABEL_IMAGE_CLASSIFICATION or \
self.task_type == ModelTask.OBJECT_DETECTION:
return predictions
Expand All @@ -141,7 +157,7 @@ def predict_proba(self, X):
:rtype: list[list]
"""
index = X.index
pred_proba = self.predict_proba[index]
pred_proba = self.index_predictions(index, self.predict_proba)
return pred_proba


Expand Down
9 changes: 7 additions & 2 deletions responsibleai_vision/tests/rai_vision_insights_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,14 @@ def validate_rai_vision_insights(
rai_vision_insights,
test_data,
target_column,
task_type
task_type,
ignore_index=False
):
pd.testing.assert_frame_equal(rai_vision_insights.test, test_data)
rai_vision_test = rai_vision_insights.test
if ignore_index:
rai_vision_test = rai_vision_test.reset_index(drop=True)
test_data = test_data.reset_index(drop=True)
pd.testing.assert_frame_equal(rai_vision_test, test_data)
assert rai_vision_insights.target_column == target_column
assert rai_vision_insights.task_type == task_type

Expand Down
39 changes: 37 additions & 2 deletions responsibleai_vision/tests/test_rai_vision_insights.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
load_multilabel_fridge_dataset,
retrieve_fridge_object_detection_model,
retrieve_or_train_fridge_model)
from ml_wrappers import wrap_model
from ml_wrappers.common.constants import Device
from rai_vision_insights_validator import validate_rai_vision_insights

Expand Down Expand Up @@ -119,6 +120,38 @@ def test_rai_insights_object_detection_fridge(self, num_masks, mask_res):
num_masks=num_masks, mask_res=mask_res,
test_error_analysis=True)

def test_rai_insights_object_detection_jagged_list(self):
data = load_fridge_object_detection_dataset()
model = retrieve_fridge_object_detection_model(
load_fridge_weights=True
)
task_type = ModelTask.OBJECT_DETECTION
wrapped_model = wrap_model(model, data, ModelTask.OBJECT_DETECTION)

class DummyPredictWrapper(object):
def __init__(self, model):
self.model = model
self._model = model._model

def to(self, dummy):
pass

def predict(self, X):
return self.model.predict(X).tolist()

def predict_proba(self, X):
return self.model.predict_proba(X)

wrapped_model = DummyPredictWrapper(wrapped_model)
class_names = np.array(['can', 'carton',
'milk_bottle', 'water_bottle'])
# test case where there are different numbers of objects in labels
data = data.iloc[[1, 50, 120]]
run_rai_insights(wrapped_model, data, ImageColumns.LABEL,
task_type, class_names,
test_error_analysis=True,
ignore_index=True)

@pytest.mark.parametrize('num_masks', [-100, -1, 0])
def test_rai_insights_invalid_num_masks(self, num_masks):
data = load_fridge_object_detection_dataset()
Expand Down Expand Up @@ -264,7 +297,8 @@ def run_rai_insights(model, test_data, target_column,
upscale=False, max_evals=DEFAULT_MAX_EVALS,
num_masks=DEFAULT_NUM_MASKS,
mask_res=DEFAULT_MASK_RES,
device=Device.AUTO.value):
device=Device.AUTO.value,
ignore_index=False):
feature_metadata = None
if dropped_features:
feature_metadata = FeatureMetadata(dropped_features=dropped_features)
Expand Down Expand Up @@ -294,7 +328,8 @@ def run_rai_insights(model, test_data, target_column,
# Validate
validate_rai_vision_insights(
rai_insights, test_data,
target_column, task_type)
target_column, task_type,
ignore_index)
if task_type == ModelTask.OBJECT_DETECTION:
selection_indexes = [[0]]
aggregate_method = 'Macro'
Expand Down

0 comments on commit 4625c32

Please sign in to comment.