Spaces:
Runtime error
Runtime error
# -------------------------------------------------------- | |
# Based on the timm code base | |
# https://github.com/huggingface/pytorch-image-models | |
# -------------------------------------------------------- | |
""" AutoAugment, RandAugment, and AugMix for PyTorch | |
This code implements the searched ImageNet policies with various tweaks and improvements and | |
does not include any of the search code. | |
AA and RA Implementation adapted from: | |
https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py | |
AugMix adapted from: | |
https://github.com/google-research/augmix | |
Papers: | |
AutoAugment: Learning Augmentation Policies from Data - https://arxiv.org/abs/1805.09501 | |
Learning Data Augmentation Strategies for Object Detection - https://arxiv.org/abs/1906.11172 | |
RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719 | |
AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - https://arxiv.org/abs/1912.02781 | |
Hacked together by / Copyright 2020 Ross Wightman | |
""" | |
import math | |
import random | |
import re | |
import numpy as np | |
import PIL | |
from PIL import Image, ImageEnhance, ImageOps | |
_PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]]) | |
_FILL = (128, 128, 128) | |
_LEVEL_DENOM = 10. # denominator for conversion from 'Mx' magnitude scale to fractional aug level for op arguments | |
_HPARAMS_DEFAULT = dict( | |
translate_const=250, | |
img_mean=_FILL, | |
) | |
_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) | |
def _interpolation(kwargs): | |
interpolation = kwargs.pop('resample', Image.BILINEAR) | |
if isinstance(interpolation, (list, tuple)): | |
return random.choice(interpolation) | |
else: | |
return interpolation | |
def _check_args_tf(kwargs): | |
if 'fillcolor' in kwargs and _PIL_VER < (5, 0): | |
kwargs.pop('fillcolor') | |
kwargs['resample'] = _interpolation(kwargs) | |
def shear_x(img, factor, **kwargs): | |
_check_args_tf(kwargs) | |
return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs) | |
def shear_y(img, factor, **kwargs): | |
_check_args_tf(kwargs) | |
return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs) | |
def translate_x_rel(img, pct, **kwargs): | |
pixels = pct * img.size[0] | |
_check_args_tf(kwargs) | |
return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs) | |
def translate_y_rel(img, pct, **kwargs): | |
pixels = pct * img.size[1] | |
_check_args_tf(kwargs) | |
return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs) | |
def translate_x_abs(img, pixels, **kwargs): | |
_check_args_tf(kwargs) | |
return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs) | |
def translate_y_abs(img, pixels, **kwargs): | |
_check_args_tf(kwargs) | |
return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs) | |
def rotate(img, degrees, **kwargs): | |
_check_args_tf(kwargs) | |
if _PIL_VER >= (5, 2): | |
return img.rotate(degrees, **kwargs) | |
elif _PIL_VER >= (5, 0): | |
w, h = img.size | |
post_trans = (0, 0) | |
rotn_center = (w / 2.0, h / 2.0) | |
angle = -math.radians(degrees) | |
matrix = [ | |
round(math.cos(angle), 15), | |
round(math.sin(angle), 15), | |
0.0, | |
round(-math.sin(angle), 15), | |
round(math.cos(angle), 15), | |
0.0, | |
] | |
def transform(x, y, matrix): | |
(a, b, c, d, e, f) = matrix | |
return a * x + b * y + c, d * x + e * y + f | |
matrix[2], matrix[5] = transform( | |
-rotn_center[0] - post_trans[0], -rotn_center[1] - post_trans[1], matrix | |
) | |
matrix[2] += rotn_center[0] | |
matrix[5] += rotn_center[1] | |
return img.transform(img.size, Image.AFFINE, matrix, **kwargs) | |
else: | |
return img.rotate(degrees, resample=kwargs['resample']) | |
def auto_contrast(img, **__): | |
return ImageOps.autocontrast(img) | |
def invert(img, **__): | |
return ImageOps.invert(img) | |
def equalize(img, **__): | |
return ImageOps.equalize(img) | |
def solarize(img, thresh, **__): | |
return ImageOps.solarize(img, thresh) | |
def solarize_add(img, add, thresh=128, **__): | |
lut = [] | |
for i in range(256): | |
if i < thresh: | |
lut.append(min(255, i + add)) | |
else: | |
lut.append(i) | |
if img.mode in ("L", "RGB"): | |
if img.mode == "RGB" and len(lut) == 256: | |
lut = lut + lut + lut | |
return img.point(lut) | |
else: | |
return img | |
def posterize(img, bits_to_keep, **__): | |
if bits_to_keep >= 8: | |
return img | |
return ImageOps.posterize(img, bits_to_keep) | |
def contrast(img, factor, **__): | |
return ImageEnhance.Contrast(img).enhance(factor) | |
def color(img, factor, **__): | |
return ImageEnhance.Color(img).enhance(factor) | |
def brightness(img, factor, **__): | |
return ImageEnhance.Brightness(img).enhance(factor) | |
def sharpness(img, factor, **__): | |
return ImageEnhance.Sharpness(img).enhance(factor) | |
def _randomly_negate(v): | |
"""With 50% prob, negate the value""" | |
return -v if random.random() > 0.5 else v | |
def _rotate_level_to_arg(level, _hparams): | |
# range [-30, 30] | |
level = (level / _LEVEL_DENOM) * 30. | |
level = _randomly_negate(level) | |
return level, | |
def _enhance_level_to_arg(level, _hparams): | |
# range [0.1, 1.9] | |
return (level / _LEVEL_DENOM) * 1.8 + 0.1, | |
def _enhance_increasing_level_to_arg(level, _hparams): | |
# the 'no change' level is 1.0, moving away from that towards 0. or 2.0 increases the enhancement blend | |
# range [0.1, 1.9] if level <= _LEVEL_DENOM | |
level = (level / _LEVEL_DENOM) * .9 | |
level = max(0.1, 1.0 + _randomly_negate(level)) # keep it >= 0.1 | |
return level, | |
def _shear_level_to_arg(level, _hparams): | |
# range [-0.3, 0.3] | |
level = (level / _LEVEL_DENOM) * 0.3 | |
level = _randomly_negate(level) | |
return level, | |
def _translate_abs_level_to_arg(level, hparams): | |
translate_const = hparams['translate_const'] | |
level = (level / _LEVEL_DENOM) * float(translate_const) | |
level = _randomly_negate(level) | |
return level, | |
def _translate_rel_level_to_arg(level, hparams): | |
# default range [-0.45, 0.45] | |
translate_pct = hparams.get('translate_pct', 0.45) | |
level = (level / _LEVEL_DENOM) * translate_pct | |
level = _randomly_negate(level) | |
return level, | |
def _posterize_level_to_arg(level, _hparams): | |
# As per Tensorflow TPU EfficientNet impl | |
# range [0, 4], 'keep 0 up to 4 MSB of original image' | |
# intensity/severity of augmentation decreases with level | |
return int((level / _LEVEL_DENOM) * 4), | |
def _posterize_increasing_level_to_arg(level, hparams): | |
# As per Tensorflow models research and UDA impl | |
# range [4, 0], 'keep 4 down to 0 MSB of original image', | |
# intensity/severity of augmentation increases with level | |
return 4 - _posterize_level_to_arg(level, hparams)[0], | |
def _posterize_original_level_to_arg(level, _hparams): | |
# As per original AutoAugment paper description | |
# range [4, 8], 'keep 4 up to 8 MSB of image' | |
# intensity/severity of augmentation decreases with level | |
return int((level / _LEVEL_DENOM) * 4) + 4, | |
def _solarize_level_to_arg(level, _hparams): | |
# range [0, 256] | |
# intensity/severity of augmentation decreases with level | |
return int((level / _LEVEL_DENOM) * 256), | |
def _solarize_increasing_level_to_arg(level, _hparams): | |
# range [0, 256] | |
# intensity/severity of augmentation increases with level | |
return 256 - _solarize_level_to_arg(level, _hparams)[0], | |
def _solarize_add_level_to_arg(level, _hparams): | |
# range [0, 110] | |
return int((level / _LEVEL_DENOM) * 110), | |
LEVEL_TO_ARG = { | |
'AutoContrast': None, | |
'Equalize': None, | |
'Invert': None, | |
'Rotate': _rotate_level_to_arg, | |
# There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers | |
'Posterize': _posterize_level_to_arg, | |
'PosterizeIncreasing': _posterize_increasing_level_to_arg, | |
'PosterizeOriginal': _posterize_original_level_to_arg, | |
'Solarize': _solarize_level_to_arg, | |
'SolarizeIncreasing': _solarize_increasing_level_to_arg, | |
'SolarizeAdd': _solarize_add_level_to_arg, | |
'Color': _enhance_level_to_arg, | |
'ColorIncreasing': _enhance_increasing_level_to_arg, | |
'Contrast': _enhance_level_to_arg, | |
'ContrastIncreasing': _enhance_increasing_level_to_arg, | |
'Brightness': _enhance_level_to_arg, | |
'BrightnessIncreasing': _enhance_increasing_level_to_arg, | |
'Sharpness': _enhance_level_to_arg, | |
'SharpnessIncreasing': _enhance_increasing_level_to_arg, | |
'ShearX': _shear_level_to_arg, | |
'ShearY': _shear_level_to_arg, | |
'TranslateX': _translate_abs_level_to_arg, | |
'TranslateY': _translate_abs_level_to_arg, | |
'TranslateXRel': _translate_rel_level_to_arg, | |
'TranslateYRel': _translate_rel_level_to_arg, | |
} | |
NAME_TO_OP = { | |
'AutoContrast': auto_contrast, | |
'Equalize': equalize, | |
'Invert': invert, | |
'Rotate': rotate, | |
'Posterize': posterize, | |
'PosterizeIncreasing': posterize, | |
'PosterizeOriginal': posterize, | |
'Solarize': solarize, | |
'SolarizeIncreasing': solarize, | |
'SolarizeAdd': solarize_add, | |
'Color': color, | |
'ColorIncreasing': color, | |
'Contrast': contrast, | |
'ContrastIncreasing': contrast, | |
'Brightness': brightness, | |
'BrightnessIncreasing': brightness, | |
'Sharpness': sharpness, | |
'SharpnessIncreasing': sharpness, | |
'ShearX': shear_x, | |
'ShearY': shear_y, | |
'TranslateX': translate_x_abs, | |
'TranslateY': translate_y_abs, | |
'TranslateXRel': translate_x_rel, | |
'TranslateYRel': translate_y_rel, | |
} | |
class AugmentOp: | |
def __init__(self, name, prob=0.5, magnitude=10, hparams=None): | |
hparams = hparams or _HPARAMS_DEFAULT | |
self.aug_fn = NAME_TO_OP[name] | |
self.level_fn = LEVEL_TO_ARG[name] | |
self.prob = prob | |
self.magnitude = magnitude | |
self.hparams = hparams.copy() | |
self.kwargs = dict( | |
fillcolor=hparams['img_mean'] if 'img_mean' in hparams else _FILL, | |
resample=hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION, | |
) | |
# If magnitude_std is > 0, we introduce some randomness | |
# in the usually fixed policy and sample magnitude from a normal distribution | |
# with mean `magnitude` and std-dev of `magnitude_std`. | |
# NOTE This is my own hack, being tested, not in papers or reference impls. | |
# If magnitude_std is inf, we sample magnitude from a uniform distribution | |
self.magnitude_std = self.hparams.get('magnitude_std', 0) | |
self.magnitude_max = self.hparams.get('magnitude_max', None) | |
def __call__(self, img): | |
if self.prob < 1.0 and random.random() > self.prob: | |
return img | |
magnitude = self.magnitude | |
if self.magnitude_std > 0: | |
# magnitude randomization enabled | |
if self.magnitude_std == float('inf'): | |
magnitude = random.uniform(0, magnitude) | |
elif self.magnitude_std > 0: | |
magnitude = random.gauss(magnitude, self.magnitude_std) | |
# default upper_bound for the timm RA impl is _LEVEL_DENOM (10) | |
# setting magnitude_max overrides this to allow M > 10 (behaviour closer to Google TF RA impl) | |
upper_bound = self.magnitude_max or _LEVEL_DENOM | |
magnitude = max(0., min(magnitude, upper_bound)) | |
level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple() | |
return self.aug_fn(img, *level_args, **self.kwargs) | |
def auto_augment_policy_v0(hparams): | |
# ImageNet v0 policy from TPU EfficientNet impl, cannot find a paper reference. | |
policy = [ | |
[('Equalize', 0.8, 1), ('ShearY', 0.8, 4)], | |
[('Color', 0.4, 9), ('Equalize', 0.6, 3)], | |
[('Color', 0.4, 1), ('Rotate', 0.6, 8)], | |
[('Solarize', 0.8, 3), ('Equalize', 0.4, 7)], | |
[('Solarize', 0.4, 2), ('Solarize', 0.6, 2)], | |
[('Color', 0.2, 0), ('Equalize', 0.8, 8)], | |
[('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)], | |
[('ShearX', 0.2, 9), ('Rotate', 0.6, 8)], | |
[('Color', 0.6, 1), ('Equalize', 1.0, 2)], | |
[('Invert', 0.4, 9), ('Rotate', 0.6, 0)], | |
[('Equalize', 1.0, 9), ('ShearY', 0.6, 3)], | |
[('Color', 0.4, 7), ('Equalize', 0.6, 0)], | |
[('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)], | |
[('Solarize', 0.6, 8), ('Color', 0.6, 9)], | |
[('Solarize', 0.2, 4), ('Rotate', 0.8, 9)], | |
[('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)], | |
[('ShearX', 0.0, 0), ('Solarize', 0.8, 4)], | |
[('ShearY', 0.8, 0), ('Color', 0.6, 4)], | |
[('Color', 1.0, 0), ('Rotate', 0.6, 2)], | |
[('Equalize', 0.8, 4), ('Equalize', 0.0, 8)], | |
[('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)], | |
[('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)], | |
[('Posterize', 0.8, 2), ('Solarize', 0.6, 10)], # This results in black image with Tpu posterize | |
[('Solarize', 0.6, 8), ('Equalize', 0.6, 1)], | |
[('Color', 0.8, 6), ('Rotate', 0.4, 5)], | |
] | |
pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy] | |
return pc | |
def auto_augment_policy_v0r(hparams): | |
# ImageNet v0 policy from TPU EfficientNet impl, with variation of Posterize used | |
# in Google research implementation (number of bits discarded increases with magnitude) | |
policy = [ | |
[('Equalize', 0.8, 1), ('ShearY', 0.8, 4)], | |
[('Color', 0.4, 9), ('Equalize', 0.6, 3)], | |
[('Color', 0.4, 1), ('Rotate', 0.6, 8)], | |
[('Solarize', 0.8, 3), ('Equalize', 0.4, 7)], | |
[('Solarize', 0.4, 2), ('Solarize', 0.6, 2)], | |
[('Color', 0.2, 0), ('Equalize', 0.8, 8)], | |
[('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)], | |
[('ShearX', 0.2, 9), ('Rotate', 0.6, 8)], | |
[('Color', 0.6, 1), ('Equalize', 1.0, 2)], | |
[('Invert', 0.4, 9), ('Rotate', 0.6, 0)], | |
[('Equalize', 1.0, 9), ('ShearY', 0.6, 3)], | |
[('Color', 0.4, 7), ('Equalize', 0.6, 0)], | |
[('PosterizeIncreasing', 0.4, 6), ('AutoContrast', 0.4, 7)], | |
[('Solarize', 0.6, 8), ('Color', 0.6, 9)], | |
[('Solarize', 0.2, 4), ('Rotate', 0.8, 9)], | |
[('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)], | |
[('ShearX', 0.0, 0), ('Solarize', 0.8, 4)], | |
[('ShearY', 0.8, 0), ('Color', 0.6, 4)], | |
[('Color', 1.0, 0), ('Rotate', 0.6, 2)], | |
[('Equalize', 0.8, 4), ('Equalize', 0.0, 8)], | |
[('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)], | |
[('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)], | |
[('PosterizeIncreasing', 0.8, 2), ('Solarize', 0.6, 10)], | |
[('Solarize', 0.6, 8), ('Equalize', 0.6, 1)], | |
[('Color', 0.8, 6), ('Rotate', 0.4, 5)], | |
] | |
pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy] | |
return pc | |
def auto_augment_policy_original(hparams): | |
# ImageNet policy from https://arxiv.org/abs/1805.09501 | |
policy = [ | |
[('PosterizeOriginal', 0.4, 8), ('Rotate', 0.6, 9)], | |
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)], | |
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)], | |
[('PosterizeOriginal', 0.6, 7), ('PosterizeOriginal', 0.6, 6)], | |
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)], | |
[('Equalize', 0.4, 4), ('Rotate', 0.8, 8)], | |
[('Solarize', 0.6, 3), ('Equalize', 0.6, 7)], | |
[('PosterizeOriginal', 0.8, 5), ('Equalize', 1.0, 2)], | |
[('Rotate', 0.2, 3), ('Solarize', 0.6, 8)], | |
[('Equalize', 0.6, 8), ('PosterizeOriginal', 0.4, 6)], | |
[('Rotate', 0.8, 8), ('Color', 0.4, 0)], | |
[('Rotate', 0.4, 9), ('Equalize', 0.6, 2)], | |
[('Equalize', 0.0, 7), ('Equalize', 0.8, 8)], | |
[('Invert', 0.6, 4), ('Equalize', 1.0, 8)], | |
[('Color', 0.6, 4), ('Contrast', 1.0, 8)], | |
[('Rotate', 0.8, 8), ('Color', 1.0, 2)], | |
[('Color', 0.8, 8), ('Solarize', 0.8, 7)], | |
[('Sharpness', 0.4, 7), ('Invert', 0.6, 8)], | |
[('ShearX', 0.6, 5), ('Equalize', 1.0, 9)], | |
[('Color', 0.4, 0), ('Equalize', 0.6, 3)], | |
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)], | |
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)], | |
[('Invert', 0.6, 4), ('Equalize', 1.0, 8)], | |
[('Color', 0.6, 4), ('Contrast', 1.0, 8)], | |
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)], | |
] | |
pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy] | |
return pc | |
def auto_augment_policy_originalr(hparams): | |
# ImageNet policy from https://arxiv.org/abs/1805.09501 with research posterize variation | |
policy = [ | |
[('PosterizeIncreasing', 0.4, 8), ('Rotate', 0.6, 9)], | |
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)], | |
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)], | |
[('PosterizeIncreasing', 0.6, 7), ('PosterizeIncreasing', 0.6, 6)], | |
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)], | |
[('Equalize', 0.4, 4), ('Rotate', 0.8, 8)], | |
[('Solarize', 0.6, 3), ('Equalize', 0.6, 7)], | |
[('PosterizeIncreasing', 0.8, 5), ('Equalize', 1.0, 2)], | |
[('Rotate', 0.2, 3), ('Solarize', 0.6, 8)], | |
[('Equalize', 0.6, 8), ('PosterizeIncreasing', 0.4, 6)], | |
[('Rotate', 0.8, 8), ('Color', 0.4, 0)], | |
[('Rotate', 0.4, 9), ('Equalize', 0.6, 2)], | |
[('Equalize', 0.0, 7), ('Equalize', 0.8, 8)], | |
[('Invert', 0.6, 4), ('Equalize', 1.0, 8)], | |
[('Color', 0.6, 4), ('Contrast', 1.0, 8)], | |
[('Rotate', 0.8, 8), ('Color', 1.0, 2)], | |
[('Color', 0.8, 8), ('Solarize', 0.8, 7)], | |
[('Sharpness', 0.4, 7), ('Invert', 0.6, 8)], | |
[('ShearX', 0.6, 5), ('Equalize', 1.0, 9)], | |
[('Color', 0.4, 0), ('Equalize', 0.6, 3)], | |
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)], | |
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)], | |
[('Invert', 0.6, 4), ('Equalize', 1.0, 8)], | |
[('Color', 0.6, 4), ('Contrast', 1.0, 8)], | |
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)], | |
] | |
pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy] | |
return pc | |
def auto_augment_policy(name='v0', hparams=None): | |
hparams = hparams or _HPARAMS_DEFAULT | |
if name == 'original': | |
return auto_augment_policy_original(hparams) | |
elif name == 'originalr': | |
return auto_augment_policy_originalr(hparams) | |
elif name == 'v0': | |
return auto_augment_policy_v0(hparams) | |
elif name == 'v0r': | |
return auto_augment_policy_v0r(hparams) | |
else: | |
assert False, 'Unknown AA policy (%s)' % name | |
class AutoAugment: | |
def __init__(self, policy): | |
self.policy = policy | |
def __call__(self, img): | |
sub_policy = random.choice(self.policy) | |
for op in sub_policy: | |
img = op(img) | |
return img | |
def auto_augment_transform(config_str, hparams): | |
""" | |
Create a AutoAugment transform | |
:param config_str: String defining configuration of auto augmentation. Consists of multiple sections separated by | |
dashes ('-'). The first section defines the AutoAugment policy (one of 'v0', 'v0r', 'original', 'originalr'). | |
The remaining sections, not order sepecific determine | |
'mstd' - float std deviation of magnitude noise applied | |
Ex 'original-mstd0.5' results in AutoAugment with original policy, magnitude_std 0.5 | |
:param hparams: Other hparams (kwargs) for the AutoAugmentation scheme | |
:return: A PyTorch compatible Transform | |
""" | |
config = config_str.split('-') | |
policy_name = config[0] | |
config = config[1:] | |
for c in config: | |
cs = re.split(r'(\d.*)', c) | |
if len(cs) < 2: | |
continue | |
key, val = cs[:2] | |
if key == 'mstd': | |
# noise param injected via hparams for now | |
hparams.setdefault('magnitude_std', float(val)) | |
else: | |
assert False, 'Unknown AutoAugment config section' | |
aa_policy = auto_augment_policy(policy_name, hparams=hparams) | |
return AutoAugment(aa_policy) | |
_RAND_TRANSFORMS = [ | |
'AutoContrast', | |
'Equalize', | |
'Invert', | |
'Rotate', | |
'Posterize', | |
'Solarize', | |
'SolarizeAdd', | |
'Color', | |
'Contrast', | |
'Brightness', | |
'Sharpness', | |
'ShearX', | |
'ShearY', | |
'TranslateXRel', | |
'TranslateYRel', | |
# 'Cutout' # NOTE I've implement this as random erasing separately | |
] | |
_RAND_INCREASING_TRANSFORMS = [ | |
'AutoContrast', | |
'Equalize', | |
'Invert', | |
'Rotate', | |
'PosterizeIncreasing', | |
'SolarizeIncreasing', | |
'SolarizeAdd', | |
'ColorIncreasing', | |
'ContrastIncreasing', | |
'BrightnessIncreasing', | |
'SharpnessIncreasing', | |
'ShearX', | |
'ShearY', | |
'TranslateXRel', | |
'TranslateYRel', | |
# 'Cutout' # NOTE I've implement this as random erasing separately | |
] | |
# These experimental weights are based loosely on the relative improvements mentioned in paper. | |
# They may not result in increased performance, but could likely be tuned to so. | |
_RAND_CHOICE_WEIGHTS_0 = { | |
'Rotate': 0.3, | |
'ShearX': 0.2, | |
'ShearY': 0.2, | |
'TranslateXRel': 0.1, | |
'TranslateYRel': 0.1, | |
'Color': .025, | |
'Sharpness': 0.025, | |
'AutoContrast': 0.025, | |
'Solarize': .005, | |
'SolarizeAdd': .005, | |
'Contrast': .005, | |
'Brightness': .005, | |
'Equalize': .005, | |
'Posterize': 0, | |
'Invert': 0, | |
} | |
def _select_rand_weights(weight_idx=0, transforms=None): | |
transforms = transforms or _RAND_TRANSFORMS | |
assert weight_idx == 0 # only one set of weights currently | |
rand_weights = _RAND_CHOICE_WEIGHTS_0 | |
probs = [rand_weights[k] for k in transforms] | |
probs /= np.sum(probs) | |
return probs | |
def rand_augment_ops(magnitude=10, hparams=None, transforms=None): | |
hparams = hparams or _HPARAMS_DEFAULT | |
transforms = transforms or _RAND_TRANSFORMS | |
return [AugmentOp( | |
name, prob=0.5, magnitude=magnitude, hparams=hparams) for name in transforms] | |
class RandAugment: | |
def __init__(self, ops, num_layers=2, choice_weights=None): | |
self.ops = ops | |
self.num_layers = num_layers | |
self.choice_weights = choice_weights | |
def __call__(self, img): | |
# no replacement when using weighted choice | |
ops = np.random.choice( | |
self.ops, self.num_layers, replace=self.choice_weights is None, p=self.choice_weights) | |
for op in ops: | |
img = op(img) | |
return img | |
def rand_augment_transform(config_str, hparams): | |
""" | |
Create a RandAugment transform | |
:param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by | |
dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining | |
sections, not order sepecific determine | |
'm' - integer magnitude of rand augment | |
'n' - integer num layers (number of transform ops selected per image) | |
'w' - integer probabiliy weight index (index of a set of weights to influence choice of op) | |
'mstd' - float std deviation of magnitude noise applied, or uniform sampling if infinity (or > 100) | |
'mmax' - set upper bound for magnitude to something other than default of _LEVEL_DENOM (10) | |
'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0) | |
Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5 | |
'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2 | |
:param hparams: Other hparams (kwargs) for the RandAugmentation scheme | |
:return: A PyTorch compatible Transform | |
""" | |
magnitude = _LEVEL_DENOM # default to _LEVEL_DENOM for magnitude (currently 10) | |
num_layers = 2 # default to 2 ops per image | |
weight_idx = None # default to no probability weights for op choice | |
transforms = _RAND_TRANSFORMS | |
config = config_str.split('-') | |
assert config[0] == 'rand' | |
config = config[1:] | |
for c in config: | |
cs = re.split(r'(\d.*)', c) | |
if len(cs) < 2: | |
continue | |
key, val = cs[:2] | |
if key == 'mstd': | |
# noise param / randomization of magnitude values | |
mstd = float(val) | |
if mstd > 100: | |
# use uniform sampling in 0 to magnitude if mstd is > 100 | |
mstd = float('inf') | |
hparams.setdefault('magnitude_std', mstd) | |
elif key == 'mmax': | |
# clip magnitude between [0, mmax] instead of default [0, _LEVEL_DENOM] | |
hparams.setdefault('magnitude_max', int(val)) | |
elif key == 'inc': | |
if bool(val): | |
transforms = _RAND_INCREASING_TRANSFORMS | |
elif key == 'm': | |
magnitude = int(val) | |
elif key == 'n': | |
num_layers = int(val) | |
elif key == 'w': | |
weight_idx = int(val) | |
else: | |
assert False, 'Unknown RandAugment config section' | |
ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams, transforms=transforms) | |
choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx) | |
return RandAugment(ra_ops, num_layers, choice_weights=choice_weights) | |
_AUGMIX_TRANSFORMS = [ | |
'AutoContrast', | |
'ColorIncreasing', # not in paper | |
'ContrastIncreasing', # not in paper | |
'BrightnessIncreasing', # not in paper | |
'SharpnessIncreasing', # not in paper | |
'Equalize', | |
'Rotate', | |
'PosterizeIncreasing', | |
'SolarizeIncreasing', | |
'ShearX', | |
'ShearY', | |
'TranslateXRel', | |
'TranslateYRel', | |
] | |
def augmix_ops(magnitude=10, hparams=None, transforms=None): | |
hparams = hparams or _HPARAMS_DEFAULT | |
transforms = transforms or _AUGMIX_TRANSFORMS | |
return [AugmentOp( | |
name, prob=1.0, magnitude=magnitude, hparams=hparams) for name in transforms] | |
class AugMixAugment: | |
""" AugMix Transform | |
Adapted and improved from impl here: https://github.com/google-research/augmix/blob/master/imagenet.py | |
From paper: 'AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - | |
https://arxiv.org/abs/1912.02781 | |
""" | |
def __init__(self, ops, alpha=1., width=3, depth=-1, blended=False): | |
self.ops = ops | |
self.alpha = alpha | |
self.width = width | |
self.depth = depth | |
self.blended = blended # blended mode is faster but not well tested | |
def _calc_blended_weights(self, ws, m): | |
ws = ws * m | |
cump = 1. | |
rws = [] | |
for w in ws[::-1]: | |
alpha = w / cump | |
cump *= (1 - alpha) | |
rws.append(alpha) | |
return np.array(rws[::-1], dtype=np.float32) | |
def _apply_blended(self, img, mixing_weights, m): | |
# This is my first crack and implementing a slightly faster mixed augmentation. Instead | |
# of accumulating the mix for each chain in a Numpy array and then blending with original, | |
# it recomputes the blending coefficients and applies one PIL image blend per chain. | |
# TODO the results appear in the right ballpark but they differ by more than rounding. | |
img_orig = img.copy() | |
ws = self._calc_blended_weights(mixing_weights, m) | |
for w in ws: | |
depth = self.depth if self.depth > 0 else np.random.randint(1, 4) | |
ops = np.random.choice(self.ops, depth, replace=True) | |
img_aug = img_orig # no ops are in-place, deep copy not necessary | |
for op in ops: | |
img_aug = op(img_aug) | |
img = Image.blend(img, img_aug, w) | |
return img | |
def _apply_basic(self, img, mixing_weights, m): | |
# This is a literal adaptation of the paper/official implementation without normalizations and | |
# PIL <-> Numpy conversions between every op. It is still quite CPU compute heavy compared to the | |
# typical augmentation transforms, could use a GPU / Kornia implementation. | |
img_shape = img.size[0], img.size[1], len(img.getbands()) | |
mixed = np.zeros(img_shape, dtype=np.float32) | |
for mw in mixing_weights: | |
depth = self.depth if self.depth > 0 else np.random.randint(1, 4) | |
ops = np.random.choice(self.ops, depth, replace=True) | |
img_aug = img # no ops are in-place, deep copy not necessary | |
for op in ops: | |
img_aug = op(img_aug) | |
mixed += mw * np.asarray(img_aug, dtype=np.float32) | |
np.clip(mixed, 0, 255., out=mixed) | |
mixed = Image.fromarray(mixed.astype(np.uint8)) | |
return Image.blend(img, mixed, m) | |
def __call__(self, img): | |
mixing_weights = np.float32(np.random.dirichlet([self.alpha] * self.width)) | |
m = np.float32(np.random.beta(self.alpha, self.alpha)) | |
if self.blended: | |
mixed = self._apply_blended(img, mixing_weights, m) | |
else: | |
mixed = self._apply_basic(img, mixing_weights, m) | |
return mixed | |
def augment_and_mix_transform(config_str, hparams): | |
""" Create AugMix PyTorch transform | |
:param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by | |
dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining | |
sections, not order sepecific determine | |
'm' - integer magnitude (severity) of augmentation mix (default: 3) | |
'w' - integer width of augmentation chain (default: 3) | |
'd' - integer depth of augmentation chain (-1 is random [1, 3], default: -1) | |
'b' - integer (bool), blend each branch of chain into end result without a final blend, less CPU (default: 0) | |
'mstd' - float std deviation of magnitude noise applied (default: 0) | |
Ex 'augmix-m5-w4-d2' results in AugMix with severity 5, chain width 4, chain depth 2 | |
:param hparams: Other hparams (kwargs) for the Augmentation transforms | |
:return: A PyTorch compatible Transform | |
""" | |
magnitude = 3 | |
width = 3 | |
depth = -1 | |
alpha = 1. | |
blended = False | |
config = config_str.split('-') | |
assert config[0] == 'augmix' | |
config = config[1:] | |
for c in config: | |
cs = re.split(r'(\d.*)', c) | |
if len(cs) < 2: | |
continue | |
key, val = cs[:2] | |
if key == 'mstd': | |
# noise param injected via hparams for now | |
hparams.setdefault('magnitude_std', float(val)) | |
elif key == 'm': | |
magnitude = int(val) | |
elif key == 'w': | |
width = int(val) | |
elif key == 'd': | |
depth = int(val) | |
elif key == 'a': | |
alpha = float(val) | |
elif key == 'b': | |
blended = bool(val) | |
else: | |
assert False, 'Unknown AugMix config section' | |
hparams.setdefault('magnitude_std', float('inf')) # default to uniform sampling (if not set via mstd arg) | |
ops = augmix_ops(magnitude=magnitude, hparams=hparams) | |
return AugMixAugment(ops, alpha=alpha, width=width, depth=depth, blended=blended) | |