# Copyright 2022 Google LLC

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     https://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Dataset creation for frame interpolation."""
from typing import Callable, Dict, List, Optional

from absl import logging
import gin.tf
import tensorflow as tf


def _create_feature_map() -> Dict[str, tf.io.FixedLenFeature]:
  """Creates the feature map for extracting the frame triplet."""
  feature_map = {
      'frame_0/encoded':
          tf.io.FixedLenFeature((), tf.string, default_value=''),
      'frame_0/format':
          tf.io.FixedLenFeature((), tf.string, default_value='jpg'),
      'frame_0/height':
          tf.io.FixedLenFeature((), tf.int64, default_value=0),
      'frame_0/width':
          tf.io.FixedLenFeature((), tf.int64, default_value=0),
      'frame_1/encoded':
          tf.io.FixedLenFeature((), tf.string, default_value=''),
      'frame_1/format':
          tf.io.FixedLenFeature((), tf.string, default_value='jpg'),
      'frame_1/height':
          tf.io.FixedLenFeature((), tf.int64, default_value=0),
      'frame_1/width':
          tf.io.FixedLenFeature((), tf.int64, default_value=0),
      'frame_2/encoded':
          tf.io.FixedLenFeature((), tf.string, default_value=''),
      'frame_2/format':
          tf.io.FixedLenFeature((), tf.string, default_value='jpg'),
      'frame_2/height':
          tf.io.FixedLenFeature((), tf.int64, default_value=0),
      'frame_2/width':
          tf.io.FixedLenFeature((), tf.int64, default_value=0),
      'path':
          tf.io.FixedLenFeature((), tf.string, default_value=''),
  }
  return feature_map


def _parse_example(sample):
  """Parses a serialized sample.

  Args:
    sample: A serialized tf.Example to be parsed.

  Returns:
    dictionary containing the following:
      encoded_image
      image_height
      image_width
  """
  feature_map = _create_feature_map()
  features = tf.io.parse_single_example(sample, feature_map)
  output_dict = {
      'x0': tf.io.decode_image(features['frame_0/encoded'], dtype=tf.float32),
      'x1': tf.io.decode_image(features['frame_2/encoded'], dtype=tf.float32),
      'y': tf.io.decode_image(features['frame_1/encoded'], dtype=tf.float32),
      # The fractional time value of frame_1 is not included in our tfrecords,
      # but is always at 0.5. The model will expect this to be specificed, so
      # we insert it here.
      'time': 0.5,
      # Store the original mid frame filepath for identifying examples.
      'path': features['path'],
  }

  return output_dict


def _random_crop_images(crop_size: int, images: tf.Tensor,
                        total_channel_size: int) -> tf.Tensor:
  """Crops the tensor with random offset to the given size."""
  if crop_size > 0:
    crop_shape = tf.constant([crop_size, crop_size, total_channel_size])
    images = tf.image.random_crop(images, crop_shape)
  return images


def crop_example(example: tf.Tensor, crop_size: int,
                 crop_keys: Optional[List[str]] = None):
  """Random crops selected images in the example to given size and keys.

  Args:
    example: Input tensor representing images to be cropped.
    crop_size: The size to crop images to. This value is used for both
      height and width.
    crop_keys: The images in the input example to crop.

  Returns:
    Example with cropping applied to selected images.
  """
  if crop_keys is None:
    crop_keys = ['x0', 'x1', 'y']
    channels = [3, 3, 3]

  # Stack images along channel axis, and perform a random crop once.
  image_to_crop = [example[key] for key in crop_keys]
  stacked_images = tf.concat(image_to_crop, axis=-1)
  cropped_images = _random_crop_images(crop_size, stacked_images, sum(channels))
  cropped_images = tf.split(
      cropped_images, num_or_size_splits=channels, axis=-1)
  for key, cropped_image in zip(crop_keys, cropped_images):
    example[key] = cropped_image
  return example


def apply_data_augmentation(
    augmentation_fns: Dict[str, Callable[..., tf.Tensor]],
    example: tf.Tensor,
    augmentation_keys: Optional[List[str]] = None) -> tf.Tensor:
  """Applies random augmentation in succession to selected image keys.

  Args:
    augmentation_fns: A Dict of Callables to data augmentation functions.
    example: Input tensor representing images to be augmented.
    augmentation_keys: The images in the input example to augment.

  Returns:
    Example with augmentation applied to selected images.
  """
  if augmentation_keys is None:
    augmentation_keys = ['x0', 'x1', 'y']

  # Apply each augmentation in sequence
  augmented_images = {key: example[key] for key in augmentation_keys}
  for augmentation_function in augmentation_fns.values():
    augmented_images = augmentation_function(augmented_images)

  for key in augmentation_keys:
    example[key] = augmented_images[key]
  return example


def _create_from_tfrecord(batch_size, file, augmentation_fns,
                          crop_size) -> tf.data.Dataset:
  """Creates a dataset from TFRecord."""
  dataset = tf.data.TFRecordDataset(file)
  dataset = dataset.map(
      _parse_example, num_parallel_calls=tf.data.experimental.AUTOTUNE)

  # Perform data_augmentation before cropping and batching
  if augmentation_fns is not None:
    dataset = dataset.map(
        lambda x: apply_data_augmentation(augmentation_fns, x),
        num_parallel_calls=tf.data.experimental.AUTOTUNE)

  if crop_size > 0:
    dataset = dataset.map(
        lambda x: crop_example(x, crop_size=crop_size),
        num_parallel_calls=tf.data.experimental.AUTOTUNE)
  dataset = dataset.batch(batch_size, drop_remainder=True)
  return dataset


def _generate_sharded_filenames(filename: str) -> List[str]:
  """Generates filenames of the each file in the sharded filepath.

  Based on github.com/google/revisiting-self-supervised/blob/master/datasets.py.

  Args:
    filename: The sharded filepath.

  Returns:
    A list of filepaths for each file in the shard.
  """
  base, count = filename.split('@')
  count = int(count)
  return ['{}-{:05d}-of-{:05d}'.format(base, i, count) for i in range(count)]


def _create_from_sharded_tfrecord(batch_size,
                                  train_mode,
                                  file,
                                  augmentation_fns,
                                  crop_size,
                                  max_examples=-1) -> tf.data.Dataset:
  """Creates a dataset from a sharded tfrecord."""
  dataset = tf.data.Dataset.from_tensor_slices(
      _generate_sharded_filenames(file))

  # pylint: disable=g-long-lambda
  dataset = dataset.interleave(
      lambda x: _create_from_tfrecord(
          batch_size,
          file=x,
          augmentation_fns=augmentation_fns,
          crop_size=crop_size),
      num_parallel_calls=tf.data.AUTOTUNE,
      deterministic=not train_mode)
  # pylint: enable=g-long-lambda
  dataset = dataset.prefetch(buffer_size=2)
  if max_examples > 0:
    return dataset.take(max_examples)
  return dataset


@gin.configurable('training_dataset')
def create_training_dataset(
    batch_size: int,
    file: Optional[str] = None,
    files: Optional[List[str]] = None,
    crop_size: int = -1,
    crop_sizes: Optional[List[int]] = None,
    augmentation_fns: Optional[Dict[str, Callable[..., tf.Tensor]]] = None
) -> tf.data.Dataset:
  """Creates the training dataset.

  The given tfrecord should contain data in a format produced by
  frame_interpolation/datasets/create_*_tfrecord.py

  Args:
    batch_size: The number of images to batch per example.
    file: (deprecated) A path to a sharded tfrecord in <tfrecord>@N format.
      Deprecated. Use 'files' instead.
    files: A list of paths to sharded tfrecords in <tfrecord>@N format.
    crop_size: (deprecated) If > 0, images are cropped to crop_size x crop_size
      using tensorflow's random cropping. Deprecated: use 'files' and
      'crop_sizes' instead.
    crop_sizes: List of crop sizes. If > 0, images are cropped to
      crop_size x crop_size using tensorflow's random cropping.
    augmentation_fns: A Dict of Callables to data augmentation functions.
  Returns:
    A tensorflow dataset for accessing examples that contain the input images
    'x0', 'x1', ground truth 'y' and time of the ground truth 'time'=[0,1] in a
    dictionary of tensors.
  """
  if file:
    logging.warning('gin-configurable training_dataset.file is deprecated. '
                    'Use training_dataset.files instead.')
    return _create_from_sharded_tfrecord(batch_size, True, file,
                                         augmentation_fns, crop_size)
  else:
    if not crop_sizes or len(crop_sizes) != len(files):
      raise ValueError('Please pass crop_sizes[] with training_dataset.files.')
    if crop_size > 0:
      raise ValueError(
          'crop_size should not be used with files[], use crop_sizes[] instead.'
      )
    tables = []
    for file, crop_size in zip(files, crop_sizes):
      tables.append(
          _create_from_sharded_tfrecord(batch_size, True, file,
                                        augmentation_fns, crop_size))
    return tf.data.experimental.sample_from_datasets(tables)


@gin.configurable('eval_datasets')
def create_eval_datasets(batch_size: int,
                         files: List[str],
                         names: List[str],
                         crop_size: int = -1,
                         max_examples: int = -1) -> Dict[str, tf.data.Dataset]:
  """Creates the evaluation datasets.

  As opposed to create_training_dataset this function makes sure that the
  examples for each dataset are always read in a deterministic (same) order.

  Each given tfrecord should contain data in a format produced by
  frame_interpolation/datasets/create_*_tfrecord.py

  The (batch_size, crop_size, max_examples) are specified for all eval datasets.

  Args:
    batch_size: The number of images to batch per example.
    files: List of paths to a sharded tfrecord in <tfrecord>@N format.
    names: List of names of eval datasets.
    crop_size: If > 0, images are cropped to crop_size x crop_size using
      tensorflow's random cropping.
    max_examples: If > 0, truncate the dataset to 'max_examples' in length. This
      can be useful for speeding up evaluation loop in case the tfrecord for the
      evaluation set is very large.
  Returns:
    A dict of name to tensorflow dataset for accessing examples that contain the
    input images 'x0', 'x1', ground truth 'y' and time of the ground truth
    'time'=[0,1] in a dictionary of tensors.
  """
  return {
      name: _create_from_sharded_tfrecord(batch_size, False, file, None,
                                          crop_size, max_examples)
      for name, file in zip(names, files)
  }