Skip to content

Commit

Permalink
fix image downloader failing with automl format on deserialize due to…
Browse files Browse the repository at this point in the history
… missing label transformations
  • Loading branch information
imatiach-msft committed Dec 1, 2023
1 parent 26e8cca commit 08ab314
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -787,7 +787,8 @@ def _format_od_labels(self, y, class_names):
object_labels_lst = [0] * len(class_names)
for detection in image:
# tracking number of same objects in the image
object_labels_lst[int(detection[0] - 1)] += 1
object_index = int(detection[0] - 1)
object_labels_lst[object_index] += 1
formatted_labels.append(object_labels_lst)

return formatted_labels
Expand Down Expand Up @@ -855,11 +856,12 @@ def _save_ext_data(self, path):
os.makedirs(mltable_directory, exist_ok=True)
mltable_data_dict = {}
if self.test_mltable_path:
mltable_dir = self.test_mltable_path.split('/')[-1]
test_mltable_path = Path(self.test_mltable_path)
mltable_dir = test_mltable_path.name
mltable_data_dict[_TEST_MLTABLE_PATH] = mltable_dir
test_dir = mltable_directory / mltable_dir
shutil.copytree(
Path(self.test_mltable_path), test_dir
test_mltable_path, test_dir
)
if mltable_data_dict:
dict_path = mltable_directory / _MLTABLE_METADATA_FILENAME
Expand Down Expand Up @@ -1095,12 +1097,14 @@ def _load_ext_data(inst, path):
mltable_dict = {}
with open(mltable_dict_path, 'r') as file:
mltable_dict = json.load(file)

if mltable_dict.get(_TEST_MLTABLE_PATH, ''):
inst.test_mltable_path = str(mltable_directory / mltable_dict[
_TEST_MLTABLE_PATH])
test_dataset = inst._image_downloader(inst.test_mltable_path)
inst.test = test_dataset._images_df
if inst.task_type == ModelTask.OBJECT_DETECTION.value:
inst.test = transform_object_detection_labels(
inst.test, target_column, inst._classes)

@staticmethod
def _load_transformations(inst, path):
Expand Down
10 changes: 10 additions & 0 deletions responsibleai_vision/tests/common_vision_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,16 @@ def _get_model_path(self, path):
return os.path.join(path, 'image-classification-model')


class ObjectDetectionPipelineSerializer(object):
def save(self, model, path):
pass

def load(self, path):
return retrieve_fridge_object_detection_model(
load_fridge_weights=True
)


class DummyFlowersPipelineSerializer(object):
def save(self, model, path):
pass
Expand Down
49 changes: 34 additions & 15 deletions responsibleai_vision/tests/rai_vision_insights_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@ def validate_rai_vision_insights(
test_data,
target_column,
task_type,
ignore_index=False
ignore_index=False,
ignore_test_data=False
):
rai_vision_test = rai_vision_insights.test
if ignore_index:
rai_vision_test = rai_vision_test.reset_index(drop=True)
test_data = test_data.reset_index(drop=True)
pd.testing.assert_frame_equal(rai_vision_test, test_data)
if not ignore_test_data:
rai_vision_test = rai_vision_insights.test
if ignore_index:
rai_vision_test = rai_vision_test.reset_index(drop=True)
test_data = test_data.reset_index(drop=True)
pd.testing.assert_frame_equal(rai_vision_test, test_data)
assert rai_vision_insights.target_column == target_column
assert rai_vision_insights.task_type == task_type

Expand All @@ -38,7 +40,9 @@ def run_and_validate_serialization(
class_names,
label,
serializer,
image_width=None
image_width=None,
ignore_test_data=False,
image_downloader=None
):
"""Run and validate serialization.
Expand All @@ -57,17 +61,28 @@ def run_and_validate_serialization(
:param image_width: Image width in inches
:type image_width: int
"""
rai_insights = RAIVisionInsights(
pred, test, label,
task_type=task_type,
classes=class_names,
serializer=serializer,
image_width=image_width)

with TemporaryDirectory() as tmpdir:
save_1 = Path(tmpdir) / "first_save"
save_2 = Path(tmpdir) / "second_save"

test_data_path = None
if image_downloader is not None:
test_data_path = str(Path(tmpdir) / "fake_downloaded_test_data")
dir_path = Path(test_data_path)
dir_path.mkdir(exist_ok=True, parents=True)
fake_file_path = dir_path / 'fake_file.txt'
with open(fake_file_path, 'w') as file:
file.write("fake content")

rai_insights = RAIVisionInsights(
pred, test, label,
test_data_path=test_data_path,
task_type=task_type,
classes=class_names,
serializer=serializer,
image_width=image_width,
image_downloader=image_downloader)

# Save it
rai_insights.save(save_1)
assert len(os.listdir(save_1 / ManagerNames.EXPLAINER)) == 0
Expand All @@ -80,7 +95,11 @@ def run_and_validate_serialization(
# Validate
validate_rai_vision_insights(
rai_2, test,
label, task_type)
label, task_type,
ignore_test_data=ignore_test_data)

# Test calling get_data works
rai_2.get_data()

# Save again
rai_2.save(save_2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,36 @@
import shutil
from pathlib import Path
from tempfile import TemporaryDirectory
from unittest.mock import patch

import numpy as np
import PIL
import pytest
from common_vision_utils import (DummyFlowersPipelineSerializer,
ImageClassificationPipelineSerializer,
ObjectDetectionPipelineSerializer,
create_dummy_model,
create_image_classification_pipeline,
load_flowers_dataset, load_imagenet_dataset,
load_imagenet_labels)
load_flowers_dataset,
load_fridge_object_detection_dataset,
load_imagenet_dataset, load_imagenet_labels,
retrieve_fridge_object_detection_model)
from rai_vision_insights_validator import run_and_validate_serialization

from responsibleai_vision import ModelTask, RAIVisionInsights
from responsibleai_vision.common.constants import ImageColumns

FRIDGE_CLASS_NAMES = np.array(['can', 'carton',
'milk_bottle', 'water_bottle'])


class FakeImageDownloader:
def __init__(self, test_mltable_path):
self._images_df = self.get_data()

def get_data(self):
return None


class TestRAIVisionInsightsSaveAndLoadScenarios(object):

Expand Down Expand Up @@ -78,3 +94,37 @@ def test_loading_rai_insights_without_model_file(self):
with pytest.raises(OSError, match=match_msg):
without_model_rai_insights = RAIVisionInsights.load(save_path)
assert without_model_rai_insights.model is None

@pytest.mark.parametrize('automl_format', [True, False])
def test_rai_insights_object_detection(self, automl_format):
data = load_fridge_object_detection_dataset(automl_format)
model = retrieve_fridge_object_detection_model(
load_fridge_weights=True
)
task_type = ModelTask.OBJECT_DETECTION
test = data[:3]
label = ImageColumns.LABEL
serializer = ObjectDetectionPipelineSerializer()

run_and_validate_serialization(
model, test, task_type, FRIDGE_CLASS_NAMES, label,
serializer, ignore_test_data=True)

def test_rai_insights_image_downloader_object_detection(self):
data = load_fridge_object_detection_dataset(True)
model = retrieve_fridge_object_detection_model(
load_fridge_weights=True
)
task_type = ModelTask.OBJECT_DETECTION
test = data[:3]
label = ImageColumns.LABEL
serializer = ObjectDetectionPipelineSerializer()

get_data = ('test_rai_vision_insights_save_and_load_scenarios'
'.FakeImageDownloader.get_data')
with patch(get_data) as mock_images_df:
mock_images_df.return_value = test.copy()
run_and_validate_serialization(
model, test, task_type, FRIDGE_CLASS_NAMES, label,
serializer, ignore_test_data=True,
image_downloader=FakeImageDownloader)

0 comments on commit 08ab314

Please sign in to comment.