deanna-emery's picture
updates
93528c6
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for autoaugment."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import random
from absl.testing import parameterized
import numpy as np
import tensorflow as tf, tf_keras
from official.vision.ops import augment
def get_dtype_test_cases():
return [
('uint8', tf.uint8),
('int32', tf.int32),
('float16', tf.float16),
('float32', tf.float32),
]
@parameterized.named_parameters(get_dtype_test_cases())
class TransformsTest(parameterized.TestCase, tf.test.TestCase):
"""Basic tests for fundamental transformations."""
def test_to_from_4d(self, dtype):
for shape in [(10, 10), (10, 10, 10), (10, 10, 10, 10)]:
original_ndims = len(shape)
image = tf.zeros(shape, dtype=dtype)
image_4d = augment.to_4d(image)
self.assertEqual(4, tf.rank(image_4d))
self.assertAllEqual(image, augment.from_4d(image_4d, original_ndims))
def test_transform(self, dtype):
image = tf.constant([[1, 2], [3, 4]], dtype=dtype)
self.assertAllEqual(
augment.transform(image, transforms=[1] * 8), [[4, 4], [4, 4]])
def test_translate(self, dtype):
image = tf.constant(
[[1, 0, 1, 0], [0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1]], dtype=dtype)
translations = [-1, -1]
translated = augment.translate(image=image, translations=translations)
expected = [[1, 0, 1, 1], [0, 1, 0, 0], [1, 0, 1, 1], [1, 0, 1, 1]]
self.assertAllEqual(translated, expected)
def test_translate_shapes(self, dtype):
translation = [0, 0]
for shape in [(3, 3), (5, 5), (224, 224, 3)]:
image = tf.zeros(shape, dtype=dtype)
self.assertAllEqual(image, augment.translate(image, translation))
def test_translate_invalid_translation(self, dtype):
image = tf.zeros((1, 1), dtype=dtype)
invalid_translation = [[[1, 1]]]
with self.assertRaisesRegex(TypeError, 'rank 1 or 2'):
_ = augment.translate(image, invalid_translation)
def test_rotate(self, dtype):
image = tf.reshape(tf.cast(tf.range(9), dtype), (3, 3))
rotation = 90.
transformed = augment.rotate(image=image, degrees=rotation)
expected = [[2, 5, 8], [1, 4, 7], [0, 3, 6]]
self.assertAllEqual(transformed, expected)
def test_rotate_shapes(self, dtype):
degrees = 0.
for shape in [(3, 3), (5, 5), (224, 224, 3)]:
image = tf.zeros(shape, dtype=dtype)
self.assertAllEqual(image, augment.rotate(image, degrees))
def test_random_cutout_video(self, dtype):
for num_channels in (1, 2, 3):
video = tf.ones((2, 2, 2, num_channels), dtype=dtype)
video = augment.cutout_video(video)
num_zeros = np.sum(video == 0)
self.assertGreater(num_zeros, 0)
def test_cutout_video_with_fixed_shape(self, dtype):
tf.random.set_seed(0)
video = tf.ones((10, 10, 10, 1), dtype=dtype)
video = augment.cutout_video(video, mask_shape=tf.constant([2, 2, 2]))
num_zeros = np.sum(video == 0)
self.assertEqual(num_zeros, 8)
class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
AVAILABLE_POLICIES = [
'v0',
'test',
'simple',
'reduced_cifar10',
'svhn',
'reduced_imagenet',
'detection_v0',
'vit',
]
def test_autoaugment(self):
"""Smoke test to be sure there are no syntax errors."""
image = tf.zeros((224, 224, 3), dtype=tf.uint8)
for policy in self.AVAILABLE_POLICIES:
augmenter = augment.AutoAugment(augmentation_name=policy)
aug_image = augmenter.distort(image)
self.assertEqual((224, 224, 3), aug_image.shape)
def test_autoaugment_with_bboxes(self):
"""Smoke test to be sure there are no syntax errors with bboxes."""
image = tf.zeros((224, 224, 3), dtype=tf.uint8)
bboxes = tf.ones((2, 4), dtype=tf.float32)
for policy in self.AVAILABLE_POLICIES:
augmenter = augment.AutoAugment(augmentation_name=policy)
aug_image, aug_bboxes = augmenter.distort_with_boxes(image, bboxes)
self.assertEqual((224, 224, 3), aug_image.shape)
self.assertEqual((2, 4), aug_bboxes.shape)
def test_randaug(self):
"""Smoke test to be sure there are no syntax errors."""
image = tf.zeros((224, 224, 3), dtype=tf.uint8)
augmenter = augment.RandAugment()
aug_image = augmenter.distort(image)
self.assertEqual((224, 224, 3), aug_image.shape)
def test_randaug_with_bboxes(self):
"""Smoke test to be sure there are no syntax errors with bboxes."""
image = tf.zeros((224, 224, 3), dtype=tf.uint8)
bboxes = tf.ones((2, 4), dtype=tf.float32)
augmenter = augment.RandAugment()
aug_image, aug_bboxes = augmenter.distort_with_boxes(image, bboxes)
self.assertEqual((224, 224, 3), aug_image.shape)
self.assertEqual((2, 4), aug_bboxes.shape)
def test_randaug_build_for_detection(self):
"""Smoke test to be sure there are no syntax errors built for detection."""
image = tf.zeros((224, 224, 3), dtype=tf.uint8)
bboxes = tf.ones((2, 4), dtype=tf.float32)
augmenter = augment.RandAugment.build_for_detection()
self.assertCountEqual(augmenter.available_ops, [
'AutoContrast', 'Equalize', 'Invert', 'Posterize', 'Solarize', 'Color',
'Contrast', 'Brightness', 'Sharpness', 'Cutout', 'SolarizeAdd',
'Rotate_BBox', 'ShearX_BBox', 'ShearY_BBox', 'TranslateX_BBox',
'TranslateY_BBox'
])
aug_image, aug_bboxes = augmenter.distort_with_boxes(image, bboxes)
self.assertEqual((224, 224, 3), aug_image.shape)
self.assertEqual((2, 4), aug_bboxes.shape)
def test_all_policy_ops(self):
"""Smoke test to be sure all augmentation functions can execute."""
prob = 1
magnitude = 10
replace_value = [128] * 3
cutout_const = 100
translate_const = 250
image = tf.ones((224, 224, 3), dtype=tf.uint8)
bboxes = None
for op_name in augment.NAME_TO_FUNC.keys() - augment.REQUIRE_BOXES_FUNCS:
func, _, args = augment._parse_policy_info(op_name, prob, magnitude,
replace_value, cutout_const,
translate_const)
image, bboxes = func(image, bboxes, *args)
self.assertEqual((224, 224, 3), image.shape)
self.assertIsNone(bboxes)
def test_all_policy_ops_with_bboxes(self):
"""Smoke test to be sure all augmentation functions can execute."""
prob = 1
magnitude = 10
replace_value = [128] * 3
cutout_const = 100
translate_const = 250
image = tf.ones((224, 224, 3), dtype=tf.uint8)
bboxes = tf.ones((2, 4), dtype=tf.float32)
for op_name in augment.NAME_TO_FUNC:
func, _, args = augment._parse_policy_info(op_name, prob, magnitude,
replace_value, cutout_const,
translate_const)
image, bboxes = func(image, bboxes, *args)
self.assertEqual((224, 224, 3), image.shape)
self.assertEqual((2, 4), bboxes.shape)
def test_autoaugment_video(self):
"""Smoke test with video to be sure there are no syntax errors."""
image = tf.zeros((2, 224, 224, 3), dtype=tf.uint8)
for policy in self.AVAILABLE_POLICIES:
augmenter = augment.AutoAugment(augmentation_name=policy)
aug_image = augmenter.distort(image)
self.assertEqual((2, 224, 224, 3), aug_image.shape)
def test_autoaugment_video_with_boxes(self):
"""Smoke test with video to be sure there are no syntax errors."""
image = tf.zeros((2, 224, 224, 3), dtype=tf.uint8)
bboxes = tf.ones((2, 2, 4), dtype=tf.float32)
for policy in self.AVAILABLE_POLICIES:
augmenter = augment.AutoAugment(augmentation_name=policy)
aug_image, aug_bboxes = augmenter.distort_with_boxes(image, bboxes)
self.assertEqual((2, 224, 224, 3), aug_image.shape)
self.assertEqual((2, 2, 4), aug_bboxes.shape)
def test_randaug_video(self):
"""Smoke test with video to be sure there are no syntax errors."""
image = tf.zeros((2, 224, 224, 3), dtype=tf.uint8)
augmenter = augment.RandAugment()
aug_image = augmenter.distort(image)
self.assertEqual((2, 224, 224, 3), aug_image.shape)
def test_all_policy_ops_video(self):
"""Smoke test to be sure all video augmentation functions can execute."""
prob = 1
magnitude = 10
replace_value = [128] * 3
cutout_const = 100
translate_const = 250
image = tf.ones((2, 224, 224, 3), dtype=tf.uint8)
bboxes = None
for op_name in augment.NAME_TO_FUNC.keys() - augment.REQUIRE_BOXES_FUNCS:
func, _, args = augment._parse_policy_info(op_name, prob, magnitude,
replace_value, cutout_const,
translate_const)
image, bboxes = func(image, bboxes, *args)
self.assertEqual((2, 224, 224, 3), image.shape)
self.assertIsNone(bboxes)
def test_all_policy_ops_video_with_bboxes(self):
"""Smoke test to be sure all video augmentation functions can execute."""
prob = 1
magnitude = 10
replace_value = [128] * 3
cutout_const = 100
translate_const = 250
image = tf.ones((2, 224, 224, 3), dtype=tf.uint8)
bboxes = tf.ones((2, 2, 4), dtype=tf.float32)
for op_name in augment.NAME_TO_FUNC:
func, _, args = augment._parse_policy_info(op_name, prob, magnitude,
replace_value, cutout_const,
translate_const)
if op_name in {
'Rotate_BBox',
'ShearX_BBox',
'ShearY_BBox',
'TranslateX_BBox',
'TranslateY_BBox',
'TranslateY_Only_BBoxes',
}:
with self.assertRaises(ValueError):
func(image, bboxes, *args)
else:
image, bboxes = func(image, bboxes, *args)
self.assertEqual((2, 224, 224, 3), image.shape)
self.assertEqual((2, 2, 4), bboxes.shape)
def _generate_test_policy(self):
"""Generate a test policy at random."""
op_list = list(augment.NAME_TO_FUNC.keys())
size = 6
prob = [round(random.uniform(0., 1.), 1) for _ in range(size)]
mag = [round(random.uniform(0, 10)) for _ in range(size)]
policy = []
for i in range(0, size, 2):
policy.append([(op_list[i], prob[i], mag[i]),
(op_list[i + 1], prob[i + 1], mag[i + 1])])
return policy
def test_custom_policy(self):
"""Test autoaugment with a custom policy."""
image = tf.zeros((224, 224, 3), dtype=tf.uint8)
augmenter = augment.AutoAugment(policies=self._generate_test_policy())
aug_image = augmenter.distort(image)
self.assertEqual((224, 224, 3), aug_image.shape)
def test_autoaugment_three_augment(self):
"""Test three augmentation."""
image = tf.random.normal(shape=(224, 224, 3), dtype=tf.float32)
augmenter = augment.AutoAugment(augmentation_name='deit3_three_augment')
aug_image = augmenter.distort(image)
self.assertEqual((224, 224, 3), aug_image.shape)
self.assertFalse(tf.math.reduce_all(image == aug_image))
@parameterized.named_parameters(
{'testcase_name': '_OutOfRangeProb',
'sub_policy': ('Equalize', 1.1, 3), 'value': '1.1'},
{'testcase_name': '_OutOfRangeMag',
'sub_policy': ('Equalize', 0.9, 11), 'value': '11'},
)
def test_invalid_custom_sub_policy(self, sub_policy, value):
"""Test autoaugment with out-of-range values in the custom policy."""
image = tf.zeros((224, 224, 3), dtype=tf.uint8)
policy = self._generate_test_policy()
policy[0][0] = sub_policy
augmenter = augment.AutoAugment(policies=policy)
with self.assertRaisesRegex(
tf.errors.InvalidArgumentError,
r'Expected \'tf.Tensor\(False, shape=\(\), dtype=bool\)\' to be true. '
r'Summarized data: ({})'.format(value)):
augmenter.distort(image)
def test_invalid_custom_policy_ndim(self):
"""Test autoaugment with wrong dimension in the custom policy."""
policy = [[('Equalize', 0.8, 1), ('Shear', 0.8, 4)],
[('TranslateY', 0.6, 3), ('Rotate', 0.9, 3)]]
policy = [[policy]]
with self.assertRaisesRegex(
ValueError,
r'Expected \(:, :, 3\) but got \(1, 1, 2, 2, 3\).'):
augment.AutoAugment(policies=policy)
def test_invalid_custom_policy_shape(self):
"""Test autoaugment with wrong shape in the custom policy."""
policy = [[('Equalize', 0.8, 1, 1), ('Shear', 0.8, 4, 1)],
[('TranslateY', 0.6, 3, 1), ('Rotate', 0.9, 3, 1)]]
with self.assertRaisesRegex(
ValueError,
r'Expected \(:, :, 3\) but got \(2, 2, 4\)'):
augment.AutoAugment(policies=policy)
def test_invalid_custom_policy_key(self):
"""Test autoaugment with invalid key in the custom policy."""
image = tf.zeros((224, 224, 3), dtype=tf.uint8)
policy = [[('AAAAA', 0.8, 1), ('Shear', 0.8, 4)],
[('TranslateY', 0.6, 3), ('Rotate', 0.9, 3)]]
augmenter = augment.AutoAugment(policies=policy)
with self.assertRaisesRegex(KeyError, '\'AAAAA\''):
augmenter.distort(image)
class RandomErasingTest(tf.test.TestCase, parameterized.TestCase):
def test_random_erase_replaces_some_pixels(self):
image = tf.zeros((224, 224, 3), dtype=tf.float32)
augmenter = augment.RandomErasing(probability=1., max_count=10)
aug_image = augmenter.distort(image)
self.assertEqual((224, 224, 3), aug_image.shape)
self.assertNotEqual(0, tf.reduce_max(aug_image))
class MixupAndCutmixTest(tf.test.TestCase, parameterized.TestCase):
def test_mixup_and_cutmix_smoothes_labels(self):
batch_size = 12
num_classes = 1000
label_smoothing = 0.1
images = tf.random.normal((batch_size, 224, 224, 3), dtype=tf.float32)
labels = tf.range(batch_size)
augmenter = augment.MixupAndCutmix(
num_classes=num_classes, label_smoothing=label_smoothing)
aug_images, aug_labels = augmenter.distort(images, labels)
self.assertEqual(images.shape, aug_images.shape)
self.assertEqual(images.dtype, aug_images.dtype)
self.assertEqual([batch_size, num_classes], aug_labels.shape)
self.assertAllLessEqual(aug_labels, 1. - label_smoothing +
2. / num_classes) # With tolerance
self.assertAllGreaterEqual(aug_labels, label_smoothing / num_classes -
1e4) # With tolerance
def test_mixup_changes_image(self):
batch_size = 12
num_classes = 1000
label_smoothing = 0.1
images = tf.random.normal((batch_size, 224, 224, 3), dtype=tf.float32)
labels = tf.range(batch_size)
augmenter = augment.MixupAndCutmix(
mixup_alpha=1., cutmix_alpha=0., num_classes=num_classes)
aug_images, aug_labels = augmenter.distort(images, labels)
self.assertEqual(images.shape, aug_images.shape)
self.assertEqual(images.dtype, aug_images.dtype)
self.assertEqual([batch_size, num_classes], aug_labels.shape)
self.assertAllLessEqual(aug_labels, 1. - label_smoothing +
2. / num_classes) # With tolerance
self.assertAllGreaterEqual(aug_labels, label_smoothing / num_classes -
1e4) # With tolerance
self.assertFalse(tf.math.reduce_all(images == aug_images))
def test_cutmix_changes_image(self):
batch_size = 12
num_classes = 1000
label_smoothing = 0.1
images = tf.random.normal((batch_size, 224, 224, 3), dtype=tf.float32)
labels = tf.range(batch_size)
augmenter = augment.MixupAndCutmix(
mixup_alpha=0., cutmix_alpha=1., num_classes=num_classes)
aug_images, aug_labels = augmenter.distort(images, labels)
self.assertEqual(images.shape, aug_images.shape)
self.assertEqual(images.dtype, aug_images.dtype)
self.assertEqual([batch_size, num_classes], aug_labels.shape)
self.assertAllLessEqual(aug_labels, 1. - label_smoothing +
2. / num_classes) # With tolerance
self.assertAllGreaterEqual(aug_labels, label_smoothing / num_classes -
1e4) # With tolerance
self.assertFalse(tf.math.reduce_all(images == aug_images))
def test_mixup_and_cutmix_smoothes_labels_with_videos(self):
batch_size = 12
num_classes = 1000
label_smoothing = 0.1
images = tf.random.normal((batch_size, 8, 224, 224, 3), dtype=tf.float32)
labels = tf.range(batch_size)
augmenter = augment.MixupAndCutmix(
num_classes=num_classes, label_smoothing=label_smoothing)
aug_images, aug_labels = augmenter.distort(images, labels)
self.assertEqual(images.shape, aug_images.shape)
self.assertEqual(images.dtype, aug_images.dtype)
self.assertEqual([batch_size, num_classes], aug_labels.shape)
self.assertAllLessEqual(aug_labels, 1. - label_smoothing +
2. / num_classes) # With tolerance
self.assertAllGreaterEqual(aug_labels, label_smoothing / num_classes -
1e4) # With tolerance
def test_mixup_changes_video(self):
batch_size = 12
num_classes = 1000
label_smoothing = 0.1
images = tf.random.normal((batch_size, 8, 224, 224, 3), dtype=tf.float32)
labels = tf.range(batch_size)
augmenter = augment.MixupAndCutmix(
mixup_alpha=1., cutmix_alpha=0., num_classes=num_classes)
aug_images, aug_labels = augmenter.distort(images, labels)
self.assertEqual(images.shape, aug_images.shape)
self.assertEqual(images.dtype, aug_images.dtype)
self.assertEqual([batch_size, num_classes], aug_labels.shape)
self.assertAllLessEqual(aug_labels, 1. - label_smoothing +
2. / num_classes) # With tolerance
self.assertAllGreaterEqual(aug_labels, label_smoothing / num_classes -
1e4) # With tolerance
self.assertFalse(tf.math.reduce_all(images == aug_images))
def test_cutmix_changes_video(self):
batch_size = 12
num_classes = 1000
label_smoothing = 0.1
images = tf.random.normal((batch_size, 8, 224, 224, 3), dtype=tf.float32)
labels = tf.range(batch_size)
augmenter = augment.MixupAndCutmix(
mixup_alpha=0., cutmix_alpha=1., num_classes=num_classes)
aug_images, aug_labels = augmenter.distort(images, labels)
self.assertEqual(images.shape, aug_images.shape)
self.assertEqual(images.dtype, aug_images.dtype)
self.assertEqual([batch_size, num_classes], aug_labels.shape)
self.assertAllLessEqual(aug_labels, 1. - label_smoothing +
2. / num_classes) # With tolerance
self.assertAllGreaterEqual(aug_labels, label_smoothing / num_classes -
1e4) # With tolerance
self.assertFalse(tf.math.reduce_all(images == aug_images))
if __name__ == '__main__':
tf.test.main()