Skip to content

Commit

Permalink
fix image downloader failing with automl format on deserialize due to…
Browse files Browse the repository at this point in the history
… missing label transformations (#2435)
  • Loading branch information
imatiach-msft authored Dec 4, 2023
1 parent 26e8cca commit 85d6373
Show file tree
Hide file tree
Showing 8 changed files with 125 additions and 30 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/CI-e2e-notebooks-text-vision.yml
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ jobs:
- name: Install pytorch with python 3.7
shell: bash -l {0}
run: |
conda install --yes --quiet "pytorch==1.13.1" "torchvision<0.15" captum cpuonly -c pytorch
conda install --yes --quiet "pytorch==1.13.1" "torchvision<0.15" cpuonly -c pytorch
- name: Setup tools
shell: bash -l {0}
Expand All @@ -102,6 +102,7 @@ jobs:
- name: Install dependencies
shell: bash -l {0}
run: |
pip install captum
pip install -r requirements-dev.txt
pip install -v -e .
working-directory: raiwidgets
Expand Down
9 changes: 5 additions & 4 deletions .github/workflows/CI-notebook-vision.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,25 +52,25 @@ jobs:
name: Install pytorch on non-MacOS with python 3.7
shell: bash -l {0}
run: |
conda install --yes --quiet "pytorch==1.13.1" "torchvision<0.15" captum cpuonly -c pytorch
conda install --yes --quiet "pytorch==1.13.1" "torchvision<0.15" cpuonly -c pytorch
- if: ${{ matrix.operatingSystem == 'macos-latest' && matrix.pythonVersion == '3.7' }}
name: Install Anaconda packages on MacOS with python 3.7
shell: bash -l {0}
run: |
conda install --yes --quiet "pytorch==1.13.1" "torchvision<0.15" captum -c pytorch
conda install --yes --quiet "pytorch==1.13.1" "torchvision<0.15" -c pytorch
- if: ${{ matrix.operatingSystem != 'macos-latest' && matrix.pythonVersion != '3.7' }}
name: Install pytorch on non-MacOS
shell: bash -l {0}
run: |
conda install --yes --quiet "pytorch<2.1,>1.13.1" "torchvision<0.16" captum cpuonly -c pytorch
conda install --yes --quiet "pytorch<2.1,>1.13.1" "torchvision<0.16" cpuonly -c pytorch
- if: ${{ matrix.operatingSystem == 'macos-latest' && matrix.pythonVersion != '3.7' }}
name: Install Anaconda packages on MacOS, which should not include cpuonly according to official docs
shell: bash -l {0}
run: |
conda install --yes --quiet "pytorch<2.1,>1.13.1" "torchvision<0.16" captum -c pytorch
conda install --yes --quiet "pytorch<2.1,>1.13.1" "torchvision<0.16" -c pytorch
- name: Setup tools
shell: bash -l {0}
Expand All @@ -82,6 +82,7 @@ jobs:
- name: Install dependencies
shell: bash -l {0}
run: |
pip install captum
pip install -r requirements-dev.txt
pip install .
working-directory: raiwidgets
Expand Down
5 changes: 3 additions & 2 deletions .github/workflows/CI-responsibleai-text-vision-pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,13 @@ jobs:
name: Install pytorch on non-MacOS
shell: bash -l {0}
run: |
conda install --yes --quiet pytorch==1.13.1 "torchvision<0.15" captum cpuonly -c pytorch
conda install --yes --quiet pytorch==1.13.1 "torchvision<0.15" cpuonly -c pytorch
- if: ${{ matrix.operatingSystem == 'macos-latest' }}
name: Install Anaconda packages on MacOS, which should not include cpuonly according to official docs
shell: bash -l {0}
run: |
conda install --yes --quiet pytorch==1.13.1 "torchvision<0.15" captum -c pytorch
conda install --yes --quiet pytorch==1.13.1 "torchvision<0.15" -c pytorch
- name: Setup tools
shell: bash -l {0}
Expand All @@ -75,6 +75,7 @@ jobs:
- name: Install dependencies
shell: bash -l {0}
run: |
pip install captum
pip install -r requirements-dev.txt
pip install .
working-directory: ${{ matrix.packageDirectory }}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,8 @@ def _get_dataset(self):
for _, image in enumerate(images):
if isinstance(image, str):
image = get_image_from_path(image, self.image_mode)
if isinstance(image, list):
image = np.array(image)
s = io.BytesIO()
# IMshow only accepts floats in range [0, 1]
try:
Expand Down Expand Up @@ -787,7 +789,8 @@ def _format_od_labels(self, y, class_names):
object_labels_lst = [0] * len(class_names)
for detection in image:
# tracking number of same objects in the image
object_labels_lst[int(detection[0] - 1)] += 1
object_index = int(detection[0] - 1)
object_labels_lst[object_index] += 1
formatted_labels.append(object_labels_lst)

return formatted_labels
Expand Down Expand Up @@ -855,11 +858,12 @@ def _save_ext_data(self, path):
os.makedirs(mltable_directory, exist_ok=True)
mltable_data_dict = {}
if self.test_mltable_path:
mltable_dir = self.test_mltable_path.split('/')[-1]
test_mltable_path = Path(self.test_mltable_path)
mltable_dir = test_mltable_path.name
mltable_data_dict[_TEST_MLTABLE_PATH] = mltable_dir
test_dir = mltable_directory / mltable_dir
shutil.copytree(
Path(self.test_mltable_path), test_dir
test_mltable_path, test_dir
)
if mltable_data_dict:
dict_path = mltable_directory / _MLTABLE_METADATA_FILENAME
Expand Down Expand Up @@ -1095,12 +1099,14 @@ def _load_ext_data(inst, path):
mltable_dict = {}
with open(mltable_dict_path, 'r') as file:
mltable_dict = json.load(file)

if mltable_dict.get(_TEST_MLTABLE_PATH, ''):
inst.test_mltable_path = str(mltable_directory / mltable_dict[
_TEST_MLTABLE_PATH])
test_dataset = inst._image_downloader(inst.test_mltable_path)
inst.test = test_dataset._images_df
if inst.task_type == ModelTask.OBJECT_DETECTION.value:
inst.test = transform_object_detection_labels(
inst.test, target_column, inst._classes)

@staticmethod
def _load_transformations(inst, path):
Expand Down
10 changes: 10 additions & 0 deletions responsibleai_vision/tests/common_vision_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,16 @@ def _get_model_path(self, path):
return os.path.join(path, 'image-classification-model')


class ObjectDetectionPipelineSerializer(object):
def save(self, model, path):
pass

def load(self, path):
return retrieve_fridge_object_detection_model(
load_fridge_weights=True
)


class DummyFlowersPipelineSerializer(object):
def save(self, model, path):
pass
Expand Down
49 changes: 34 additions & 15 deletions responsibleai_vision/tests/rai_vision_insights_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@ def validate_rai_vision_insights(
test_data,
target_column,
task_type,
ignore_index=False
ignore_index=False,
ignore_test_data=False
):
rai_vision_test = rai_vision_insights.test
if ignore_index:
rai_vision_test = rai_vision_test.reset_index(drop=True)
test_data = test_data.reset_index(drop=True)
pd.testing.assert_frame_equal(rai_vision_test, test_data)
if not ignore_test_data:
rai_vision_test = rai_vision_insights.test
if ignore_index:
rai_vision_test = rai_vision_test.reset_index(drop=True)
test_data = test_data.reset_index(drop=True)
pd.testing.assert_frame_equal(rai_vision_test, test_data)
assert rai_vision_insights.target_column == target_column
assert rai_vision_insights.task_type == task_type

Expand All @@ -38,7 +40,9 @@ def run_and_validate_serialization(
class_names,
label,
serializer,
image_width=None
image_width=None,
ignore_test_data=False,
image_downloader=None
):
"""Run and validate serialization.
Expand All @@ -57,17 +61,28 @@ def run_and_validate_serialization(
:param image_width: Image width in inches
:type image_width: int
"""
rai_insights = RAIVisionInsights(
pred, test, label,
task_type=task_type,
classes=class_names,
serializer=serializer,
image_width=image_width)

with TemporaryDirectory() as tmpdir:
save_1 = Path(tmpdir) / "first_save"
save_2 = Path(tmpdir) / "second_save"

test_data_path = None
if image_downloader is not None:
test_data_path = str(Path(tmpdir) / "fake_downloaded_test_data")
dir_path = Path(test_data_path)
dir_path.mkdir(exist_ok=True, parents=True)
fake_file_path = dir_path / 'fake_file.txt'
with open(fake_file_path, 'w') as file:
file.write("fake content")

rai_insights = RAIVisionInsights(
pred, test, label,
test_data_path=test_data_path,
task_type=task_type,
classes=class_names,
serializer=serializer,
image_width=image_width,
image_downloader=image_downloader)

# Save it
rai_insights.save(save_1)
assert len(os.listdir(save_1 / ManagerNames.EXPLAINER)) == 0
Expand All @@ -80,7 +95,11 @@ def run_and_validate_serialization(
# Validate
validate_rai_vision_insights(
rai_2, test,
label, task_type)
label, task_type,
ignore_test_data=ignore_test_data)

# Test calling get_data works
rai_2.get_data()

# Save again
rai_2.save(save_2)
Expand Down
11 changes: 9 additions & 2 deletions responsibleai_vision/tests/test_rai_vision_insights.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,11 @@ def to(self, dummy):
pass

def predict(self, X):
return self.model.predict(X).tolist()
prediction = self.model.predict(X).tolist()
if len(prediction) == 1:
# fix ndim error for some versions of pandas in tests
prediction = prediction[0]
return prediction

def predict_proba(self, X):
return self.model.predict_proba(X)
Expand All @@ -146,7 +150,10 @@ def predict_proba(self, X):
class_names = np.array(['can', 'carton',
'milk_bottle', 'water_bottle'])
# test case where there are different numbers of objects in labels
data = data.iloc[[1, 50, 120]]
images = ['10.jpg', '29.jpg', '92.jpg']
images = ["./data/odFridgeObjects/images/{}".format(i) for i in images]
data = data.loc[data[ImageColumns.IMAGE.value].isin(images)]
data = data.reset_index(drop=True)
run_rai_insights(wrapped_model, data, ImageColumns.LABEL,
task_type, class_names,
test_error_analysis=True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,36 @@
import shutil
from pathlib import Path
from tempfile import TemporaryDirectory
from unittest.mock import patch

import numpy as np
import PIL
import pytest
from common_vision_utils import (DummyFlowersPipelineSerializer,
ImageClassificationPipelineSerializer,
ObjectDetectionPipelineSerializer,
create_dummy_model,
create_image_classification_pipeline,
load_flowers_dataset, load_imagenet_dataset,
load_imagenet_labels)
load_flowers_dataset,
load_fridge_object_detection_dataset,
load_imagenet_dataset, load_imagenet_labels,
retrieve_fridge_object_detection_model)
from rai_vision_insights_validator import run_and_validate_serialization

from responsibleai_vision import ModelTask, RAIVisionInsights
from responsibleai_vision.common.constants import ImageColumns

FRIDGE_CLASS_NAMES = np.array(['can', 'carton',
'milk_bottle', 'water_bottle'])


class FakeImageDownloader:
def __init__(self, test_mltable_path):
self._images_df = self.get_data()

def get_data(self):
return None


class TestRAIVisionInsightsSaveAndLoadScenarios(object):

Expand Down Expand Up @@ -78,3 +94,37 @@ def test_loading_rai_insights_without_model_file(self):
with pytest.raises(OSError, match=match_msg):
without_model_rai_insights = RAIVisionInsights.load(save_path)
assert without_model_rai_insights.model is None

@pytest.mark.parametrize('automl_format', [True, False])
def test_rai_insights_object_detection(self, automl_format):
data = load_fridge_object_detection_dataset(automl_format)
model = retrieve_fridge_object_detection_model(
load_fridge_weights=True
)
task_type = ModelTask.OBJECT_DETECTION
test = data[:3]
label = ImageColumns.LABEL
serializer = ObjectDetectionPipelineSerializer()

run_and_validate_serialization(
model, test, task_type, FRIDGE_CLASS_NAMES, label,
serializer, ignore_test_data=True)

def test_rai_insights_image_downloader_object_detection(self):
data = load_fridge_object_detection_dataset(True)
model = retrieve_fridge_object_detection_model(
load_fridge_weights=True
)
task_type = ModelTask.OBJECT_DETECTION
test = data[:3]
label = ImageColumns.LABEL
serializer = ObjectDetectionPipelineSerializer()

get_data = ('test_rai_vision_insights_save_and_load_scenarios'
'.FakeImageDownloader.get_data')
with patch(get_data) as mock_images_df:
mock_images_df.return_value = test.copy()
run_and_validate_serialization(
model, test, task_type, FRIDGE_CLASS_NAMES, label,
serializer, ignore_test_data=True,
image_downloader=FakeImageDownloader)

0 comments on commit 85d6373

Please sign in to comment.