From 4625c322d5a5aee8d6a6cc017be28b9858d71a21 Mon Sep 17 00:00:00 2001 From: Ilya Matiach Date: Fri, 3 Nov 2023 16:03:29 -0400 Subject: [PATCH] fix RangeIndex error in error analysis for object detection models --- .../managers/error_analysis_manager.py | 20 +++++++++- .../tests/rai_vision_insights_validator.py | 9 ++++- .../tests/test_rai_vision_insights.py | 39 ++++++++++++++++++- 3 files changed, 62 insertions(+), 6 deletions(-) diff --git a/responsibleai_vision/responsibleai_vision/managers/error_analysis_manager.py b/responsibleai_vision/responsibleai_vision/managers/error_analysis_manager.py index 6ac914d7a2..771d0a64eb 100644 --- a/responsibleai_vision/responsibleai_vision/managers/error_analysis_manager.py +++ b/responsibleai_vision/responsibleai_vision/managers/error_analysis_manager.py @@ -115,6 +115,22 @@ def __init__(self, model, dataset, image_mode, transformations, self.predictions = np.array(predictions_joined) self.predict_proba = self.model.predict_proba(test) + def index_predictions(self, index, predictions): + """Index the predictions. + + :param index: The index to use. + :type index: list + :param predictions: The predictions to index. + :type predictions: list + """ + if not isinstance(index, list): + index = list(index) + if isinstance(predictions, list): + predictions = [predictions[i] for i in index] + else: + predictions = predictions[index] + return predictions + def predict(self, X): """Predict the class labels for the provided data. @@ -124,7 +140,7 @@ def predict(self, X): :rtype: list """ index = X.index - predictions = self.predictions[index] + predictions = self.index_predictions(index, self.predictions) if self.task_type == ModelTask.MULTILABEL_IMAGE_CLASSIFICATION or \ self.task_type == ModelTask.OBJECT_DETECTION: return predictions @@ -141,7 +157,7 @@ def predict_proba(self, X): :rtype: list[list] """ index = X.index - pred_proba = self.predict_proba[index] + pred_proba = self.index_predictions(index, self.predict_proba) return pred_proba diff --git a/responsibleai_vision/tests/rai_vision_insights_validator.py b/responsibleai_vision/tests/rai_vision_insights_validator.py index a098802224..b91e166e90 100644 --- a/responsibleai_vision/tests/rai_vision_insights_validator.py +++ b/responsibleai_vision/tests/rai_vision_insights_validator.py @@ -19,9 +19,14 @@ def validate_rai_vision_insights( rai_vision_insights, test_data, target_column, - task_type + task_type, + ignore_index=False ): - pd.testing.assert_frame_equal(rai_vision_insights.test, test_data) + rai_vision_test = rai_vision_insights.test + if ignore_index: + rai_vision_test = rai_vision_test.reset_index(drop=True) + test_data = test_data.reset_index(drop=True) + pd.testing.assert_frame_equal(rai_vision_test, test_data) assert rai_vision_insights.target_column == target_column assert rai_vision_insights.task_type == task_type diff --git a/responsibleai_vision/tests/test_rai_vision_insights.py b/responsibleai_vision/tests/test_rai_vision_insights.py index 3acbee3ecc..a8da6e5b94 100644 --- a/responsibleai_vision/tests/test_rai_vision_insights.py +++ b/responsibleai_vision/tests/test_rai_vision_insights.py @@ -18,6 +18,7 @@ load_multilabel_fridge_dataset, retrieve_fridge_object_detection_model, retrieve_or_train_fridge_model) +from ml_wrappers import wrap_model from ml_wrappers.common.constants import Device from rai_vision_insights_validator import validate_rai_vision_insights @@ -119,6 +120,38 @@ def test_rai_insights_object_detection_fridge(self, num_masks, mask_res): num_masks=num_masks, mask_res=mask_res, test_error_analysis=True) + def test_rai_insights_object_detection_jagged_list(self): + data = load_fridge_object_detection_dataset() + model = retrieve_fridge_object_detection_model( + load_fridge_weights=True + ) + task_type = ModelTask.OBJECT_DETECTION + wrapped_model = wrap_model(model, data, ModelTask.OBJECT_DETECTION) + + class DummyPredictWrapper(object): + def __init__(self, model): + self.model = model + self._model = model._model + + def to(self, dummy): + pass + + def predict(self, X): + return self.model.predict(X).tolist() + + def predict_proba(self, X): + return self.model.predict_proba(X) + + wrapped_model = DummyPredictWrapper(wrapped_model) + class_names = np.array(['can', 'carton', + 'milk_bottle', 'water_bottle']) + # test case where there are different numbers of objects in labels + data = data.iloc[[1, 50, 120]] + run_rai_insights(wrapped_model, data, ImageColumns.LABEL, + task_type, class_names, + test_error_analysis=True, + ignore_index=True) + @pytest.mark.parametrize('num_masks', [-100, -1, 0]) def test_rai_insights_invalid_num_masks(self, num_masks): data = load_fridge_object_detection_dataset() @@ -264,7 +297,8 @@ def run_rai_insights(model, test_data, target_column, upscale=False, max_evals=DEFAULT_MAX_EVALS, num_masks=DEFAULT_NUM_MASKS, mask_res=DEFAULT_MASK_RES, - device=Device.AUTO.value): + device=Device.AUTO.value, + ignore_index=False): feature_metadata = None if dropped_features: feature_metadata = FeatureMetadata(dropped_features=dropped_features) @@ -294,7 +328,8 @@ def run_rai_insights(model, test_data, target_column, # Validate validate_rai_vision_insights( rai_insights, test_data, - target_column, task_type) + target_column, task_type, + ignore_index) if task_type == ModelTask.OBJECT_DETECTION: selection_indexes = [[0]] aggregate_method = 'Macro'