EdgeTA / utils /dl /augmix /__init__.py
LINC-BIT's picture
Upload 1912 files
b84549f verified
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Reference implementation of AugMix's data augmentation method in numpy."""
from .augs import augmentations
import numpy as np
from PIL import Image
import torch
# CIFAR-10 constants
MEAN = [0.4914, 0.4822, 0.4465]
STD = [0.2023, 0.1994, 0.2010]
def normalize(image):
"""Normalize input image channel-wise to zero mean and unit variance."""
image = image.transpose(2, 0, 1) # Switch to channel-first
mean, std = np.array(MEAN), np.array(STD)
image = (image - mean[:, None, None]) / std[:, None, None]
return image.transpose(1, 2, 0)
def apply_op(image, op, severity):
image = np.clip(image * 255., 0, 255).astype(np.uint8)
pil_img = Image.fromarray(image) # Convert to PIL.Image
pil_img = op(pil_img, severity)
res = np.asarray(pil_img) / 255.
return res
def augment_and_mix_pil(image: Image, severity=3, width=3, depth=-1, alpha=1.):
"""Perform AugMix augmentations and compute mixture.
Args:
image: Raw input image as float32 np.ndarray of shape (h, w, c)
severity: Severity of underlying augmentation operators (between 1 to 10).
width: Width of augmentation chain
depth: Depth of augmentation chain. -1 enables stochastic depth uniformly
from [1, 3]
alpha: Probability coefficient for Beta and Dirichlet distributions.
Returns:
mixed: Augmented and mixed image.
"""
ws = np.float32(
np.random.dirichlet([alpha] * width))
m = np.float32(np.random.beta(alpha, alpha))
mix = np.zeros_like(image)
# print(mix.shape, image.shape)
for i in range(width):
image_aug = image.copy()
d = depth if depth > 0 else np.random.randint(1, 4)
for _ in range(d):
op = np.random.choice(augmentations)
image_aug = apply_op(image_aug, op, severity)
# Preprocessing commutes since all coefficients are convex
# print(mix.shape, image_aug.shape)
mix = mix + ws[i] * normalize(image_aug)
mixed = (1 - m) * normalize(image) + m * mix
return mixed
# def augment_and_mix_tensor(image: torch.Tensor, severity=3, width=3, depth=-1, alpha=1.):
# """Perform AugMix augmentations and compute mixture.
# Args:
# image: Raw input image as float32 np.ndarray of shape (h, w, c)
# severity: Severity of underlying augmentation operators (between 1 to 10).
# width: Width of augmentation chain
# depth: Depth of augmentation chain. -1 enables stochastic depth uniformly
# from [1, 3]
# alpha: Probability coefficient for Beta and Dirichlet distributions.
# Returns:
# mixed: Augmented and mixed image.
# """
# image = Image.fromarray(image.cpu().numpy())
# image = augment_and_mix_pil(image, severity, width, depth, alpha)
# return torch.from_numpy(image)
def augment_and_mix_tensors(images: torch.Tensor, severity=3, width=3, depth=-1, alpha=1.):
"""Perform AugMix augmentations and compute mixture.
Args:
image: Raw input image as float32 np.ndarray of shape (h, w, c)
severity: Severity of underlying augmentation operators (between 1 to 10).
width: Width of augmentation chain
depth: Depth of augmentation chain. -1 enables stochastic depth uniformly
from [1, 3]
alpha: Probability coefficient for Beta and Dirichlet distributions.
Returns:
mixed: Augmented and mixed image.
"""
res = []
for image in images:
gray_img = False
if image.size(0) == 1:
gray_img = True
image = torch.cat([image, image, image])
image = image.cpu().numpy().transpose(1, 2, 0)
aug_image = augment_and_mix_pil(image, severity, width, depth, alpha)
if gray_img:
aug_image = aug_image.transpose(2, 0, 1)[0: 1]
res += [torch.from_numpy(aug_image).cuda().float()]
return torch.stack(res)