ai / training /model_lib.py
CHEN11102's picture
Upload 47 files
2061d64 verified
raw
history blame
2.1 kB
# Copyright 2022 Google LLC
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# https://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A library for instantiating the model for training frame interpolation.
All models are expected to use three inputs: input image batches 'x0' and 'x1'
and 'time', the fractional time where the output should be generated.
The models are expected to output the prediction as a dictionary that contains
at least the predicted image batch as 'image' plus optional data for debug,
analysis or custom losses.
"""
import gin.tf
from ..models.film_net import interpolator as film_net_interpolator
from ..models.film_net import options as film_net_options
import tensorflow as tf
@gin.configurable('model')
def create_model(name: str) -> tf.keras.Model:
"""Creates the frame interpolation model based on given model name."""
if name == 'film_net':
return _create_film_net_model() # pylint: disable=no-value-for-parameter
else:
raise ValueError(f'Model {name} not implemented.')
def _create_film_net_model() -> tf.keras.Model:
"""Creates the film_net interpolator."""
# Options are gin-configured in the Options class directly.
options = film_net_options.Options()
x0 = tf.keras.Input(
shape=(None, None, 3), batch_size=None, dtype=tf.float32, name='x0')
x1 = tf.keras.Input(
shape=(None, None, 3), batch_size=None, dtype=tf.float32, name='x1')
time = tf.keras.Input(
shape=(1,), batch_size=None, dtype=tf.float32, name='time')
return film_net_interpolator.create_model(x0, x1, time, options)