ai / eval /interpolator_cli.py
CHEN11102's picture
Upload 47 files
2061d64 verified
# Copyright 2022 Google LLC
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# https://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Runs the FILM frame interpolator on a pair of frames on beam.
This script is used evaluate the output quality of the FILM Tensorflow frame
interpolator. Optionally, it outputs a video of the interpolated frames.
A beam pipeline for invoking the frame interpolator on a set of directories
identified by a glob (--pattern). Each directory is expected to contain two
input frames that are the inputs to the frame interpolator. If a directory has
more than two frames, then each contiguous frame pair is treated as input to
generate in-between frames.
The output video is stored to interpolator.mp4 in each directory. The number of
frames is determined by --times_to_interpolate, which controls the number of
times the frame interpolator is invoked. When the number of input frames is 2,
the number of output frames is 2^times_to_interpolate+1.
This expects a directory structure such as:
<root directory of the eval>/01/frame1.png
frame2.png
<root directory of the eval>/02/frame1.png
frame2.png
<root directory of the eval>/03/frame1.png
frame2.png
...
And will produce:
<root directory of the eval>/01/interpolated_frames/frame0.png
frame1.png
frame2.png
<root directory of the eval>/02/interpolated_frames/frame0.png
frame1.png
frame2.png
<root directory of the eval>/03/interpolated_frames/frame0.png
frame1.png
frame2.png
...
And optionally will produce:
<root directory of the eval>/01/interpolated.mp4
<root directory of the eval>/02/interpolated.mp4
<root directory of the eval>/03/interpolated.mp4
...
Usage example:
python3 -m frame_interpolation.eval.interpolator_cli \
--model_path <path to TF2 saved model> \
--pattern "<root directory of the eval>/*" \
--times_to_interpolate <Number of times to interpolate>
"""
import functools
import os
from typing import List, Sequence
from . import interpolator as interpolator_lib
from . import util
from absl import app
from absl import flags
from absl import logging
import apache_beam as beam
import mediapy as media
import natsort
import numpy as np
import tensorflow as tf
from tqdm.auto import tqdm
# Controls TF_CCP log level.
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
_PATTERN = flags.DEFINE_string(
name='pattern',
default=None,
help='The pattern to determine the directories with the input frames.',
required=True)
_MODEL_PATH = flags.DEFINE_string(
name='model_path',
default=None,
help='The path of the TF2 saved model to use.')
_TIMES_TO_INTERPOLATE = flags.DEFINE_integer(
name='times_to_interpolate',
default=5,
help='The number of times to run recursive midpoint interpolation. '
'The number of output frames will be 2^times_to_interpolate+1.')
_FPS = flags.DEFINE_integer(
name='fps',
default=30,
help='Frames per second to play interpolated videos in slow motion.')
_ALIGN = flags.DEFINE_integer(
name='align',
default=64,
help='If >1, pad the input size so it is evenly divisible by this value.')
_BLOCK_HEIGHT = flags.DEFINE_integer(
name='block_height',
default=1,
help='An int >= 1, number of patches along height, '
'patch_height = height//block_height, should be evenly divisible.')
_BLOCK_WIDTH = flags.DEFINE_integer(
name='block_width',
default=1,
help='An int >= 1, number of patches along width, '
'patch_width = width//block_width, should be evenly divisible.')
_OUTPUT_VIDEO = flags.DEFINE_boolean(
name='output_video',
default=False,
help='If true, creates a video of the frames in the interpolated_frames/ '
'subdirectory')
# Add other extensions, if not either.
_INPUT_EXT = ['png', 'jpg', 'jpeg']
def _output_frames(frames: List[np.ndarray], frames_dir: str):
"""Writes PNG-images to a directory.
If frames_dir doesn't exist, it is created. If frames_dir contains existing
PNG-files, they are removed before saving the new ones.
Args:
frames: List of images to save.
frames_dir: The output directory to save the images.
"""
if tf.io.gfile.isdir(frames_dir):
old_frames = tf.io.gfile.glob(f'{frames_dir}/frame_*.png')
if old_frames:
logging.info('Removing existing frames from %s.', frames_dir)
for old_frame in old_frames:
tf.io.gfile.remove(old_frame)
else:
tf.io.gfile.makedirs(frames_dir)
for idx, frame in tqdm(
enumerate(frames), total=len(frames), ncols=100, colour='green'):
util.write_image(f'{frames_dir}/frame_{idx:03d}.png', frame)
logging.info('Output frames saved in %s.', frames_dir)
class ProcessDirectory(beam.DoFn):
"""DoFn for running the interpolator on a single directory at the time."""
def setup(self):
self.interpolator = interpolator_lib.Interpolator(
_MODEL_PATH.value, _ALIGN.value,
[_BLOCK_HEIGHT.value, _BLOCK_WIDTH.value])
if _OUTPUT_VIDEO.value:
ffmpeg_path = util.get_ffmpeg_path()
media.set_ffmpeg(ffmpeg_path)
def process(self, directory: str):
input_frames_list = [
natsort.natsorted(tf.io.gfile.glob(f'{directory}/*.{ext}'))
for ext in _INPUT_EXT
]
input_frames = functools.reduce(lambda x, y: x + y, input_frames_list)
logging.info('Generating in-between frames for %s.', directory)
frames = list(
util.interpolate_recursively_from_files(
input_frames, _TIMES_TO_INTERPOLATE.value, self.interpolator))
_output_frames(frames, f'{directory}/interpolated_frames')
if _OUTPUT_VIDEO.value:
media.write_video(f'{directory}/interpolated.mp4', frames, fps=_FPS.value)
logging.info('Output video saved at %s/interpolated.mp4.', directory)
def _run_pipeline() -> None:
directories = tf.io.gfile.glob(_PATTERN.value)
pipeline = beam.Pipeline('DirectRunner')
(pipeline | 'Create directory names' >> beam.Create(directories) # pylint: disable=expression-not-assigned
| 'Process directories' >> beam.ParDo(ProcessDirectory()))
result = pipeline.run()
result.wait_until_finish()
def main(argv: Sequence[str]) -> None:
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
_run_pipeline()
if __name__ == '__main__':
app.run(main)