|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""The final fusion stage for the film_net frame interpolator. |
|
|
|
The inputs to this module are the warped input images, image features and |
|
flow fields, all aligned to the target frame (often midway point between the |
|
two original inputs). The output is the final image. FILM has no explicit |
|
occlusion handling -- instead using the abovementioned information this module |
|
automatically decides how to best blend the inputs together to produce content |
|
in areas where the pixels can only be borrowed from one of the inputs. |
|
|
|
Similarly, this module also decides on how much to blend in each input in case |
|
of fractional timestep that is not at the halfway point. For example, if the two |
|
inputs images are at t=0 and t=1, and we were to synthesize a frame at t=0.1, |
|
it often makes most sense to favor the first input. However, this is not |
|
always the case -- in particular in occluded pixels. |
|
|
|
The architecture of the Fusion module follows U-net [1] architecture's decoder |
|
side, e.g. each pyramid level consists of concatenation with upsampled coarser |
|
level output, and two 3x3 convolutions. |
|
|
|
The upsampling is implemented as 'resize convolution', e.g. nearest neighbor |
|
upsampling followed by 2x2 convolution as explained in [2]. The classic U-net |
|
uses max-pooling which has a tendency to create checkerboard artifacts. |
|
|
|
[1] Ronneberger et al. U-Net: Convolutional Networks for Biomedical Image |
|
Segmentation, 2015, https://arxiv.org/pdf/1505.04597.pdf |
|
[2] https://distill.pub/2016/deconv-checkerboard/ |
|
""" |
|
|
|
from typing import List |
|
|
|
from . import options |
|
import tensorflow as tf |
|
|
|
|
|
def _relu(x: tf.Tensor) -> tf.Tensor: |
|
return tf.nn.leaky_relu(x, alpha=0.2) |
|
|
|
|
|
_NUMBER_OF_COLOR_CHANNELS = 3 |
|
|
|
|
|
class Fusion(tf.keras.layers.Layer): |
|
"""The decoder.""" |
|
|
|
def __init__(self, name: str, config: options.Options): |
|
super().__init__(name=name) |
|
|
|
|
|
|
|
self.convs: List[List[tf.keras.layers.Layer]] = [] |
|
|
|
|
|
self.levels = config.fusion_pyramid_levels |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for i in range(config.fusion_pyramid_levels - 1): |
|
m = config.specialized_levels |
|
k = config.filters |
|
num_filters = (k << i) if i < m else (k << m) |
|
|
|
convs: List[tf.keras.layers.Layer] = [] |
|
convs.append( |
|
tf.keras.layers.Conv2D( |
|
filters=num_filters, kernel_size=[2, 2], padding='same')) |
|
convs.append( |
|
tf.keras.layers.Conv2D( |
|
filters=num_filters, |
|
kernel_size=[3, 3], |
|
padding='same', |
|
activation=_relu)) |
|
convs.append( |
|
tf.keras.layers.Conv2D( |
|
filters=num_filters, |
|
kernel_size=[3, 3], |
|
padding='same', |
|
activation=_relu)) |
|
self.convs.append(convs) |
|
|
|
|
|
self.output_conv = tf.keras.layers.Conv2D( |
|
filters=_NUMBER_OF_COLOR_CHANNELS, kernel_size=1) |
|
|
|
def call(self, pyramid: List[tf.Tensor]) -> tf.Tensor: |
|
"""Runs the fusion module. |
|
|
|
Args: |
|
pyramid: The input feature pyramid as list of tensors. Each tensor being |
|
in (B x H x W x C) format, with finest level tensor first. |
|
|
|
Returns: |
|
A batch of RGB images. |
|
Raises: |
|
ValueError, if len(pyramid) != config.fusion_pyramid_levels as provided in |
|
the constructor. |
|
""" |
|
if len(pyramid) != self.levels: |
|
raise ValueError( |
|
'Fusion called with different number of pyramid levels ' |
|
f'{len(pyramid)} than it was configured for, {self.levels}.') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
net = pyramid[-1] |
|
|
|
|
|
for i in reversed(range(0, self.levels - 1)): |
|
|
|
level_size = tf.shape(pyramid[i])[1:3] |
|
net = tf.image.resize(net, level_size, |
|
tf.image.ResizeMethod.NEAREST_NEIGHBOR) |
|
net = self.convs[i][0](net) |
|
net = tf.concat([pyramid[i], net], axis=-1) |
|
net = self.convs[i][1](net) |
|
net = self.convs[i][2](net) |
|
net = self.output_conv(net) |
|
return net |
|
|