# Copyright 2022 Google LLC # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # https://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """The film_net frame interpolator main model code. Basics ====== The film_net is an end-to-end learned neural frame interpolator implemented as a TF2 model. It has the following inputs and outputs: Inputs: x0: image A. x1: image B. time: desired sub-frame time. Outputs: image: the predicted in-between image at the chosen time in range [0, 1]. Additional outputs include forward and backward warped image pyramids, flow pyramids, etc., that can be visualized for debugging and analysis. Note that many training sets only contain triplets with ground truth at time=0.5. If a model has been trained with such training set, it will only work well for synthesizing frames at time=0.5. Such models can only generate more in-between frames using recursion. Architecture ============ The inference consists of three main stages: 1) feature extraction 2) warping 3) fusion. On high-level, the architecture has similarities to Context-aware Synthesis for Video Frame Interpolation [1], but the exact architecture is closer to Multi-view Image Fusion [2] with some modifications for the frame interpolation use-case. Feature extraction stage employs the cascaded multi-scale architecture described in [2]. The advantage of this architecture is that coarse level flow prediction can be learned from finer resolution image samples. This is especially useful to avoid overfitting with moderately sized datasets. The warping stage uses a residual flow prediction idea that is similar to PWC-Net [3], Multi-view Image Fusion [2] and many others. The fusion stage is similar to U-Net's decoder where the skip connections are connected to warped image and feature pyramids. This is described in [2]. Implementation Conventions ==================== Pyramids -------- Throughtout the model, all image and feature pyramids are stored as python lists with finest level first followed by downscaled versions obtained by successively halving the resolution. The depths of all pyramids are determined by options.pyramid_levels. The only exception to this is internal to the feature extractor, where smaller feature pyramids are temporarily constructed with depth options.sub_levels. Color ranges & gamma -------------------- The model code makes no assumptions on whether the images are in gamma or linearized space or what is the range of RGB color values. So a model can be trained with different choices. This does not mean that all the choices lead to similar results. In practice the model has been proven to work well with RGB scale = [0,1] with gamma-space images (i.e. not linearized). [1] Context-aware Synthesis for Video Frame Interpolation, Niklaus and Liu, 2018 [2] Multi-view Image Fusion, Trinidad et al, 2019 [3] PWC-Net: CNNs for Optical Flow Using Pyramid, Warping, and Cost Volume """ from . import feature_extractor from . import fusion from . import options from . import pyramid_flow_estimator from . import util import tensorflow as tf def create_model(x0: tf.Tensor, x1: tf.Tensor, time: tf.Tensor, config: options.Options) -> tf.keras.Model: """Creates a frame interpolator model. The frame interpolator is used to warp the two images to the in-between frame at given time. Note that training data is often restricted such that supervision only exists at 'time'=0.5. If trained with such data, the model will overfit to predicting images that are halfway between the two inputs and will not be as accurate elsewhere. Args: x0: first input image as BxHxWxC tensor. x1: second input image as BxHxWxC tensor. time: ignored by film_net. We always infer a frame at t = 0.5. config: FilmNetOptions object. Returns: A tf.Model that takes 'x0', 'x1', and 'time' as input and returns a dictionary with the interpolated result in 'image'. For additional diagnostics or supervision, the following intermediate results are also stored in the dictionary: 'x0_warped': an intermediate result obtained by warping from x0 'x1_warped': an intermediate result obtained by warping from x1 'forward_residual_flow_pyramid': pyramid with forward residual flows 'backward_residual_flow_pyramid': pyramid with backward residual flows 'forward_flow_pyramid': pyramid with forward flows 'backward_flow_pyramid': pyramid with backward flows Raises: ValueError, if config.pyramid_levels < config.fusion_pyramid_levels. """ if config.pyramid_levels < config.fusion_pyramid_levels: raise ValueError('config.pyramid_levels must be greater than or equal to ' 'config.fusion_pyramid_levels.') x0_decoded = x0 x1_decoded = x1 # shuffle images image_pyramids = [ util.build_image_pyramid(x0_decoded, config), util.build_image_pyramid(x1_decoded, config) ] # Siamese feature pyramids: extract = feature_extractor.FeatureExtractor('feat_net', config) feature_pyramids = [extract(image_pyramids[0]), extract(image_pyramids[1])] predict_flow = pyramid_flow_estimator.PyramidFlowEstimator( 'predict_flow', config) # Predict forward flow. forward_residual_flow_pyramid = predict_flow(feature_pyramids[0], feature_pyramids[1]) # Predict backward flow. backward_residual_flow_pyramid = predict_flow(feature_pyramids[1], feature_pyramids[0]) # Concatenate features and images: # Note that we keep up to 'fusion_pyramid_levels' levels as only those # are used by the fusion module. fusion_pyramid_levels = config.fusion_pyramid_levels forward_flow_pyramid = util.flow_pyramid_synthesis( forward_residual_flow_pyramid)[:fusion_pyramid_levels] backward_flow_pyramid = util.flow_pyramid_synthesis( backward_residual_flow_pyramid)[:fusion_pyramid_levels] # We multiply the flows with t and 1-t to warp to the desired fractional time. # # Note: In film_net we fix time to be 0.5, and recursively invoke the interpo- # lator for multi-frame interpolation. Below, we create a constant tensor of # shape [B]. We use the `time` tensor to infer the batch size. mid_time = tf.keras.layers.Lambda(lambda x: tf.ones_like(x) * 0.5)(time) backward_flow = util.multiply_pyramid(backward_flow_pyramid, mid_time[:, 0]) forward_flow = util.multiply_pyramid(forward_flow_pyramid, 1 - mid_time[:, 0]) pyramids_to_warp = [ util.concatenate_pyramids(image_pyramids[0][:fusion_pyramid_levels], feature_pyramids[0][:fusion_pyramid_levels]), util.concatenate_pyramids(image_pyramids[1][:fusion_pyramid_levels], feature_pyramids[1][:fusion_pyramid_levels]) ] # Warp features and images using the flow. Note that we use backward warping # and backward flow is used to read from image 0 and forward flow from # image 1. forward_warped_pyramid = util.pyramid_warp(pyramids_to_warp[0], backward_flow) backward_warped_pyramid = util.pyramid_warp(pyramids_to_warp[1], forward_flow) aligned_pyramid = util.concatenate_pyramids(forward_warped_pyramid, backward_warped_pyramid) aligned_pyramid = util.concatenate_pyramids(aligned_pyramid, backward_flow) aligned_pyramid = util.concatenate_pyramids(aligned_pyramid, forward_flow) fuse = fusion.Fusion('fusion', config) prediction = fuse(aligned_pyramid) output_color = prediction[..., :3] outputs = {'image': output_color} if config.use_aux_outputs: outputs.update({ 'x0_warped': forward_warped_pyramid[0][..., 0:3], 'x1_warped': backward_warped_pyramid[0][..., 0:3], 'forward_residual_flow_pyramid': forward_residual_flow_pyramid, 'backward_residual_flow_pyramid': backward_residual_flow_pyramid, 'forward_flow_pyramid': forward_flow_pyramid, 'backward_flow_pyramid': backward_flow_pyramid, }) model = tf.keras.Model( inputs={ 'x0': x0, 'x1': x1, 'time': time }, outputs=outputs) return model