File size: 5,331 Bytes
1772f26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
# Copyright 2022 Google LLC

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     https://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Various utilities used in the film_net frame interpolator model."""
from typing import List

from .options import Options
import tensorflow as tf
import tensorflow_addons.image as tfa_image


def build_image_pyramid(image: tf.Tensor,
                        options: Options) -> List[tf.Tensor]:
  """Builds an image pyramid from a given image.

  The original image is included in the pyramid and the rest are generated by
  successively halving the resolution.

  Args:
    image: the input image.
    options: film_net options object

  Returns:
    A list of images starting from the finest with options.pyramid_levels items
  """
  levels = options.pyramid_levels
  pyramid = []
  pool = tf.keras.layers.AveragePooling2D(
      pool_size=2, strides=2, padding='valid')
  for i in range(0, levels):
    pyramid.append(image)
    if i < levels-1:
      image = pool(image)
  return pyramid


def warp(image: tf.Tensor, flow: tf.Tensor) -> tf.Tensor:
  """Backward warps the image using the given flow.

  Specifically, the output pixel in batch b, at position x, y will be computed
  as follows:
    (flowed_y, flowed_x) = (y+flow[b, y, x, 1], x+flow[b, y, x, 0])
    output[b, y, x] = bilinear_lookup(image, b, flowed_y, flowed_x)

  Note that the flow vectors are expected as [x, y], e.g. x in position 0 and
  y in position 1.

  Args:
    image: An image with shape BxHxWxC.
    flow: A flow with shape BxHxWx2, with the two channels denoting the relative
      offset in order: (dx, dy).
  Returns:
    A warped image.
  """
  # tfa_image.dense_image_warp expects unconventional negated optical flow, so
  # negate the flow here. Also revert x and y for compatibility with older saved
  # models trained with custom warp op that stored (x, y) instead of (y, x) flow
  # vectors.
  flow = -flow[..., ::-1]

  # Note: we have to wrap tfa_image.dense_image_warp into a Keras Lambda,
  # because it is not compatible with Keras symbolic tensors and we want to use
  # this code as part of a Keras model.  Wrapping it into a lambda has the
  # consequence that tfa_image.dense_image_warp is only called once the tensors
  # are concrete, e.g. actually contain data. The inner lambda is a workaround
  # for passing two parameters, e.g you would really want to write:
  # tf.keras.layers.Lambda(tfa_image.dense_image_warp)(image, flow), but this is
  # not supported by the Keras Lambda.
  warped = tf.keras.layers.Lambda(
      lambda x: tfa_image.dense_image_warp(*x))((image, flow))
  return tf.reshape(warped, shape=tf.shape(image))


def multiply_pyramid(pyramid: List[tf.Tensor],
                     scalar: tf.Tensor) -> List[tf.Tensor]:
  """Multiplies all image batches in the pyramid by a batch of scalars.

  Args:
    pyramid: Pyramid of image batches.
    scalar: Batch of scalars.

  Returns:
    An image pyramid with all images multiplied by the scalar.
  """
  # To multiply each image with its corresponding scalar, we first transpose
  # the batch of images from BxHxWxC-format to CxHxWxB. This can then be
  # multiplied with a batch of scalars, then we transpose back to the standard
  # BxHxWxC form.
  return [
      tf.transpose(tf.transpose(image, [3, 1, 2, 0]) * scalar, [3, 1, 2, 0])
      for image in pyramid
  ]


def flow_pyramid_synthesis(
    residual_pyramid: List[tf.Tensor]) -> List[tf.Tensor]:
  """Converts a residual flow pyramid into a flow pyramid."""
  flow = residual_pyramid[-1]
  flow_pyramid = [flow]
  for residual_flow in reversed(residual_pyramid[:-1]):
    level_size = tf.shape(residual_flow)[1:3]
    flow = tf.image.resize(images=2*flow, size=level_size)
    flow = residual_flow + flow
    flow_pyramid.append(flow)
  # Use reversed() to return in the 'standard' finest-first-order:
  return list(reversed(flow_pyramid))


def pyramid_warp(feature_pyramid: List[tf.Tensor],
                 flow_pyramid: List[tf.Tensor]) -> List[tf.Tensor]:
  """Warps the feature pyramid using the flow pyramid.

  Args:
    feature_pyramid: feature pyramid starting from the finest level.
    flow_pyramid: flow fields, starting from the finest level.

  Returns:
    Reverse warped feature pyramid.
  """
  warped_feature_pyramid = []
  for features, flow in zip(feature_pyramid, flow_pyramid):
    warped_feature_pyramid.append(warp(features, flow))
  return warped_feature_pyramid


def concatenate_pyramids(pyramid1: List[tf.Tensor],
                         pyramid2: List[tf.Tensor]) -> List[tf.Tensor]:
  """Concatenates each pyramid level together in the channel dimension."""
  result = []
  for features1, features2 in zip(pyramid1, pyramid2):
    result.append(tf.concat([features1, features2], axis=-1))
  return result