diff --git a/mmcls/datasets/pipelines/__init__.py b/mmcls/datasets/pipelines/__init__.py index 5103b0de979..0289c446d6e 100644 --- a/mmcls/datasets/pipelines/__init__.py +++ b/mmcls/datasets/pipelines/__init__.py @@ -1,5 +1,7 @@ -from .auto_augment import (ColorTransform, Invert, Posterize, Rotate, Shear, - Solarize, Translate) +from .auto_augment import (AutoAugment, AutoContrast, Brightness, + ColorTransform, Contrast, Equalize, Invert, + Posterize, Rotate, Sharpness, Shear, Solarize, + Translate) from .compose import Compose from .formating import (Collect, ImageToTensor, ToNumpy, ToPIL, ToTensor, Transpose, to_tensor) @@ -12,5 +14,6 @@ 'Transpose', 'Collect', 'LoadImageFromFile', 'Resize', 'CenterCrop', 'RandomFlip', 'Normalize', 'RandomCrop', 'RandomResizedCrop', 'RandomGrayscale', 'Shear', 'Translate', 'Rotate', 'Invert', - 'ColorTransform', 'Solarize', 'Posterize' + 'ColorTransform', 'Solarize', 'Posterize', 'AutoContrast', 'Equalize', + 'Contrast', 'Brightness', 'Sharpness', 'AutoAugment' ] diff --git a/mmcls/datasets/pipelines/auto_augment.py b/mmcls/datasets/pipelines/auto_augment.py index ce1b82402cd..1dd66e2ac4c 100644 --- a/mmcls/datasets/pipelines/auto_augment.py +++ b/mmcls/datasets/pipelines/auto_augment.py @@ -1,7 +1,10 @@ +import copy + import mmcv import numpy as np from ..builder import PIPELINES +from .compose import Compose def random_negative(value, random_negative_prob): @@ -9,6 +12,44 @@ def random_negative(value, random_negative_prob): return -value if np.random.rand() < random_negative_prob else value +@PIPELINES.register_module() +class AutoAugment(object): + """Auto augmentation. + This data augmentation is proposed in `AutoAugment: Learning Augmentation + Policies from Data `_. + + Args: + policies (list[list[dict]]): The policies of auto augmentation. Each + policy in ``policies`` is a specific augmentation policy, and is + composed by several augmentations (dict). When AutoAugment is + called, a random policy in ``policies`` will be selected to + augment images. + """ + + def __init__(self, policies): + assert isinstance(policies, list) and len(policies) > 0, \ + 'Policies must be a non-empty list.' + for policy in policies: + assert isinstance(policy, list) and len(policy) > 0, \ + 'Each policy in policies must be a non-empty list.' + for augment in policy: + assert isinstance(augment, dict) and 'type' in augment, \ + 'Each specific augmentation must be a dict with key' \ + ' "type".' + + self.policies = copy.deepcopy(policies) + self.sub_policy = [Compose(policy) for policy in self.policies] + + def __call__(self, results): + sub_policy = np.random.choice(self.sub_policy) + return sub_policy(results) + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(policies={self.policies})' + return repr_str + + @PIPELINES.register_module() class Shear(object): """Shear images. @@ -261,6 +302,36 @@ def __repr__(self): return repr_str +@PIPELINES.register_module() +class AutoContrast(object): + """Auto adjust image contrast. + + Args: + prob (float): The probability for performing invert therefore should + be in range [0, 1]. Defaults to 0.5. + """ + + def __init__(self, prob=0.5): + assert 0 <= prob <= 1.0, 'The prob should be in range [0,1], ' \ + f'got {prob} instead.' + + self.prob = prob + + def __call__(self, results): + if np.random.rand() > self.prob: + return results + for key in results.get('img_fields', ['img']): + img = results[key] + img_contrasted = mmcv.auto_contrast(img) + results[key] = img_contrasted.astype(img.dtype) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(prob={self.prob})' + return repr_str + + @PIPELINES.register_module() class Invert(object): """Invert images. @@ -292,52 +363,39 @@ def __repr__(self): @PIPELINES.register_module() -class ColorTransform(object): - """Adjust the color balance of images. +class Equalize(object): + """Equalize the image histogram. Args: - magnitude (int | float): The magnitude used for color transform. A - positive magnitude would enhance the color and a negative magnitude - would make the image grayer. A magnitude=0 gives the origin img. - prob (float): The probability for performing ColorTransform therefore - should be in range [0, 1]. Defaults to 0.5. - random_negative_prob (float): The probability that turns the magnitude - negative, which should be in range [0,1]. Defaults to 0.5. + prob (float): The probability for performing invert therefore should + be in range [0, 1]. Defaults to 0.5. """ - def __init__(self, magnitude, prob=0.5, random_negative_prob=0.5): - assert isinstance(magnitude, (int, float)), 'The magnitude type must '\ - f'be int or float, but got {type(magnitude)} instead.' + def __init__(self, prob=0.5): assert 0 <= prob <= 1.0, 'The prob should be in range [0,1], ' \ f'got {prob} instead.' - assert 0 <= random_negative_prob <= 1.0, 'The random_negative_prob ' \ - f'should be in range [0,1], got {random_negative_prob} instead.' - self.magnitude = magnitude self.prob = prob - self.random_negative_prob = random_negative_prob def __call__(self, results): if np.random.rand() > self.prob: return results - magnitude = random_negative(self.magnitude, self.random_negative_prob) for key in results.get('img_fields', ['img']): img = results[key] - img_color_adjusted = mmcv.adjust_color(img, alpha=1 + magnitude) - results[key] = img_color_adjusted.astype(img.dtype) + img_equalized = mmcv.imequalize(img) + results[key] = img_equalized.astype(img.dtype) return results def __repr__(self): repr_str = self.__class__.__name__ - repr_str += f'(magnitude={self.magnitude}, ' - repr_str += f'prob={self.prob}, ' - repr_str += f'random_negative_prob={self.random_negative_prob})' + repr_str += f'(prob={self.prob})' return repr_str @PIPELINES.register_module() class Solarize(object): - """Solarize an image (invert all pixel values above a threshold). + """Solarize images (invert all pixel values above a threshold). + Args: thr (int | float): The threshold above which the pixels value will be inverted. @@ -372,7 +430,8 @@ def __repr__(self): @PIPELINES.register_module() class Posterize(object): - """Posterize an image (reduce the number of bits for each color channel). + """Posterize images (reduce the number of bits for each color channel). + Args: bits (int): Number of bits for each pixel in the output img, which should be less or equal to 8. @@ -404,3 +463,182 @@ def __repr__(self): repr_str += f'(bits={self.bits}, ' repr_str += f'prob={self.prob})' return repr_str + + +@PIPELINES.register_module() +class Contrast(object): + """Adjust images contrast. + + Args: + magnitude (int | float): The magnitude used for adjusting contrast. A + positive magnitude would enhance the contrast and a negative + magnitude would make the image grayer. A magnitude=0 gives the + origin img. + prob (float): The probability for performing contrast adjusting + therefore should be in range [0, 1]. Defaults to 0.5. + random_negative_prob (float): The probability that turns the magnitude + negative, which should be in range [0,1]. Defaults to 0.5. + """ + + def __init__(self, magnitude, prob=0.5, random_negative_prob=0.5): + assert isinstance(magnitude, (int, float)), 'The magnitude type must '\ + f'be int or float, but got {type(magnitude)} instead.' + assert 0 <= prob <= 1.0, 'The prob should be in range [0,1], ' \ + f'got {prob} instead.' + assert 0 <= random_negative_prob <= 1.0, 'The random_negative_prob ' \ + f'should be in range [0,1], got {random_negative_prob} instead.' + + self.magnitude = magnitude + self.prob = prob + self.random_negative_prob = random_negative_prob + + def __call__(self, results): + if np.random.rand() > self.prob: + return results + magnitude = random_negative(self.magnitude, self.random_negative_prob) + for key in results.get('img_fields', ['img']): + img = results[key] + img_contrasted = mmcv.adjust_contrast(img, factor=1 + magnitude) + results[key] = img_contrasted.astype(img.dtype) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(magnitude={self.magnitude}, ' + repr_str += f'prob={self.prob}, ' + repr_str += f'random_negative_prob={self.random_negative_prob})' + return repr_str + + +@PIPELINES.register_module() +class ColorTransform(object): + """Adjust images color balance. + + Args: + magnitude (int | float): The magnitude used for color transform. A + positive magnitude would enhance the color and a negative magnitude + would make the image grayer. A magnitude=0 gives the origin img. + prob (float): The probability for performing ColorTransform therefore + should be in range [0, 1]. Defaults to 0.5. + random_negative_prob (float): The probability that turns the magnitude + negative, which should be in range [0,1]. Defaults to 0.5. + """ + + def __init__(self, magnitude, prob=0.5, random_negative_prob=0.5): + assert isinstance(magnitude, (int, float)), 'The magnitude type must '\ + f'be int or float, but got {type(magnitude)} instead.' + assert 0 <= prob <= 1.0, 'The prob should be in range [0,1], ' \ + f'got {prob} instead.' + assert 0 <= random_negative_prob <= 1.0, 'The random_negative_prob ' \ + f'should be in range [0,1], got {random_negative_prob} instead.' + + self.magnitude = magnitude + self.prob = prob + self.random_negative_prob = random_negative_prob + + def __call__(self, results): + if np.random.rand() > self.prob: + return results + magnitude = random_negative(self.magnitude, self.random_negative_prob) + for key in results.get('img_fields', ['img']): + img = results[key] + img_color_adjusted = mmcv.adjust_color(img, alpha=1 + magnitude) + results[key] = img_color_adjusted.astype(img.dtype) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(magnitude={self.magnitude}, ' + repr_str += f'prob={self.prob}, ' + repr_str += f'random_negative_prob={self.random_negative_prob})' + return repr_str + + +@PIPELINES.register_module() +class Brightness(object): + """Adjust images brightness. + + Args: + magnitude (int | float): The magnitude used for adjusting brightness. A + positive magnitude would enhance the brightness and a negative + magnitude would make the image darker. A magnitude=0 gives the + origin img. + prob (float): The probability for performing contrast adjusting + therefore should be in range [0, 1]. Defaults to 0.5. + random_negative_prob (float): The probability that turns the magnitude + negative, which should be in range [0,1]. Defaults to 0.5. + """ + + def __init__(self, magnitude, prob=0.5, random_negative_prob=0.5): + assert isinstance(magnitude, (int, float)), 'The magnitude type must '\ + f'be int or float, but got {type(magnitude)} instead.' + assert 0 <= prob <= 1.0, 'The prob should be in range [0,1], ' \ + f'got {prob} instead.' + assert 0 <= random_negative_prob <= 1.0, 'The random_negative_prob ' \ + f'should be in range [0,1], got {random_negative_prob} instead.' + + self.magnitude = magnitude + self.prob = prob + self.random_negative_prob = random_negative_prob + + def __call__(self, results): + if np.random.rand() > self.prob: + return results + magnitude = random_negative(self.magnitude, self.random_negative_prob) + for key in results.get('img_fields', ['img']): + img = results[key] + img_brightened = mmcv.adjust_brightness(img, factor=1 + magnitude) + results[key] = img_brightened.astype(img.dtype) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(magnitude={self.magnitude}, ' + repr_str += f'prob={self.prob}, ' + repr_str += f'random_negative_prob={self.random_negative_prob})' + return repr_str + + +@PIPELINES.register_module() +class Sharpness(object): + """Adjust images sharpness. + + Args: + magnitude (int | float): The magnitude used for adjusting sharpness. A + positive magnitude would enhance the sharpness and a negative + magnitude would make the image bulr. A magnitude=0 gives the + origin img. + prob (float): The probability for performing contrast adjusting + therefore should be in range [0, 1]. Defaults to 0.5. + random_negative_prob (float): The probability that turns the magnitude + negative, which should be in range [0,1]. Defaults to 0.5. + """ + + def __init__(self, magnitude, prob=0.5, random_negative_prob=0.5): + assert isinstance(magnitude, (int, float)), 'The magnitude type must '\ + f'be int or float, but got {type(magnitude)} instead.' + assert 0 <= prob <= 1.0, 'The prob should be in range [0,1], ' \ + f'got {prob} instead.' + assert 0 <= random_negative_prob <= 1.0, 'The random_negative_prob ' \ + f'should be in range [0,1], got {random_negative_prob} instead.' + + self.magnitude = magnitude + self.prob = prob + self.random_negative_prob = random_negative_prob + + def __call__(self, results): + if np.random.rand() > self.prob: + return results + magnitude = random_negative(self.magnitude, self.random_negative_prob) + for key in results.get('img_fields', ['img']): + img = results[key] + img_sharpened = mmcv.adjust_sharpness(img, factor=1 + magnitude) + results[key] = img_sharpened.astype(img.dtype) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(magnitude={self.magnitude}, ' + repr_str += f'prob={self.prob}, ' + repr_str += f'random_negative_prob={self.random_negative_prob})' + return repr_str diff --git a/tests/test_pipelines/test_auto_augment.py b/tests/test_pipelines/test_auto_augment.py index 4376f1f5160..9c750af0b31 100644 --- a/tests/test_pipelines/test_auto_augment.py +++ b/tests/test_pipelines/test_auto_augment.py @@ -122,6 +122,14 @@ def test_shear(): sheared_img = np.stack([sheared_img, sheared_img, sheared_img], axis=-1) assert (results['img'] == sheared_img).all() + # test auto aug with shear + results = construct_toy_data() + policies = [[transform]] + autoaug = dict(type='AutoAugment', policies=policies) + pipeline = build_from_cfg(autoaug, PIPELINES) + results = pipeline(results) + assert (results['img'] == sheared_img).all() + def test_translate(): # test assertion for invalid type of magnitude @@ -326,6 +334,34 @@ def test_rotate(): assert (results['img'] == results['img2']).all() +def test_auto_contrast(): + # test assertion for invalid value of prob + with pytest.raises(AssertionError): + transform = dict(type='AutoContrast', prob=100) + build_from_cfg(transform, PIPELINES) + + # test case when prob=0, therefore no auto_contrast + results = construct_toy_data() + transform = dict(type='AutoContrast', prob=0.) + pipeline = build_from_cfg(transform, PIPELINES) + results = pipeline(results) + assert (results['img'] == results['ori_img']).all() + + # test case when prob=1 + results = construct_toy_data() + transform = dict(type='AutoContrast', prob=1.) + pipeline = build_from_cfg(transform, PIPELINES) + results = pipeline(results) + auto_contrasted_img = np.array( + [[0, 23, 46, 69], [92, 115, 139, 162], [185, 208, 231, 255]], + dtype=np.uint8) + auto_contrasted_img = np.stack( + [auto_contrasted_img, auto_contrasted_img, auto_contrasted_img], + axis=-1) + assert (results['img'] == auto_contrasted_img).all() + assert (results['img'] == results['img2']).all() + + def test_invert(): # test assertion for invalid value of prob with pytest.raises(AssertionError): @@ -353,70 +389,37 @@ def test_invert(): assert (results['img'] == results['img2']).all() -def test_color_transform(): - # test assertion for invalid type of magnitude - with pytest.raises(AssertionError): - transform = dict(type='ColorTransform', magnitude=None) - build_from_cfg(transform, PIPELINES) +def test_equalize(nb_rand_test=100): - # test assertion for invalid value of prob - with pytest.raises(AssertionError): - transform = dict(type='ColorTransform', magnitude=0.5, prob=100) - build_from_cfg(transform, PIPELINES) + def _imequalize(img): + # equalize the image using PIL.ImageOps.equalize + from PIL import ImageOps, Image + img = Image.fromarray(img) + equalized_img = np.asarray(ImageOps.equalize(img)) + return equalized_img - # test assertion for invalid value of random_negative_prob + # test assertion for invalid value of prob with pytest.raises(AssertionError): - transform = dict( - type='ColorTransform', magnitude=0.5, random_negative_prob=100) + transform = dict(type='Equalize', prob=100) build_from_cfg(transform, PIPELINES) - # test case when magnitude=0, therefore no color transform - results = construct_toy_data_photometric() - transform = dict(type='ColorTransform', magnitude=0., prob=1.) - pipeline = build_from_cfg(transform, PIPELINES) - results = pipeline(results) - assert (results['img'] == results['ori_img']).all() - - # test case when prob=0, therefore no color transform - results = construct_toy_data_photometric() - transform = dict(type='ColorTransform', magnitude=1., prob=0.) + # test case when prob=0, therefore no equalize + results = construct_toy_data() + transform = dict(type='Equalize', prob=0.) pipeline = build_from_cfg(transform, PIPELINES) results = pipeline(results) assert (results['img'] == results['ori_img']).all() - # test case when magnitude=-1, therefore got gray img - results = construct_toy_data_photometric() - transform = dict( - type='ColorTransform', magnitude=-1., prob=1., random_negative_prob=0) - pipeline = build_from_cfg(transform, PIPELINES) - results = pipeline(results) - img_gray = mmcv.bgr2gray(results['ori_img']) - img_gray = np.stack([img_gray, img_gray, img_gray], axis=-1) - assert (results['img'] == img_gray).all() - - # test case when magnitude=0.5 - results = construct_toy_data_photometric() - transform = dict( - type='ColorTransform', magnitude=.5, prob=1., random_negative_prob=0) - pipeline = build_from_cfg(transform, PIPELINES) - results = pipeline(results) - img_r = np.round( - np.clip((results['ori_img'] * 0.5 + img_gray * 0.5), 0, - 255)).astype(results['ori_img'].dtype) - assert (results['img'] == img_r).all() - assert (results['img'] == results['img2']).all() - - # test case when magnitude=0.3, random_negative_prob=1 - results = construct_toy_data_photometric() - transform = dict( - type='ColorTransform', magnitude=.3, prob=1., random_negative_prob=1.) + # test case when prob=1 with randomly sampled image. + results = construct_toy_data() + transform = dict(type='Equalize', prob=1.) pipeline = build_from_cfg(transform, PIPELINES) - results = pipeline(results) - img_r = np.round( - np.clip((results['ori_img'] * 0.7 + img_gray * 0.3), 0, - 255)).astype(results['ori_img'].dtype) - assert (results['img'] == img_r).all() - assert (results['img'] == results['img2']).all() + for _ in range(nb_rand_test): + img = np.clip(np.random.normal(0, 1, (1000, 1200, 3)) * 260, 0, + 255).astype(np.uint8) + results['img'] = img + results = pipeline(copy.deepcopy(results)) + assert (results['img'] == _imequalize(img)).all() def test_solarize(): @@ -512,3 +515,259 @@ def test_posterize(): axis=-1) assert (results['img'] == img_posterized).all() assert (results['img'] == results['img2']).all() + + +def test_contrast(nb_rand_test=100): + + def _adjust_contrast(img, factor): + from PIL.ImageEnhance import Contrast + from PIL import Image + # Image.fromarray defaultly supports RGB, not BGR. + # convert from BGR to RGB + img = Image.fromarray(img[..., ::-1], mode='RGB') + contrasted_img = Contrast(img).enhance(factor) + # convert from RGB to BGR + return np.asarray(contrasted_img)[..., ::-1] + + # test assertion for invalid type of magnitude + with pytest.raises(AssertionError): + transform = dict(type='Contrast', magnitude=None) + build_from_cfg(transform, PIPELINES) + + # test assertion for invalid value of prob + with pytest.raises(AssertionError): + transform = dict(type='Contrast', magnitude=0.5, prob=100) + build_from_cfg(transform, PIPELINES) + + # test assertion for invalid value of random_negative_prob + with pytest.raises(AssertionError): + transform = dict( + type='Contrast', magnitude=0.5, random_negative_prob=100) + build_from_cfg(transform, PIPELINES) + + # test case when magnitude=0, therefore no adjusting contrast + results = construct_toy_data_photometric() + transform = dict(type='Contrast', magnitude=0., prob=1.) + pipeline = build_from_cfg(transform, PIPELINES) + results = pipeline(results) + assert (results['img'] == results['ori_img']).all() + + # test case when prob=0, therefore no adjusting contrast + results = construct_toy_data_photometric() + transform = dict(type='Contrast', magnitude=1., prob=0.) + pipeline = build_from_cfg(transform, PIPELINES) + results = pipeline(results) + assert (results['img'] == results['ori_img']).all() + + # test case when prob=1 with randomly sampled image. + results = construct_toy_data() + for _ in range(nb_rand_test): + magnitude = np.random.uniform() * np.random.choice([-1, 1]) + transform = dict( + type='Contrast', + magnitude=magnitude, + prob=1., + random_negative_prob=0.) + pipeline = build_from_cfg(transform, PIPELINES) + img = np.clip(np.random.uniform(0, 1, (1200, 1000, 3)) * 260, 0, + 255).astype(np.uint8) + results['img'] = img + results = pipeline(copy.deepcopy(results)) + # Note the gap (less_equal 1) between PIL.ImageEnhance.Contrast + # and mmcv.adjust_contrast comes from the gap that converts from + # a color image to gray image using mmcv or PIL. + np.testing.assert_allclose( + results['img'], + _adjust_contrast(img, 1 + magnitude), + rtol=0, + atol=1) + + +def test_color_transform(): + # test assertion for invalid type of magnitude + with pytest.raises(AssertionError): + transform = dict(type='ColorTransform', magnitude=None) + build_from_cfg(transform, PIPELINES) + + # test assertion for invalid value of prob + with pytest.raises(AssertionError): + transform = dict(type='ColorTransform', magnitude=0.5, prob=100) + build_from_cfg(transform, PIPELINES) + + # test assertion for invalid value of random_negative_prob + with pytest.raises(AssertionError): + transform = dict( + type='ColorTransform', magnitude=0.5, random_negative_prob=100) + build_from_cfg(transform, PIPELINES) + + # test case when magnitude=0, therefore no color transform + results = construct_toy_data_photometric() + transform = dict(type='ColorTransform', magnitude=0., prob=1.) + pipeline = build_from_cfg(transform, PIPELINES) + results = pipeline(results) + assert (results['img'] == results['ori_img']).all() + + # test case when prob=0, therefore no color transform + results = construct_toy_data_photometric() + transform = dict(type='ColorTransform', magnitude=1., prob=0.) + pipeline = build_from_cfg(transform, PIPELINES) + results = pipeline(results) + assert (results['img'] == results['ori_img']).all() + + # test case when magnitude=-1, therefore got gray img + results = construct_toy_data_photometric() + transform = dict( + type='ColorTransform', magnitude=-1., prob=1., random_negative_prob=0) + pipeline = build_from_cfg(transform, PIPELINES) + results = pipeline(results) + img_gray = mmcv.bgr2gray(results['ori_img']) + img_gray = np.stack([img_gray, img_gray, img_gray], axis=-1) + assert (results['img'] == img_gray).all() + + # test case when magnitude=0.5 + results = construct_toy_data_photometric() + transform = dict( + type='ColorTransform', magnitude=.5, prob=1., random_negative_prob=0) + pipeline = build_from_cfg(transform, PIPELINES) + results = pipeline(results) + img_r = np.round( + np.clip((results['ori_img'] * 0.5 + img_gray * 0.5), 0, + 255)).astype(results['ori_img'].dtype) + assert (results['img'] == img_r).all() + assert (results['img'] == results['img2']).all() + + # test case when magnitude=0.3, random_negative_prob=1 + results = construct_toy_data_photometric() + transform = dict( + type='ColorTransform', magnitude=.3, prob=1., random_negative_prob=1.) + pipeline = build_from_cfg(transform, PIPELINES) + results = pipeline(results) + img_r = np.round( + np.clip((results['ori_img'] * 0.7 + img_gray * 0.3), 0, + 255)).astype(results['ori_img'].dtype) + assert (results['img'] == img_r).all() + assert (results['img'] == results['img2']).all() + + +def test_brightness(nb_rand_test=100): + + def _adjust_brightness(img, factor): + # adjust the brightness of image using + # PIL.ImageEnhance.Brightness + from PIL.ImageEnhance import Brightness + from PIL import Image + img = Image.fromarray(img) + brightened_img = Brightness(img).enhance(factor) + return np.asarray(brightened_img) + + # test assertion for invalid type of magnitude + with pytest.raises(AssertionError): + transform = dict(type='Brightness', magnitude=None) + build_from_cfg(transform, PIPELINES) + + # test assertion for invalid value of prob + with pytest.raises(AssertionError): + transform = dict(type='Brightness', magnitude=0.5, prob=100) + build_from_cfg(transform, PIPELINES) + + # test assertion for invalid value of random_negative_prob + with pytest.raises(AssertionError): + transform = dict( + type='Brightness', magnitude=0.5, random_negative_prob=100) + build_from_cfg(transform, PIPELINES) + + # test case when magnitude=0, therefore no adjusting brightness + results = construct_toy_data_photometric() + transform = dict(type='Brightness', magnitude=0., prob=1.) + pipeline = build_from_cfg(transform, PIPELINES) + results = pipeline(results) + assert (results['img'] == results['ori_img']).all() + + # test case when prob=0, therefore no adjusting brightness + results = construct_toy_data_photometric() + transform = dict(type='Brightness', magnitude=1., prob=0.) + pipeline = build_from_cfg(transform, PIPELINES) + results = pipeline(results) + assert (results['img'] == results['ori_img']).all() + + # test case when prob=1 with randomly sampled image. + results = construct_toy_data() + for _ in range(nb_rand_test): + magnitude = np.random.uniform() * np.random.choice([-1, 1]) + transform = dict( + type='Brightness', + magnitude=magnitude, + prob=1., + random_negative_prob=0.) + pipeline = build_from_cfg(transform, PIPELINES) + img = np.clip(np.random.uniform(0, 1, (1200, 1000, 3)) * 260, 0, + 255).astype(np.uint8) + results['img'] = img + results = pipeline(copy.deepcopy(results)) + np.testing.assert_allclose( + results['img'], + _adjust_brightness(img, 1 + magnitude), + rtol=0, + atol=1) + + +def test_sharpness(nb_rand_test=100): + + def _adjust_sharpness(img, factor): + # adjust the sharpness of image using + # PIL.ImageEnhance.Sharpness + from PIL.ImageEnhance import Sharpness + from PIL import Image + img = Image.fromarray(img) + sharpened_img = Sharpness(img).enhance(factor) + return np.asarray(sharpened_img) + + # test assertion for invalid type of magnitude + with pytest.raises(AssertionError): + transform = dict(type='Sharpness', magnitude=None) + build_from_cfg(transform, PIPELINES) + + # test assertion for invalid value of prob + with pytest.raises(AssertionError): + transform = dict(type='Sharpness', magnitude=0.5, prob=100) + build_from_cfg(transform, PIPELINES) + + # test assertion for invalid value of random_negative_prob + with pytest.raises(AssertionError): + transform = dict( + type='Sharpness', magnitude=0.5, random_negative_prob=100) + build_from_cfg(transform, PIPELINES) + + # test case when magnitude=0, therefore no adjusting sharpness + results = construct_toy_data_photometric() + transform = dict(type='Sharpness', magnitude=0., prob=1.) + pipeline = build_from_cfg(transform, PIPELINES) + results = pipeline(results) + assert (results['img'] == results['ori_img']).all() + + # test case when prob=0, therefore no adjusting sharpness + results = construct_toy_data_photometric() + transform = dict(type='Sharpness', magnitude=1., prob=0.) + pipeline = build_from_cfg(transform, PIPELINES) + results = pipeline(results) + assert (results['img'] == results['ori_img']).all() + + # test case when prob=1 with randomly sampled image. + results = construct_toy_data() + for _ in range(nb_rand_test): + magnitude = np.random.uniform() * np.random.choice([-1, 1]) + transform = dict( + type='Sharpness', + magnitude=magnitude, + prob=1., + random_negative_prob=0.) + pipeline = build_from_cfg(transform, PIPELINES) + img = np.clip(np.random.uniform(0, 1, (1200, 1000, 3)) * 260, 0, + 255).astype(np.uint8) + results['img'] = img + results = pipeline(copy.deepcopy(results)) + np.testing.assert_allclose( + results['img'][1:-1, 1:-1], + _adjust_sharpness(img, 1 + magnitude)[1:-1, 1:-1], + rtol=0, + atol=1)