ai / eval /util.py
CHEN11102's picture
Upload 47 files
2061d64 verified
# Copyright 2022 Google LLC
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# https://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility functions for frame interpolation on a set of video frames."""
import os
import shutil
from typing import Generator, Iterable, List, Optional
from . import interpolator as interpolator_lib
import numpy as np
import tensorflow as tf
from tqdm import tqdm
_UINT8_MAX_F = float(np.iinfo(np.uint8).max)
_CONFIG_FFMPEG_NAME_OR_PATH = 'ffmpeg'
def read_image(filename: str) -> np.ndarray:
"""Reads an sRgb 8-bit image.
Args:
filename: The input filename to read.
Returns:
A float32 3-channel (RGB) ndarray with colors in the [0..1] range.
"""
image_data = tf.io.read_file(filename)
image = tf.io.decode_image(image_data, channels=3)
image_numpy = tf.cast(image, dtype=tf.float32).numpy()
return image_numpy / _UINT8_MAX_F
def write_image(filename: str, image: np.ndarray) -> None:
"""Writes a float32 3-channel RGB ndarray image, with colors in range [0..1].
Args:
filename: The output filename to save.
image: A float32 3-channel (RGB) ndarray with colors in the [0..1] range.
"""
image_in_uint8_range = np.clip(image * _UINT8_MAX_F, 0.0, _UINT8_MAX_F)
image_in_uint8 = (image_in_uint8_range + 0.5).astype(np.uint8)
extension = os.path.splitext(filename)[1]
if extension == '.jpg':
image_data = tf.io.encode_jpeg(image_in_uint8)
else:
image_data = tf.io.encode_png(image_in_uint8)
tf.io.write_file(filename, image_data)
def _recursive_generator(
frame1: np.ndarray, frame2: np.ndarray, num_recursions: int,
interpolator: interpolator_lib.Interpolator,
bar: Optional[tqdm] = None
) -> Generator[np.ndarray, None, None]:
"""Splits halfway to repeatedly generate more frames.
Args:
frame1: Input image 1.
frame2: Input image 2.
num_recursions: How many times to interpolate the consecutive image pairs.
interpolator: The frame interpolator instance.
Yields:
The interpolated frames, including the first frame (frame1), but excluding
the final frame2.
"""
if num_recursions == 0:
yield frame1
else:
# Adds the batch dimension to all inputs before calling the interpolator,
# and remove it afterwards.
time = np.full(shape=(1,), fill_value=0.5, dtype=np.float32)
mid_frame = interpolator(frame1[np.newaxis, ...], frame2[np.newaxis, ...],
time)[0]
bar.update(1) if bar is not None else bar
yield from _recursive_generator(frame1, mid_frame, num_recursions - 1,
interpolator, bar)
yield from _recursive_generator(mid_frame, frame2, num_recursions - 1,
interpolator, bar)
def interpolate_recursively_from_files(
frames: List[str], times_to_interpolate: int,
interpolator: interpolator_lib.Interpolator) -> Iterable[np.ndarray]:
"""Generates interpolated frames by repeatedly interpolating the midpoint.
Loads the files on demand and uses the yield paradigm to return the frames
to allow streamed processing of longer videos.
Recursive interpolation is useful if the interpolator is trained to predict
frames at midpoint only and is thus expected to perform poorly elsewhere.
Args:
frames: List of input frames. Expected shape (H, W, 3). The colors should be
in the range[0, 1] and in gamma space.
times_to_interpolate: Number of times to do recursive midpoint
interpolation.
interpolator: The frame interpolation model to use.
Yields:
The interpolated frames (including the inputs).
"""
n = len(frames)
num_frames = (n - 1) * (2**(times_to_interpolate) - 1)
bar = tqdm(total=num_frames, ncols=100, colour='green')
for i in range(1, n):
yield from _recursive_generator(
read_image(frames[i - 1]), read_image(frames[i]), times_to_interpolate,
interpolator, bar)
# Separately yield the final frame.
yield read_image(frames[-1])
def interpolate_recursively_from_memory(
frames: List[np.ndarray], times_to_interpolate: int,
interpolator: interpolator_lib.Interpolator) -> Iterable[np.ndarray]:
"""Generates interpolated frames by repeatedly interpolating the midpoint.
This is functionally equivalent to interpolate_recursively_from_files(), but
expects the inputs frames in memory, instead of loading them on demand.
Recursive interpolation is useful if the interpolator is trained to predict
frames at midpoint only and is thus expected to perform poorly elsewhere.
Args:
frames: List of input frames. Expected shape (H, W, 3). The colors should be
in the range[0, 1] and in gamma space.
times_to_interpolate: Number of times to do recursive midpoint
interpolation.
interpolator: The frame interpolation model to use.
Yields:
The interpolated frames (including the inputs).
"""
n = len(frames)
num_frames = (n - 1) * (2**(times_to_interpolate) - 1)
bar = tqdm(total=num_frames, ncols=100, colour='green')
for i in range(1, n):
yield from _recursive_generator(frames[i - 1], frames[i],
times_to_interpolate, interpolator, bar)
# Separately yield the final frame.
yield frames[-1]
def get_ffmpeg_path() -> str:
path = shutil.which(_CONFIG_FFMPEG_NAME_OR_PATH)
if not path:
raise RuntimeError(
f"Program '{_CONFIG_FFMPEG_NAME_OR_PATH}' is not found;"
" perhaps install ffmpeg using 'apt-get install ffmpeg'.")
return path