Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Tests for autoaugment.""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import random | |
from absl.testing import parameterized | |
import numpy as np | |
import tensorflow as tf, tf_keras | |
from official.vision.ops import augment | |
def get_dtype_test_cases(): | |
return [ | |
('uint8', tf.uint8), | |
('int32', tf.int32), | |
('float16', tf.float16), | |
('float32', tf.float32), | |
] | |
class TransformsTest(parameterized.TestCase, tf.test.TestCase): | |
"""Basic tests for fundamental transformations.""" | |
def test_to_from_4d(self, dtype): | |
for shape in [(10, 10), (10, 10, 10), (10, 10, 10, 10)]: | |
original_ndims = len(shape) | |
image = tf.zeros(shape, dtype=dtype) | |
image_4d = augment.to_4d(image) | |
self.assertEqual(4, tf.rank(image_4d)) | |
self.assertAllEqual(image, augment.from_4d(image_4d, original_ndims)) | |
def test_transform(self, dtype): | |
image = tf.constant([[1, 2], [3, 4]], dtype=dtype) | |
self.assertAllEqual( | |
augment.transform(image, transforms=[1] * 8), [[4, 4], [4, 4]]) | |
def test_translate(self, dtype): | |
image = tf.constant( | |
[[1, 0, 1, 0], [0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1]], dtype=dtype) | |
translations = [-1, -1] | |
translated = augment.translate(image=image, translations=translations) | |
expected = [[1, 0, 1, 1], [0, 1, 0, 0], [1, 0, 1, 1], [1, 0, 1, 1]] | |
self.assertAllEqual(translated, expected) | |
def test_translate_shapes(self, dtype): | |
translation = [0, 0] | |
for shape in [(3, 3), (5, 5), (224, 224, 3)]: | |
image = tf.zeros(shape, dtype=dtype) | |
self.assertAllEqual(image, augment.translate(image, translation)) | |
def test_translate_invalid_translation(self, dtype): | |
image = tf.zeros((1, 1), dtype=dtype) | |
invalid_translation = [[[1, 1]]] | |
with self.assertRaisesRegex(TypeError, 'rank 1 or 2'): | |
_ = augment.translate(image, invalid_translation) | |
def test_rotate(self, dtype): | |
image = tf.reshape(tf.cast(tf.range(9), dtype), (3, 3)) | |
rotation = 90. | |
transformed = augment.rotate(image=image, degrees=rotation) | |
expected = [[2, 5, 8], [1, 4, 7], [0, 3, 6]] | |
self.assertAllEqual(transformed, expected) | |
def test_rotate_shapes(self, dtype): | |
degrees = 0. | |
for shape in [(3, 3), (5, 5), (224, 224, 3)]: | |
image = tf.zeros(shape, dtype=dtype) | |
self.assertAllEqual(image, augment.rotate(image, degrees)) | |
def test_random_cutout_video(self, dtype): | |
for num_channels in (1, 2, 3): | |
video = tf.ones((2, 2, 2, num_channels), dtype=dtype) | |
video = augment.cutout_video(video) | |
num_zeros = np.sum(video == 0) | |
self.assertGreater(num_zeros, 0) | |
def test_cutout_video_with_fixed_shape(self, dtype): | |
tf.random.set_seed(0) | |
video = tf.ones((10, 10, 10, 1), dtype=dtype) | |
video = augment.cutout_video(video, mask_shape=tf.constant([2, 2, 2])) | |
num_zeros = np.sum(video == 0) | |
self.assertEqual(num_zeros, 8) | |
class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase): | |
AVAILABLE_POLICIES = [ | |
'v0', | |
'test', | |
'simple', | |
'reduced_cifar10', | |
'svhn', | |
'reduced_imagenet', | |
'detection_v0', | |
'vit', | |
] | |
def test_autoaugment(self): | |
"""Smoke test to be sure there are no syntax errors.""" | |
image = tf.zeros((224, 224, 3), dtype=tf.uint8) | |
for policy in self.AVAILABLE_POLICIES: | |
augmenter = augment.AutoAugment(augmentation_name=policy) | |
aug_image = augmenter.distort(image) | |
self.assertEqual((224, 224, 3), aug_image.shape) | |
def test_autoaugment_with_bboxes(self): | |
"""Smoke test to be sure there are no syntax errors with bboxes.""" | |
image = tf.zeros((224, 224, 3), dtype=tf.uint8) | |
bboxes = tf.ones((2, 4), dtype=tf.float32) | |
for policy in self.AVAILABLE_POLICIES: | |
augmenter = augment.AutoAugment(augmentation_name=policy) | |
aug_image, aug_bboxes = augmenter.distort_with_boxes(image, bboxes) | |
self.assertEqual((224, 224, 3), aug_image.shape) | |
self.assertEqual((2, 4), aug_bboxes.shape) | |
def test_randaug(self): | |
"""Smoke test to be sure there are no syntax errors.""" | |
image = tf.zeros((224, 224, 3), dtype=tf.uint8) | |
augmenter = augment.RandAugment() | |
aug_image = augmenter.distort(image) | |
self.assertEqual((224, 224, 3), aug_image.shape) | |
def test_randaug_with_bboxes(self): | |
"""Smoke test to be sure there are no syntax errors with bboxes.""" | |
image = tf.zeros((224, 224, 3), dtype=tf.uint8) | |
bboxes = tf.ones((2, 4), dtype=tf.float32) | |
augmenter = augment.RandAugment() | |
aug_image, aug_bboxes = augmenter.distort_with_boxes(image, bboxes) | |
self.assertEqual((224, 224, 3), aug_image.shape) | |
self.assertEqual((2, 4), aug_bboxes.shape) | |
def test_randaug_build_for_detection(self): | |
"""Smoke test to be sure there are no syntax errors built for detection.""" | |
image = tf.zeros((224, 224, 3), dtype=tf.uint8) | |
bboxes = tf.ones((2, 4), dtype=tf.float32) | |
augmenter = augment.RandAugment.build_for_detection() | |
self.assertCountEqual(augmenter.available_ops, [ | |
'AutoContrast', 'Equalize', 'Invert', 'Posterize', 'Solarize', 'Color', | |
'Contrast', 'Brightness', 'Sharpness', 'Cutout', 'SolarizeAdd', | |
'Rotate_BBox', 'ShearX_BBox', 'ShearY_BBox', 'TranslateX_BBox', | |
'TranslateY_BBox' | |
]) | |
aug_image, aug_bboxes = augmenter.distort_with_boxes(image, bboxes) | |
self.assertEqual((224, 224, 3), aug_image.shape) | |
self.assertEqual((2, 4), aug_bboxes.shape) | |
def test_all_policy_ops(self): | |
"""Smoke test to be sure all augmentation functions can execute.""" | |
prob = 1 | |
magnitude = 10 | |
replace_value = [128] * 3 | |
cutout_const = 100 | |
translate_const = 250 | |
image = tf.ones((224, 224, 3), dtype=tf.uint8) | |
bboxes = None | |
for op_name in augment.NAME_TO_FUNC.keys() - augment.REQUIRE_BOXES_FUNCS: | |
func, _, args = augment._parse_policy_info(op_name, prob, magnitude, | |
replace_value, cutout_const, | |
translate_const) | |
image, bboxes = func(image, bboxes, *args) | |
self.assertEqual((224, 224, 3), image.shape) | |
self.assertIsNone(bboxes) | |
def test_all_policy_ops_with_bboxes(self): | |
"""Smoke test to be sure all augmentation functions can execute.""" | |
prob = 1 | |
magnitude = 10 | |
replace_value = [128] * 3 | |
cutout_const = 100 | |
translate_const = 250 | |
image = tf.ones((224, 224, 3), dtype=tf.uint8) | |
bboxes = tf.ones((2, 4), dtype=tf.float32) | |
for op_name in augment.NAME_TO_FUNC: | |
func, _, args = augment._parse_policy_info(op_name, prob, magnitude, | |
replace_value, cutout_const, | |
translate_const) | |
image, bboxes = func(image, bboxes, *args) | |
self.assertEqual((224, 224, 3), image.shape) | |
self.assertEqual((2, 4), bboxes.shape) | |
def test_autoaugment_video(self): | |
"""Smoke test with video to be sure there are no syntax errors.""" | |
image = tf.zeros((2, 224, 224, 3), dtype=tf.uint8) | |
for policy in self.AVAILABLE_POLICIES: | |
augmenter = augment.AutoAugment(augmentation_name=policy) | |
aug_image = augmenter.distort(image) | |
self.assertEqual((2, 224, 224, 3), aug_image.shape) | |
def test_autoaugment_video_with_boxes(self): | |
"""Smoke test with video to be sure there are no syntax errors.""" | |
image = tf.zeros((2, 224, 224, 3), dtype=tf.uint8) | |
bboxes = tf.ones((2, 2, 4), dtype=tf.float32) | |
for policy in self.AVAILABLE_POLICIES: | |
augmenter = augment.AutoAugment(augmentation_name=policy) | |
aug_image, aug_bboxes = augmenter.distort_with_boxes(image, bboxes) | |
self.assertEqual((2, 224, 224, 3), aug_image.shape) | |
self.assertEqual((2, 2, 4), aug_bboxes.shape) | |
def test_randaug_video(self): | |
"""Smoke test with video to be sure there are no syntax errors.""" | |
image = tf.zeros((2, 224, 224, 3), dtype=tf.uint8) | |
augmenter = augment.RandAugment() | |
aug_image = augmenter.distort(image) | |
self.assertEqual((2, 224, 224, 3), aug_image.shape) | |
def test_all_policy_ops_video(self): | |
"""Smoke test to be sure all video augmentation functions can execute.""" | |
prob = 1 | |
magnitude = 10 | |
replace_value = [128] * 3 | |
cutout_const = 100 | |
translate_const = 250 | |
image = tf.ones((2, 224, 224, 3), dtype=tf.uint8) | |
bboxes = None | |
for op_name in augment.NAME_TO_FUNC.keys() - augment.REQUIRE_BOXES_FUNCS: | |
func, _, args = augment._parse_policy_info(op_name, prob, magnitude, | |
replace_value, cutout_const, | |
translate_const) | |
image, bboxes = func(image, bboxes, *args) | |
self.assertEqual((2, 224, 224, 3), image.shape) | |
self.assertIsNone(bboxes) | |
def test_all_policy_ops_video_with_bboxes(self): | |
"""Smoke test to be sure all video augmentation functions can execute.""" | |
prob = 1 | |
magnitude = 10 | |
replace_value = [128] * 3 | |
cutout_const = 100 | |
translate_const = 250 | |
image = tf.ones((2, 224, 224, 3), dtype=tf.uint8) | |
bboxes = tf.ones((2, 2, 4), dtype=tf.float32) | |
for op_name in augment.NAME_TO_FUNC: | |
func, _, args = augment._parse_policy_info(op_name, prob, magnitude, | |
replace_value, cutout_const, | |
translate_const) | |
if op_name in { | |
'Rotate_BBox', | |
'ShearX_BBox', | |
'ShearY_BBox', | |
'TranslateX_BBox', | |
'TranslateY_BBox', | |
'TranslateY_Only_BBoxes', | |
}: | |
with self.assertRaises(ValueError): | |
func(image, bboxes, *args) | |
else: | |
image, bboxes = func(image, bboxes, *args) | |
self.assertEqual((2, 224, 224, 3), image.shape) | |
self.assertEqual((2, 2, 4), bboxes.shape) | |
def _generate_test_policy(self): | |
"""Generate a test policy at random.""" | |
op_list = list(augment.NAME_TO_FUNC.keys()) | |
size = 6 | |
prob = [round(random.uniform(0., 1.), 1) for _ in range(size)] | |
mag = [round(random.uniform(0, 10)) for _ in range(size)] | |
policy = [] | |
for i in range(0, size, 2): | |
policy.append([(op_list[i], prob[i], mag[i]), | |
(op_list[i + 1], prob[i + 1], mag[i + 1])]) | |
return policy | |
def test_custom_policy(self): | |
"""Test autoaugment with a custom policy.""" | |
image = tf.zeros((224, 224, 3), dtype=tf.uint8) | |
augmenter = augment.AutoAugment(policies=self._generate_test_policy()) | |
aug_image = augmenter.distort(image) | |
self.assertEqual((224, 224, 3), aug_image.shape) | |
def test_autoaugment_three_augment(self): | |
"""Test three augmentation.""" | |
image = tf.random.normal(shape=(224, 224, 3), dtype=tf.float32) | |
augmenter = augment.AutoAugment(augmentation_name='deit3_three_augment') | |
aug_image = augmenter.distort(image) | |
self.assertEqual((224, 224, 3), aug_image.shape) | |
self.assertFalse(tf.math.reduce_all(image == aug_image)) | |
def test_invalid_custom_sub_policy(self, sub_policy, value): | |
"""Test autoaugment with out-of-range values in the custom policy.""" | |
image = tf.zeros((224, 224, 3), dtype=tf.uint8) | |
policy = self._generate_test_policy() | |
policy[0][0] = sub_policy | |
augmenter = augment.AutoAugment(policies=policy) | |
with self.assertRaisesRegex( | |
tf.errors.InvalidArgumentError, | |
r'Expected \'tf.Tensor\(False, shape=\(\), dtype=bool\)\' to be true. ' | |
r'Summarized data: ({})'.format(value)): | |
augmenter.distort(image) | |
def test_invalid_custom_policy_ndim(self): | |
"""Test autoaugment with wrong dimension in the custom policy.""" | |
policy = [[('Equalize', 0.8, 1), ('Shear', 0.8, 4)], | |
[('TranslateY', 0.6, 3), ('Rotate', 0.9, 3)]] | |
policy = [[policy]] | |
with self.assertRaisesRegex( | |
ValueError, | |
r'Expected \(:, :, 3\) but got \(1, 1, 2, 2, 3\).'): | |
augment.AutoAugment(policies=policy) | |
def test_invalid_custom_policy_shape(self): | |
"""Test autoaugment with wrong shape in the custom policy.""" | |
policy = [[('Equalize', 0.8, 1, 1), ('Shear', 0.8, 4, 1)], | |
[('TranslateY', 0.6, 3, 1), ('Rotate', 0.9, 3, 1)]] | |
with self.assertRaisesRegex( | |
ValueError, | |
r'Expected \(:, :, 3\) but got \(2, 2, 4\)'): | |
augment.AutoAugment(policies=policy) | |
def test_invalid_custom_policy_key(self): | |
"""Test autoaugment with invalid key in the custom policy.""" | |
image = tf.zeros((224, 224, 3), dtype=tf.uint8) | |
policy = [[('AAAAA', 0.8, 1), ('Shear', 0.8, 4)], | |
[('TranslateY', 0.6, 3), ('Rotate', 0.9, 3)]] | |
augmenter = augment.AutoAugment(policies=policy) | |
with self.assertRaisesRegex(KeyError, '\'AAAAA\''): | |
augmenter.distort(image) | |
class RandomErasingTest(tf.test.TestCase, parameterized.TestCase): | |
def test_random_erase_replaces_some_pixels(self): | |
image = tf.zeros((224, 224, 3), dtype=tf.float32) | |
augmenter = augment.RandomErasing(probability=1., max_count=10) | |
aug_image = augmenter.distort(image) | |
self.assertEqual((224, 224, 3), aug_image.shape) | |
self.assertNotEqual(0, tf.reduce_max(aug_image)) | |
class MixupAndCutmixTest(tf.test.TestCase, parameterized.TestCase): | |
def test_mixup_and_cutmix_smoothes_labels(self): | |
batch_size = 12 | |
num_classes = 1000 | |
label_smoothing = 0.1 | |
images = tf.random.normal((batch_size, 224, 224, 3), dtype=tf.float32) | |
labels = tf.range(batch_size) | |
augmenter = augment.MixupAndCutmix( | |
num_classes=num_classes, label_smoothing=label_smoothing) | |
aug_images, aug_labels = augmenter.distort(images, labels) | |
self.assertEqual(images.shape, aug_images.shape) | |
self.assertEqual(images.dtype, aug_images.dtype) | |
self.assertEqual([batch_size, num_classes], aug_labels.shape) | |
self.assertAllLessEqual(aug_labels, 1. - label_smoothing + | |
2. / num_classes) # With tolerance | |
self.assertAllGreaterEqual(aug_labels, label_smoothing / num_classes - | |
1e4) # With tolerance | |
def test_mixup_changes_image(self): | |
batch_size = 12 | |
num_classes = 1000 | |
label_smoothing = 0.1 | |
images = tf.random.normal((batch_size, 224, 224, 3), dtype=tf.float32) | |
labels = tf.range(batch_size) | |
augmenter = augment.MixupAndCutmix( | |
mixup_alpha=1., cutmix_alpha=0., num_classes=num_classes) | |
aug_images, aug_labels = augmenter.distort(images, labels) | |
self.assertEqual(images.shape, aug_images.shape) | |
self.assertEqual(images.dtype, aug_images.dtype) | |
self.assertEqual([batch_size, num_classes], aug_labels.shape) | |
self.assertAllLessEqual(aug_labels, 1. - label_smoothing + | |
2. / num_classes) # With tolerance | |
self.assertAllGreaterEqual(aug_labels, label_smoothing / num_classes - | |
1e4) # With tolerance | |
self.assertFalse(tf.math.reduce_all(images == aug_images)) | |
def test_cutmix_changes_image(self): | |
batch_size = 12 | |
num_classes = 1000 | |
label_smoothing = 0.1 | |
images = tf.random.normal((batch_size, 224, 224, 3), dtype=tf.float32) | |
labels = tf.range(batch_size) | |
augmenter = augment.MixupAndCutmix( | |
mixup_alpha=0., cutmix_alpha=1., num_classes=num_classes) | |
aug_images, aug_labels = augmenter.distort(images, labels) | |
self.assertEqual(images.shape, aug_images.shape) | |
self.assertEqual(images.dtype, aug_images.dtype) | |
self.assertEqual([batch_size, num_classes], aug_labels.shape) | |
self.assertAllLessEqual(aug_labels, 1. - label_smoothing + | |
2. / num_classes) # With tolerance | |
self.assertAllGreaterEqual(aug_labels, label_smoothing / num_classes - | |
1e4) # With tolerance | |
self.assertFalse(tf.math.reduce_all(images == aug_images)) | |
def test_mixup_and_cutmix_smoothes_labels_with_videos(self): | |
batch_size = 12 | |
num_classes = 1000 | |
label_smoothing = 0.1 | |
images = tf.random.normal((batch_size, 8, 224, 224, 3), dtype=tf.float32) | |
labels = tf.range(batch_size) | |
augmenter = augment.MixupAndCutmix( | |
num_classes=num_classes, label_smoothing=label_smoothing) | |
aug_images, aug_labels = augmenter.distort(images, labels) | |
self.assertEqual(images.shape, aug_images.shape) | |
self.assertEqual(images.dtype, aug_images.dtype) | |
self.assertEqual([batch_size, num_classes], aug_labels.shape) | |
self.assertAllLessEqual(aug_labels, 1. - label_smoothing + | |
2. / num_classes) # With tolerance | |
self.assertAllGreaterEqual(aug_labels, label_smoothing / num_classes - | |
1e4) # With tolerance | |
def test_mixup_changes_video(self): | |
batch_size = 12 | |
num_classes = 1000 | |
label_smoothing = 0.1 | |
images = tf.random.normal((batch_size, 8, 224, 224, 3), dtype=tf.float32) | |
labels = tf.range(batch_size) | |
augmenter = augment.MixupAndCutmix( | |
mixup_alpha=1., cutmix_alpha=0., num_classes=num_classes) | |
aug_images, aug_labels = augmenter.distort(images, labels) | |
self.assertEqual(images.shape, aug_images.shape) | |
self.assertEqual(images.dtype, aug_images.dtype) | |
self.assertEqual([batch_size, num_classes], aug_labels.shape) | |
self.assertAllLessEqual(aug_labels, 1. - label_smoothing + | |
2. / num_classes) # With tolerance | |
self.assertAllGreaterEqual(aug_labels, label_smoothing / num_classes - | |
1e4) # With tolerance | |
self.assertFalse(tf.math.reduce_all(images == aug_images)) | |
def test_cutmix_changes_video(self): | |
batch_size = 12 | |
num_classes = 1000 | |
label_smoothing = 0.1 | |
images = tf.random.normal((batch_size, 8, 224, 224, 3), dtype=tf.float32) | |
labels = tf.range(batch_size) | |
augmenter = augment.MixupAndCutmix( | |
mixup_alpha=0., cutmix_alpha=1., num_classes=num_classes) | |
aug_images, aug_labels = augmenter.distort(images, labels) | |
self.assertEqual(images.shape, aug_images.shape) | |
self.assertEqual(images.dtype, aug_images.dtype) | |
self.assertEqual([batch_size, num_classes], aug_labels.shape) | |
self.assertAllLessEqual(aug_labels, 1. - label_smoothing + | |
2. / num_classes) # With tolerance | |
self.assertAllGreaterEqual(aug_labels, label_smoothing / num_classes - | |
1e4) # With tolerance | |
self.assertFalse(tf.math.reduce_all(images == aug_images)) | |
if __name__ == '__main__': | |
tf.test.main() | |