Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Parser for video and label datasets.""" | |
from typing import Dict, Optional, Tuple, Union | |
from absl import logging | |
import tensorflow as tf, tf_keras | |
from official.vision.configs import video_classification as exp_cfg | |
from official.vision.dataloaders import decoder | |
from official.vision.dataloaders import parser | |
from official.vision.ops import augment | |
from official.vision.ops import preprocess_ops_3d | |
IMAGE_KEY = 'image/encoded' | |
LABEL_KEY = 'clip/label/index' | |
def process_image(image: tf.Tensor, | |
is_training: bool = True, | |
num_frames: int = 32, | |
stride: int = 1, | |
random_stride_range: int = 0, | |
num_test_clips: int = 1, | |
min_resize: int = 256, | |
crop_size: Union[int, Tuple[int, int]] = 224, | |
num_channels: int = 3, | |
num_crops: int = 1, | |
zero_centering_image: bool = False, | |
min_aspect_ratio: float = 0.5, | |
max_aspect_ratio: float = 2, | |
min_area_ratio: float = 0.49, | |
max_area_ratio: float = 1.0, | |
augmenter: Optional[augment.ImageAugment] = None, | |
seed: Optional[int] = None, | |
input_image_format: Optional[str] = 'jpeg') -> tf.Tensor: | |
"""Processes a serialized image tensor. | |
Args: | |
image: Input Tensor of shape [time-steps] and type tf.string of serialized | |
frames. | |
is_training: Whether or not in training mode. If True, random sample, crop | |
and left right flip is used. | |
num_frames: Number of frames per sub clip. | |
stride: Temporal stride to sample frames. | |
random_stride_range: An int indicating the min and max bounds to uniformly | |
sample different strides from the video. E.g., a value of 1 with stride=2 | |
will uniformly sample a stride in {1, 2, 3} for each video in a batch. | |
Only used enabled training for the purposes of frame-rate augmentation. | |
Defaults to 0, which disables random sampling. | |
num_test_clips: Number of test clips (1 by default). If more than 1, this | |
will sample multiple linearly spaced clips within each video at test time. | |
If 1, then a single clip in the middle of the video is sampled. The clips | |
are aggregated in the batch dimension. | |
min_resize: Frames are resized so that min(height, width) is min_resize. | |
crop_size: Final size of the frame after cropping the resized frames. | |
Optionally, specify a tuple of (crop_height, crop_width) if | |
crop_height != crop_width. | |
num_channels: Number of channels of the clip. | |
num_crops: Number of crops to perform on the resized frames. | |
zero_centering_image: If True, frames are normalized to values in [-1, 1]. | |
If False, values in [0, 1]. | |
min_aspect_ratio: The minimum aspect range for cropping. | |
max_aspect_ratio: The maximum aspect range for cropping. | |
min_area_ratio: The minimum area range for cropping. | |
max_area_ratio: The maximum area range for cropping. | |
augmenter: Image augmenter to distort each image. | |
seed: A deterministic seed to use when sampling. | |
input_image_format: The format of input image which could be jpeg, png or | |
none for unknown or mixed datasets. | |
Returns: | |
Processed frames. Tensor of shape | |
[num_frames * num_test_clips, crop_height, crop_width, num_channels]. | |
""" | |
# Validate parameters. | |
if is_training and num_test_clips != 1: | |
logging.warning( | |
'`num_test_clips` %d is ignored since `is_training` is `True`.', | |
num_test_clips) | |
if random_stride_range < 0: | |
raise ValueError('Random stride range should be >= 0, got {}'.format( | |
random_stride_range)) | |
if input_image_format not in ('jpeg', 'png', 'none'): | |
raise ValueError('Unknown input image format: {}'.format( | |
input_image_format)) | |
if isinstance(crop_size, int): | |
crop_size = (crop_size, crop_size) | |
crop_height, crop_width = crop_size | |
# Temporal sampler. | |
if is_training: | |
if random_stride_range > 0: | |
# Uniformly sample different frame-rates | |
stride = tf.random.uniform( | |
[], | |
tf.maximum(stride - random_stride_range, 1), | |
stride + random_stride_range, | |
dtype=tf.int32) | |
# Sample random clip. | |
image = preprocess_ops_3d.sample_sequence(image, num_frames, True, stride, | |
seed) | |
elif num_test_clips > 1: | |
# Sample linspace clips. | |
image = preprocess_ops_3d.sample_linspace_sequence(image, num_test_clips, | |
num_frames, stride) | |
else: | |
# Sample middle clip. | |
image = preprocess_ops_3d.sample_sequence(image, num_frames, False, stride) | |
# Decode JPEG string to tf.uint8. | |
if image.dtype == tf.string: | |
image = preprocess_ops_3d.decode_image(image, num_channels) | |
if is_training: | |
# Standard image data augmentation: random resized crop and random flip. | |
image = preprocess_ops_3d.random_crop_resize( | |
image, crop_height, crop_width, num_frames, num_channels, | |
(min_aspect_ratio, max_aspect_ratio), | |
(min_area_ratio, max_area_ratio)) | |
image = preprocess_ops_3d.random_flip_left_right(image, seed) | |
if augmenter is not None: | |
image = augmenter.distort(image) | |
else: | |
# Resize images (resize happens only if necessary to save compute). | |
image = preprocess_ops_3d.resize_smallest(image, min_resize) | |
# Crop of the frames. | |
image = preprocess_ops_3d.crop_image(image, crop_height, crop_width, False, | |
num_crops) | |
# Cast the frames in float32, normalizing according to zero_centering_image. | |
return preprocess_ops_3d.normalize_image(image, zero_centering_image) | |
def postprocess_image(image: tf.Tensor, | |
is_training: bool = True, | |
num_frames: int = 32, | |
num_test_clips: int = 1, | |
num_test_crops: int = 1) -> tf.Tensor: | |
"""Processes a batched Tensor of frames. | |
The same parameters used in process should be used here. | |
Args: | |
image: Input Tensor of shape [batch, time-steps, height, width, 3]. | |
is_training: Whether or not in training mode. If True, random sample, crop | |
and left right flip is used. | |
num_frames: Number of frames per sub clip. | |
num_test_clips: Number of test clips (1 by default). If more than 1, this | |
will sample multiple linearly spaced clips within each video at test time. | |
If 1, then a single clip in the middle of the video is sampled. The clips | |
are aggregated in the batch dimension. | |
num_test_crops: Number of test crops (1 by default). If more than 1, there | |
are multiple crops for each clip at test time. If 1, there is a single | |
central crop. The crops are aggregated in the batch dimension. | |
Returns: | |
Processed frames. Tensor of shape | |
[batch * num_test_clips * num_test_crops, num_frames, height, width, 3]. | |
""" | |
num_views = num_test_clips * num_test_crops | |
if num_views > 1 and not is_training: | |
# In this case, multiple views are merged together in batch dimension which | |
# will be batch * num_views. | |
image = tf.reshape(image, [-1, num_frames] + image.shape[2:].as_list()) | |
return image | |
def process_label(label: tf.Tensor, | |
one_hot_label: bool = True, | |
num_classes: Optional[int] = None, | |
label_dtype: tf.DType = tf.int32) -> tf.Tensor: | |
"""Processes label Tensor.""" | |
# Validate parameters. | |
if one_hot_label and not num_classes: | |
raise ValueError( | |
'`num_classes` should be given when requesting one hot label.') | |
# Cast to label_dtype (default = tf.int32). | |
label = tf.cast(label, dtype=label_dtype) | |
if one_hot_label: | |
# Replace label index by one hot representation. | |
label = tf.one_hot(label, num_classes) | |
if len(label.shape.as_list()) > 1: | |
label = tf.reduce_sum(label, axis=0) | |
if num_classes == 1: | |
# The trick for single label. | |
label = 1 - label | |
return label | |
class Decoder(decoder.Decoder): | |
"""A tf.Example decoder for classification task.""" | |
def __init__(self, image_key: str = IMAGE_KEY, label_key: str = LABEL_KEY): | |
self._context_description = { | |
# One integer stored in context. | |
label_key: tf.io.VarLenFeature(tf.int64), | |
} | |
self._sequence_description = { | |
# Each image is a string encoding JPEG. | |
image_key: tf.io.FixedLenSequenceFeature((), tf.string), | |
} | |
def add_feature(self, feature_name: str, | |
feature_type: Union[tf.io.VarLenFeature, | |
tf.io.FixedLenFeature, | |
tf.io.FixedLenSequenceFeature]): | |
self._sequence_description[feature_name] = feature_type | |
def add_context(self, feature_name: str, | |
feature_type: Union[tf.io.VarLenFeature, | |
tf.io.FixedLenFeature, | |
tf.io.FixedLenSequenceFeature]): | |
self._context_description[feature_name] = feature_type | |
def decode(self, serialized_example): | |
"""Parses a single tf.Example into image and label tensors.""" | |
result = {} | |
context, sequences = tf.io.parse_single_sequence_example( | |
serialized_example, self._context_description, | |
self._sequence_description) | |
result.update(context) | |
result.update(sequences) | |
for key, value in result.items(): | |
if isinstance(value, tf.SparseTensor): | |
result[key] = tf.sparse.to_dense(value) | |
return result | |
class VideoTfdsDecoder(decoder.Decoder): | |
"""A tf.SequenceExample decoder for tfds video classification datasets.""" | |
def __init__(self, image_key: str = IMAGE_KEY, label_key: str = LABEL_KEY): | |
self._image_key = image_key | |
self._label_key = label_key | |
def decode(self, features): | |
"""Decode the TFDS FeatureDict. | |
Args: | |
features: features from TFDS video dataset. | |
See https://www.tensorflow.org/datasets/catalog/ucf101 for example. | |
Returns: | |
Dict of tensors. | |
""" | |
sample_dict = { | |
self._image_key: features['video'], | |
self._label_key: features['label'], | |
} | |
return sample_dict | |
class Parser(parser.Parser): | |
"""Parses a video and label dataset.""" | |
def __init__(self, | |
input_params: exp_cfg.DataConfig, | |
image_key: str = IMAGE_KEY, | |
label_key: str = LABEL_KEY): | |
self._num_frames = input_params.feature_shape[0] | |
self._stride = input_params.temporal_stride | |
self._random_stride_range = input_params.random_stride_range | |
self._num_test_clips = input_params.num_test_clips | |
self._min_resize = input_params.min_image_size | |
crop_height = input_params.feature_shape[1] | |
crop_width = input_params.feature_shape[2] | |
self._crop_size = crop_height if crop_height == crop_width else ( | |
crop_height, crop_width) | |
self._num_channels = input_params.feature_shape[3] | |
self._num_crops = input_params.num_test_crops | |
self._zero_centering_image = input_params.zero_centering_image | |
self._one_hot_label = input_params.one_hot | |
self._num_classes = input_params.num_classes | |
self._image_key = image_key | |
self._label_key = label_key | |
self._dtype = tf.dtypes.as_dtype(input_params.dtype) | |
self._label_dtype = tf.dtypes.as_dtype(input_params.label_dtype) | |
self._output_audio = input_params.output_audio | |
self._min_aspect_ratio = input_params.aug_min_aspect_ratio | |
self._max_aspect_ratio = input_params.aug_max_aspect_ratio | |
self._min_area_ratio = input_params.aug_min_area_ratio | |
self._max_area_ratio = input_params.aug_max_area_ratio | |
self._input_image_format = input_params.input_image_format | |
if self._output_audio: | |
self._audio_feature = input_params.audio_feature | |
self._audio_shape = input_params.audio_feature_shape | |
aug_type = input_params.aug_type | |
if aug_type is not None: | |
if aug_type.type == 'autoaug': | |
logging.info('Using AutoAugment.') | |
self._augmenter = augment.AutoAugment( | |
augmentation_name=aug_type.autoaug.augmentation_name, | |
cutout_const=aug_type.autoaug.cutout_const, | |
translate_const=aug_type.autoaug.translate_const) | |
elif aug_type.type == 'randaug': | |
logging.info('Using RandAugment.') | |
self._augmenter = augment.RandAugment( | |
num_layers=aug_type.randaug.num_layers, | |
magnitude=aug_type.randaug.magnitude, | |
cutout_const=aug_type.randaug.cutout_const, | |
translate_const=aug_type.randaug.translate_const, | |
prob_to_apply=aug_type.randaug.prob_to_apply, | |
exclude_ops=aug_type.randaug.exclude_ops) | |
else: | |
raise ValueError( | |
'Augmentation policy {} not supported.'.format(aug_type.type)) | |
else: | |
self._augmenter = None | |
def _parse_train_data( | |
self, decoded_tensors: Dict[str, tf.Tensor] | |
) -> Tuple[Dict[str, tf.Tensor], tf.Tensor]: | |
"""Parses data for training.""" | |
# Process image and label. | |
image = decoded_tensors[self._image_key] | |
image = process_image( | |
image=image, | |
is_training=True, | |
num_frames=self._num_frames, | |
stride=self._stride, | |
random_stride_range=self._random_stride_range, | |
num_test_clips=self._num_test_clips, | |
min_resize=self._min_resize, | |
crop_size=self._crop_size, | |
num_channels=self._num_channels, | |
min_aspect_ratio=self._min_aspect_ratio, | |
max_aspect_ratio=self._max_aspect_ratio, | |
min_area_ratio=self._min_area_ratio, | |
max_area_ratio=self._max_area_ratio, | |
augmenter=self._augmenter, | |
zero_centering_image=self._zero_centering_image, | |
input_image_format=self._input_image_format) | |
image = tf.cast(image, dtype=self._dtype) | |
features = {'image': image} | |
label = decoded_tensors[self._label_key] | |
label = process_label(label, self._one_hot_label, self._num_classes, | |
self._label_dtype) | |
if self._output_audio: | |
audio = decoded_tensors[self._audio_feature] | |
audio = tf.cast(audio, dtype=self._dtype) | |
# TODO(yeqing): synchronize audio/video sampling. Especially randomness. | |
audio = preprocess_ops_3d.sample_sequence( | |
audio, self._audio_shape[0], random=False, stride=1) | |
audio = tf.ensure_shape(audio, self._audio_shape) | |
features['audio'] = audio | |
return features, label | |
def _parse_eval_data( | |
self, decoded_tensors: Dict[str, tf.Tensor] | |
) -> Tuple[Dict[str, tf.Tensor], tf.Tensor]: | |
"""Parses data for evaluation.""" | |
image = decoded_tensors[self._image_key] | |
image = process_image( | |
image=image, | |
is_training=False, | |
num_frames=self._num_frames, | |
stride=self._stride, | |
num_test_clips=self._num_test_clips, | |
min_resize=self._min_resize, | |
crop_size=self._crop_size, | |
num_channels=self._num_channels, | |
num_crops=self._num_crops, | |
zero_centering_image=self._zero_centering_image, | |
input_image_format=self._input_image_format) | |
image = tf.cast(image, dtype=self._dtype) | |
features = {'image': image} | |
label = decoded_tensors[self._label_key] | |
label = process_label(label, self._one_hot_label, self._num_classes, | |
self._label_dtype) | |
if self._output_audio: | |
audio = decoded_tensors[self._audio_feature] | |
audio = tf.cast(audio, dtype=self._dtype) | |
audio = preprocess_ops_3d.sample_sequence( | |
audio, self._audio_shape[0], random=False, stride=1) | |
audio = tf.ensure_shape(audio, self._audio_shape) | |
features['audio'] = audio | |
return features, label | |
class PostBatchProcessor(object): | |
"""Processes a video and label dataset which is batched.""" | |
def __init__(self, input_params: exp_cfg.DataConfig): | |
self._is_training = input_params.is_training | |
self._num_frames = input_params.feature_shape[0] | |
self._num_test_clips = input_params.num_test_clips | |
self._num_test_crops = input_params.num_test_crops | |
def __call__(self, features: Dict[str, tf.Tensor], | |
label: tf.Tensor) -> Tuple[Dict[str, tf.Tensor], tf.Tensor]: | |
"""Parses a single tf.Example into image and label tensors.""" | |
for key in ['image']: | |
if key in features: | |
features[key] = postprocess_image( | |
image=features[key], | |
is_training=self._is_training, | |
num_frames=self._num_frames, | |
num_test_clips=self._num_test_clips, | |
num_test_crops=self._num_test_crops) | |
return features, label | |