File size: 6,721 Bytes
2061d64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
# Copyright 2022 Google LLC

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     https://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Dataset augmentation for frame interpolation."""
from typing import Callable, Dict, List

import gin.tf
import numpy as np
import tensorflow as tf
import tensorflow.math as tfm
import tensorflow_addons.image as tfa_image

_PI = 3.141592653589793


def _rotate_flow_vectors(flow: tf.Tensor, angle_rad: float) -> tf.Tensor:
  r"""Rotate the (u,v) vector of each pixel with angle in radians.

  Flow matrix system of coordinates.
  . . . . u (x)
  .
  .
  . v (-y)

  Rotation system of coordinates.
  . y
  .
  .
  . . . . x
  Args:
    flow: Flow map which has been image-rotated.
    angle_rad: The rotation angle in radians.

  Returns:
    A flow with the same map but each (u,v) vector rotated by angle_rad.
  """
  u, v = tf.split(flow, 2, axis=-1)
  # rotu = u * cos(angle) - (-v) * sin(angle)
  rot_u = tfm.cos(angle_rad) * u + tfm.sin(angle_rad) * v
  # rotv = -(u * sin(theta) + (-v) * cos(theta))
  rot_v = -tfm.sin(angle_rad) * u + tfm.cos(angle_rad) * v
  return tf.concat((rot_u, rot_v), axis=-1)


def flow_rot90(flow: tf.Tensor, k: int) -> tf.Tensor:
  """Rotates a flow by a multiple of 90 degrees.

  Args:
    flow: The flow image shaped (H, W, 2) to rotate by multiples of 90 degrees.
    k: The multiplier factor.

  Returns:
    A flow image of the same shape as the input rotated by multiples of 90
    degrees.
  """
  angle_rad = tf.cast(k, dtype=tf.float32) * 90. * (_PI/180.)
  flow = tf.image.rot90(flow, k)
  return _rotate_flow_vectors(flow, angle_rad)


def rotate_flow(flow: tf.Tensor, angle_rad: float) -> tf.Tensor:
  """Rotates a flow by a the provided angle in radians.

  Args:
    flow: The flow image shaped (H, W, 2) to rotate by multiples of 90 degrees.
    angle_rad: The angle to ratate the flow in radians.

  Returns:
    A flow image of the same shape as the input rotated by the provided angle in
    radians.
  """
  flow = tfa_image.rotate(
      flow,
      angles=angle_rad,
      interpolation='bilinear',
      fill_mode='reflect')
  return _rotate_flow_vectors(flow, angle_rad)


def flow_flip(flow: tf.Tensor) -> tf.Tensor:
  """Flips a flow left to right.

  Args:
    flow: The flow image shaped (H, W, 2) to flip left to right.

  Returns:
    A flow image of the same shape as the input flipped left to right.
  """
  flow = tf.image.flip_left_right(tf.identity(flow))
  flow_u, flow_v = tf.split(flow, 2, axis=-1)
  return tf.stack([-1 * flow_u, flow_v], axis=-1)


def random_image_rot90(images: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]:
  """Rotates a stack of images by a random multiples of 90 degrees.

  Args:
    images: A tf.Tensor shaped (H, W, num_channels) of images stacked along the
      channel's axis.
  Returns:
    A tf.Tensor of the same rank as the `images` after random rotation by
    multiples of 90 degrees applied counter-clock wise.
  """
  random_k = tf.random.uniform((), minval=0, maxval=4, dtype=tf.int32)
  for key in images:
    images[key] = tf.image.rot90(images[key], k=random_k)
  return images


def random_flip(images: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]:
  """Flips a stack of images randomly.

  Args:
    images: A tf.Tensor shaped (H, W, num_channels) of images stacked along the
      channel's axis.

  Returns:
    A tf.Tensor of the images after random left to right flip.
  """
  prob = tf.random.uniform((), minval=0, maxval=2, dtype=tf.int32)
  prob = tf.cast(prob, tf.bool)

  def _identity(image):
    return image

  def _flip_left_right(image):
    return tf.image.flip_left_right(image)

  # pylint: disable=cell-var-from-loop
  for key in images:
    images[key] = tf.cond(prob, lambda: _flip_left_right(images[key]),
                          lambda: _identity(images[key]))
  return images


def random_reverse(images: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]:
  """Reverses a stack of images randomly.

  Args:
    images: A dictionary of tf.Tensors, each shaped (H, W, num_channels), with
      each tensor being a stack of iamges along the last channel axis.

  Returns:
    A dictionary of tf.Tensors, each shaped the same as the input images dict.
  """
  prob = tf.random.uniform((), minval=0, maxval=2, dtype=tf.int32)
  prob = tf.cast(prob, tf.bool)

  def _identity(images):
    return images

  def _reverse(images):
    images['x0'], images['x1'] = images['x1'], images['x0']
    return images

  return tf.cond(prob, lambda: _reverse(images), lambda: _identity(images))


def random_rotate(images: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]:
  """Rotates image randomly with [-45 to 45 degrees].

  Args:
    images: A tf.Tensor shaped (H, W, num_channels) of images stacked along the
      channel's axis.

  Returns:
    A tf.Tensor of the images after random rotation with a bound of -72 to 72
    degrees.
  """
  prob = tf.random.uniform((), minval=0, maxval=2, dtype=tf.int32)
  prob = tf.cast(prob, tf.float32)
  random_angle = tf.random.uniform((),
                                   minval=-0.25 * np.pi,
                                   maxval=0.25 * np.pi,
                                   dtype=tf.float32)

  for key in images:
    images[key] = tfa_image.rotate(
        images[key],
        angles=random_angle * prob,
        interpolation='bilinear',
        fill_mode='constant')
  return images


@gin.configurable('data_augmentation')
def data_augmentations(
    names: List[str]) -> Dict[str, Callable[..., tf.Tensor]]:
  """Creates the data augmentation functions.

  Args:
    names: The list of augmentation function names.
  Returns:
    A dictionary of Callables to the augmentation functions, keyed by their
    names.
  """
  augmentations = dict()
  for name in names:
    if name == 'random_image_rot90':
      augmentations[name] = random_image_rot90
    elif name == 'random_rotate':
      augmentations[name] = random_rotate
    elif name == 'random_flip':
      augmentations[name] = random_flip
    elif name == 'random_reverse':
      augmentations[name] = random_reverse
    else:
      raise AttributeError('Invalid augmentation function %s' % name)
  return augmentations