|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Various utilities used in the film_net frame interpolator model.""" |
|
from typing import List |
|
|
|
from .options import Options |
|
import tensorflow as tf |
|
import tensorflow_addons.image as tfa_image |
|
|
|
|
|
def build_image_pyramid(image: tf.Tensor, |
|
options: Options) -> List[tf.Tensor]: |
|
"""Builds an image pyramid from a given image. |
|
|
|
The original image is included in the pyramid and the rest are generated by |
|
successively halving the resolution. |
|
|
|
Args: |
|
image: the input image. |
|
options: film_net options object |
|
|
|
Returns: |
|
A list of images starting from the finest with options.pyramid_levels items |
|
""" |
|
levels = options.pyramid_levels |
|
pyramid = [] |
|
pool = tf.keras.layers.AveragePooling2D( |
|
pool_size=2, strides=2, padding='valid') |
|
for i in range(0, levels): |
|
pyramid.append(image) |
|
if i < levels-1: |
|
image = pool(image) |
|
return pyramid |
|
|
|
|
|
def warp(image: tf.Tensor, flow: tf.Tensor) -> tf.Tensor: |
|
"""Backward warps the image using the given flow. |
|
|
|
Specifically, the output pixel in batch b, at position x, y will be computed |
|
as follows: |
|
(flowed_y, flowed_x) = (y+flow[b, y, x, 1], x+flow[b, y, x, 0]) |
|
output[b, y, x] = bilinear_lookup(image, b, flowed_y, flowed_x) |
|
|
|
Note that the flow vectors are expected as [x, y], e.g. x in position 0 and |
|
y in position 1. |
|
|
|
Args: |
|
image: An image with shape BxHxWxC. |
|
flow: A flow with shape BxHxWx2, with the two channels denoting the relative |
|
offset in order: (dx, dy). |
|
Returns: |
|
A warped image. |
|
""" |
|
|
|
|
|
|
|
|
|
flow = -flow[..., ::-1] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
warped = tf.keras.layers.Lambda( |
|
lambda x: tfa_image.dense_image_warp(*x))((image, flow)) |
|
return tf.reshape(warped, shape=tf.shape(image)) |
|
|
|
|
|
def multiply_pyramid(pyramid: List[tf.Tensor], |
|
scalar: tf.Tensor) -> List[tf.Tensor]: |
|
"""Multiplies all image batches in the pyramid by a batch of scalars. |
|
|
|
Args: |
|
pyramid: Pyramid of image batches. |
|
scalar: Batch of scalars. |
|
|
|
Returns: |
|
An image pyramid with all images multiplied by the scalar. |
|
""" |
|
|
|
|
|
|
|
|
|
return [ |
|
tf.transpose(tf.transpose(image, [3, 1, 2, 0]) * scalar, [3, 1, 2, 0]) |
|
for image in pyramid |
|
] |
|
|
|
|
|
def flow_pyramid_synthesis( |
|
residual_pyramid: List[tf.Tensor]) -> List[tf.Tensor]: |
|
"""Converts a residual flow pyramid into a flow pyramid.""" |
|
flow = residual_pyramid[-1] |
|
flow_pyramid = [flow] |
|
for residual_flow in reversed(residual_pyramid[:-1]): |
|
level_size = tf.shape(residual_flow)[1:3] |
|
flow = tf.image.resize(images=2*flow, size=level_size) |
|
flow = residual_flow + flow |
|
flow_pyramid.append(flow) |
|
|
|
return list(reversed(flow_pyramid)) |
|
|
|
|
|
def pyramid_warp(feature_pyramid: List[tf.Tensor], |
|
flow_pyramid: List[tf.Tensor]) -> List[tf.Tensor]: |
|
"""Warps the feature pyramid using the flow pyramid. |
|
|
|
Args: |
|
feature_pyramid: feature pyramid starting from the finest level. |
|
flow_pyramid: flow fields, starting from the finest level. |
|
|
|
Returns: |
|
Reverse warped feature pyramid. |
|
""" |
|
warped_feature_pyramid = [] |
|
for features, flow in zip(feature_pyramid, flow_pyramid): |
|
warped_feature_pyramid.append(warp(features, flow)) |
|
return warped_feature_pyramid |
|
|
|
|
|
def concatenate_pyramids(pyramid1: List[tf.Tensor], |
|
pyramid2: List[tf.Tensor]) -> List[tf.Tensor]: |
|
"""Concatenates each pyramid level together in the channel dimension.""" |
|
result = [] |
|
for features1, features2 in zip(pyramid1, pyramid2): |
|
result.append(tf.concat([features1, features2], axis=-1)) |
|
return result |
|
|