ai / training /augmentation_lib.py
CHEN11102's picture
Upload 47 files
2061d64 verified
raw
history blame
6.72 kB
# Copyright 2022 Google LLC
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# https://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Dataset augmentation for frame interpolation."""
from typing import Callable, Dict, List
import gin.tf
import numpy as np
import tensorflow as tf
import tensorflow.math as tfm
import tensorflow_addons.image as tfa_image
_PI = 3.141592653589793
def _rotate_flow_vectors(flow: tf.Tensor, angle_rad: float) -> tf.Tensor:
r"""Rotate the (u,v) vector of each pixel with angle in radians.
Flow matrix system of coordinates.
. . . . u (x)
.
.
. v (-y)
Rotation system of coordinates.
. y
.
.
. . . . x
Args:
flow: Flow map which has been image-rotated.
angle_rad: The rotation angle in radians.
Returns:
A flow with the same map but each (u,v) vector rotated by angle_rad.
"""
u, v = tf.split(flow, 2, axis=-1)
# rotu = u * cos(angle) - (-v) * sin(angle)
rot_u = tfm.cos(angle_rad) * u + tfm.sin(angle_rad) * v
# rotv = -(u * sin(theta) + (-v) * cos(theta))
rot_v = -tfm.sin(angle_rad) * u + tfm.cos(angle_rad) * v
return tf.concat((rot_u, rot_v), axis=-1)
def flow_rot90(flow: tf.Tensor, k: int) -> tf.Tensor:
"""Rotates a flow by a multiple of 90 degrees.
Args:
flow: The flow image shaped (H, W, 2) to rotate by multiples of 90 degrees.
k: The multiplier factor.
Returns:
A flow image of the same shape as the input rotated by multiples of 90
degrees.
"""
angle_rad = tf.cast(k, dtype=tf.float32) * 90. * (_PI/180.)
flow = tf.image.rot90(flow, k)
return _rotate_flow_vectors(flow, angle_rad)
def rotate_flow(flow: tf.Tensor, angle_rad: float) -> tf.Tensor:
"""Rotates a flow by a the provided angle in radians.
Args:
flow: The flow image shaped (H, W, 2) to rotate by multiples of 90 degrees.
angle_rad: The angle to ratate the flow in radians.
Returns:
A flow image of the same shape as the input rotated by the provided angle in
radians.
"""
flow = tfa_image.rotate(
flow,
angles=angle_rad,
interpolation='bilinear',
fill_mode='reflect')
return _rotate_flow_vectors(flow, angle_rad)
def flow_flip(flow: tf.Tensor) -> tf.Tensor:
"""Flips a flow left to right.
Args:
flow: The flow image shaped (H, W, 2) to flip left to right.
Returns:
A flow image of the same shape as the input flipped left to right.
"""
flow = tf.image.flip_left_right(tf.identity(flow))
flow_u, flow_v = tf.split(flow, 2, axis=-1)
return tf.stack([-1 * flow_u, flow_v], axis=-1)
def random_image_rot90(images: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]:
"""Rotates a stack of images by a random multiples of 90 degrees.
Args:
images: A tf.Tensor shaped (H, W, num_channels) of images stacked along the
channel's axis.
Returns:
A tf.Tensor of the same rank as the `images` after random rotation by
multiples of 90 degrees applied counter-clock wise.
"""
random_k = tf.random.uniform((), minval=0, maxval=4, dtype=tf.int32)
for key in images:
images[key] = tf.image.rot90(images[key], k=random_k)
return images
def random_flip(images: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]:
"""Flips a stack of images randomly.
Args:
images: A tf.Tensor shaped (H, W, num_channels) of images stacked along the
channel's axis.
Returns:
A tf.Tensor of the images after random left to right flip.
"""
prob = tf.random.uniform((), minval=0, maxval=2, dtype=tf.int32)
prob = tf.cast(prob, tf.bool)
def _identity(image):
return image
def _flip_left_right(image):
return tf.image.flip_left_right(image)
# pylint: disable=cell-var-from-loop
for key in images:
images[key] = tf.cond(prob, lambda: _flip_left_right(images[key]),
lambda: _identity(images[key]))
return images
def random_reverse(images: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]:
"""Reverses a stack of images randomly.
Args:
images: A dictionary of tf.Tensors, each shaped (H, W, num_channels), with
each tensor being a stack of iamges along the last channel axis.
Returns:
A dictionary of tf.Tensors, each shaped the same as the input images dict.
"""
prob = tf.random.uniform((), minval=0, maxval=2, dtype=tf.int32)
prob = tf.cast(prob, tf.bool)
def _identity(images):
return images
def _reverse(images):
images['x0'], images['x1'] = images['x1'], images['x0']
return images
return tf.cond(prob, lambda: _reverse(images), lambda: _identity(images))
def random_rotate(images: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]:
"""Rotates image randomly with [-45 to 45 degrees].
Args:
images: A tf.Tensor shaped (H, W, num_channels) of images stacked along the
channel's axis.
Returns:
A tf.Tensor of the images after random rotation with a bound of -72 to 72
degrees.
"""
prob = tf.random.uniform((), minval=0, maxval=2, dtype=tf.int32)
prob = tf.cast(prob, tf.float32)
random_angle = tf.random.uniform((),
minval=-0.25 * np.pi,
maxval=0.25 * np.pi,
dtype=tf.float32)
for key in images:
images[key] = tfa_image.rotate(
images[key],
angles=random_angle * prob,
interpolation='bilinear',
fill_mode='constant')
return images
@gin.configurable('data_augmentation')
def data_augmentations(
names: List[str]) -> Dict[str, Callable[..., tf.Tensor]]:
"""Creates the data augmentation functions.
Args:
names: The list of augmentation function names.
Returns:
A dictionary of Callables to the augmentation functions, keyed by their
names.
"""
augmentations = dict()
for name in names:
if name == 'random_image_rot90':
augmentations[name] = random_image_rot90
elif name == 'random_rotate':
augmentations[name] = random_rotate
elif name == 'random_flip':
augmentations[name] = random_flip
elif name == 'random_reverse':
augmentations[name] = random_reverse
else:
raise AttributeError('Invalid augmentation function %s' % name)
return augmentations