# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Parser for video and label datasets.""" from typing import Dict, Optional, Tuple, Union from absl import logging import tensorflow as tf, tf_keras from official.vision.configs import video_classification as exp_cfg from official.vision.dataloaders import decoder from official.vision.dataloaders import parser from official.vision.ops import augment from official.vision.ops import preprocess_ops_3d IMAGE_KEY = 'image/encoded' LABEL_KEY = 'clip/label/index' def process_image(image: tf.Tensor, is_training: bool = True, num_frames: int = 32, stride: int = 1, random_stride_range: int = 0, num_test_clips: int = 1, min_resize: int = 256, crop_size: Union[int, Tuple[int, int]] = 224, num_channels: int = 3, num_crops: int = 1, zero_centering_image: bool = False, min_aspect_ratio: float = 0.5, max_aspect_ratio: float = 2, min_area_ratio: float = 0.49, max_area_ratio: float = 1.0, augmenter: Optional[augment.ImageAugment] = None, seed: Optional[int] = None, input_image_format: Optional[str] = 'jpeg') -> tf.Tensor: """Processes a serialized image tensor. Args: image: Input Tensor of shape [time-steps] and type tf.string of serialized frames. is_training: Whether or not in training mode. If True, random sample, crop and left right flip is used. num_frames: Number of frames per sub clip. stride: Temporal stride to sample frames. random_stride_range: An int indicating the min and max bounds to uniformly sample different strides from the video. E.g., a value of 1 with stride=2 will uniformly sample a stride in {1, 2, 3} for each video in a batch. Only used enabled training for the purposes of frame-rate augmentation. Defaults to 0, which disables random sampling. num_test_clips: Number of test clips (1 by default). If more than 1, this will sample multiple linearly spaced clips within each video at test time. If 1, then a single clip in the middle of the video is sampled. The clips are aggregated in the batch dimension. min_resize: Frames are resized so that min(height, width) is min_resize. crop_size: Final size of the frame after cropping the resized frames. Optionally, specify a tuple of (crop_height, crop_width) if crop_height != crop_width. num_channels: Number of channels of the clip. num_crops: Number of crops to perform on the resized frames. zero_centering_image: If True, frames are normalized to values in [-1, 1]. If False, values in [0, 1]. min_aspect_ratio: The minimum aspect range for cropping. max_aspect_ratio: The maximum aspect range for cropping. min_area_ratio: The minimum area range for cropping. max_area_ratio: The maximum area range for cropping. augmenter: Image augmenter to distort each image. seed: A deterministic seed to use when sampling. input_image_format: The format of input image which could be jpeg, png or none for unknown or mixed datasets. Returns: Processed frames. Tensor of shape [num_frames * num_test_clips, crop_height, crop_width, num_channels]. """ # Validate parameters. if is_training and num_test_clips != 1: logging.warning( '`num_test_clips` %d is ignored since `is_training` is `True`.', num_test_clips) if random_stride_range < 0: raise ValueError('Random stride range should be >= 0, got {}'.format( random_stride_range)) if input_image_format not in ('jpeg', 'png', 'none'): raise ValueError('Unknown input image format: {}'.format( input_image_format)) if isinstance(crop_size, int): crop_size = (crop_size, crop_size) crop_height, crop_width = crop_size # Temporal sampler. if is_training: if random_stride_range > 0: # Uniformly sample different frame-rates stride = tf.random.uniform( [], tf.maximum(stride - random_stride_range, 1), stride + random_stride_range, dtype=tf.int32) # Sample random clip. image = preprocess_ops_3d.sample_sequence(image, num_frames, True, stride, seed) elif num_test_clips > 1: # Sample linspace clips. image = preprocess_ops_3d.sample_linspace_sequence(image, num_test_clips, num_frames, stride) else: # Sample middle clip. image = preprocess_ops_3d.sample_sequence(image, num_frames, False, stride) # Decode JPEG string to tf.uint8. if image.dtype == tf.string: image = preprocess_ops_3d.decode_image(image, num_channels) if is_training: # Standard image data augmentation: random resized crop and random flip. image = preprocess_ops_3d.random_crop_resize( image, crop_height, crop_width, num_frames, num_channels, (min_aspect_ratio, max_aspect_ratio), (min_area_ratio, max_area_ratio)) image = preprocess_ops_3d.random_flip_left_right(image, seed) if augmenter is not None: image = augmenter.distort(image) else: # Resize images (resize happens only if necessary to save compute). image = preprocess_ops_3d.resize_smallest(image, min_resize) # Crop of the frames. image = preprocess_ops_3d.crop_image(image, crop_height, crop_width, False, num_crops) # Cast the frames in float32, normalizing according to zero_centering_image. return preprocess_ops_3d.normalize_image(image, zero_centering_image) def postprocess_image(image: tf.Tensor, is_training: bool = True, num_frames: int = 32, num_test_clips: int = 1, num_test_crops: int = 1) -> tf.Tensor: """Processes a batched Tensor of frames. The same parameters used in process should be used here. Args: image: Input Tensor of shape [batch, time-steps, height, width, 3]. is_training: Whether or not in training mode. If True, random sample, crop and left right flip is used. num_frames: Number of frames per sub clip. num_test_clips: Number of test clips (1 by default). If more than 1, this will sample multiple linearly spaced clips within each video at test time. If 1, then a single clip in the middle of the video is sampled. The clips are aggregated in the batch dimension. num_test_crops: Number of test crops (1 by default). If more than 1, there are multiple crops for each clip at test time. If 1, there is a single central crop. The crops are aggregated in the batch dimension. Returns: Processed frames. Tensor of shape [batch * num_test_clips * num_test_crops, num_frames, height, width, 3]. """ num_views = num_test_clips * num_test_crops if num_views > 1 and not is_training: # In this case, multiple views are merged together in batch dimension which # will be batch * num_views. image = tf.reshape(image, [-1, num_frames] + image.shape[2:].as_list()) return image def process_label(label: tf.Tensor, one_hot_label: bool = True, num_classes: Optional[int] = None, label_dtype: tf.DType = tf.int32) -> tf.Tensor: """Processes label Tensor.""" # Validate parameters. if one_hot_label and not num_classes: raise ValueError( '`num_classes` should be given when requesting one hot label.') # Cast to label_dtype (default = tf.int32). label = tf.cast(label, dtype=label_dtype) if one_hot_label: # Replace label index by one hot representation. label = tf.one_hot(label, num_classes) if len(label.shape.as_list()) > 1: label = tf.reduce_sum(label, axis=0) if num_classes == 1: # The trick for single label. label = 1 - label return label class Decoder(decoder.Decoder): """A tf.Example decoder for classification task.""" def __init__(self, image_key: str = IMAGE_KEY, label_key: str = LABEL_KEY): self._context_description = { # One integer stored in context. label_key: tf.io.VarLenFeature(tf.int64), } self._sequence_description = { # Each image is a string encoding JPEG. image_key: tf.io.FixedLenSequenceFeature((), tf.string), } def add_feature(self, feature_name: str, feature_type: Union[tf.io.VarLenFeature, tf.io.FixedLenFeature, tf.io.FixedLenSequenceFeature]): self._sequence_description[feature_name] = feature_type def add_context(self, feature_name: str, feature_type: Union[tf.io.VarLenFeature, tf.io.FixedLenFeature, tf.io.FixedLenSequenceFeature]): self._context_description[feature_name] = feature_type def decode(self, serialized_example): """Parses a single tf.Example into image and label tensors.""" result = {} context, sequences = tf.io.parse_single_sequence_example( serialized_example, self._context_description, self._sequence_description) result.update(context) result.update(sequences) for key, value in result.items(): if isinstance(value, tf.SparseTensor): result[key] = tf.sparse.to_dense(value) return result class VideoTfdsDecoder(decoder.Decoder): """A tf.SequenceExample decoder for tfds video classification datasets.""" def __init__(self, image_key: str = IMAGE_KEY, label_key: str = LABEL_KEY): self._image_key = image_key self._label_key = label_key def decode(self, features): """Decode the TFDS FeatureDict. Args: features: features from TFDS video dataset. See https://www.tensorflow.org/datasets/catalog/ucf101 for example. Returns: Dict of tensors. """ sample_dict = { self._image_key: features['video'], self._label_key: features['label'], } return sample_dict class Parser(parser.Parser): """Parses a video and label dataset.""" def __init__(self, input_params: exp_cfg.DataConfig, image_key: str = IMAGE_KEY, label_key: str = LABEL_KEY): self._num_frames = input_params.feature_shape[0] self._stride = input_params.temporal_stride self._random_stride_range = input_params.random_stride_range self._num_test_clips = input_params.num_test_clips self._min_resize = input_params.min_image_size crop_height = input_params.feature_shape[1] crop_width = input_params.feature_shape[2] self._crop_size = crop_height if crop_height == crop_width else ( crop_height, crop_width) self._num_channels = input_params.feature_shape[3] self._num_crops = input_params.num_test_crops self._zero_centering_image = input_params.zero_centering_image self._one_hot_label = input_params.one_hot self._num_classes = input_params.num_classes self._image_key = image_key self._label_key = label_key self._dtype = tf.dtypes.as_dtype(input_params.dtype) self._label_dtype = tf.dtypes.as_dtype(input_params.label_dtype) self._output_audio = input_params.output_audio self._min_aspect_ratio = input_params.aug_min_aspect_ratio self._max_aspect_ratio = input_params.aug_max_aspect_ratio self._min_area_ratio = input_params.aug_min_area_ratio self._max_area_ratio = input_params.aug_max_area_ratio self._input_image_format = input_params.input_image_format if self._output_audio: self._audio_feature = input_params.audio_feature self._audio_shape = input_params.audio_feature_shape aug_type = input_params.aug_type if aug_type is not None: if aug_type.type == 'autoaug': logging.info('Using AutoAugment.') self._augmenter = augment.AutoAugment( augmentation_name=aug_type.autoaug.augmentation_name, cutout_const=aug_type.autoaug.cutout_const, translate_const=aug_type.autoaug.translate_const) elif aug_type.type == 'randaug': logging.info('Using RandAugment.') self._augmenter = augment.RandAugment( num_layers=aug_type.randaug.num_layers, magnitude=aug_type.randaug.magnitude, cutout_const=aug_type.randaug.cutout_const, translate_const=aug_type.randaug.translate_const, prob_to_apply=aug_type.randaug.prob_to_apply, exclude_ops=aug_type.randaug.exclude_ops) else: raise ValueError( 'Augmentation policy {} not supported.'.format(aug_type.type)) else: self._augmenter = None def _parse_train_data( self, decoded_tensors: Dict[str, tf.Tensor] ) -> Tuple[Dict[str, tf.Tensor], tf.Tensor]: """Parses data for training.""" # Process image and label. image = decoded_tensors[self._image_key] image = process_image( image=image, is_training=True, num_frames=self._num_frames, stride=self._stride, random_stride_range=self._random_stride_range, num_test_clips=self._num_test_clips, min_resize=self._min_resize, crop_size=self._crop_size, num_channels=self._num_channels, min_aspect_ratio=self._min_aspect_ratio, max_aspect_ratio=self._max_aspect_ratio, min_area_ratio=self._min_area_ratio, max_area_ratio=self._max_area_ratio, augmenter=self._augmenter, zero_centering_image=self._zero_centering_image, input_image_format=self._input_image_format) image = tf.cast(image, dtype=self._dtype) features = {'image': image} label = decoded_tensors[self._label_key] label = process_label(label, self._one_hot_label, self._num_classes, self._label_dtype) if self._output_audio: audio = decoded_tensors[self._audio_feature] audio = tf.cast(audio, dtype=self._dtype) # TODO(yeqing): synchronize audio/video sampling. Especially randomness. audio = preprocess_ops_3d.sample_sequence( audio, self._audio_shape[0], random=False, stride=1) audio = tf.ensure_shape(audio, self._audio_shape) features['audio'] = audio return features, label def _parse_eval_data( self, decoded_tensors: Dict[str, tf.Tensor] ) -> Tuple[Dict[str, tf.Tensor], tf.Tensor]: """Parses data for evaluation.""" image = decoded_tensors[self._image_key] image = process_image( image=image, is_training=False, num_frames=self._num_frames, stride=self._stride, num_test_clips=self._num_test_clips, min_resize=self._min_resize, crop_size=self._crop_size, num_channels=self._num_channels, num_crops=self._num_crops, zero_centering_image=self._zero_centering_image, input_image_format=self._input_image_format) image = tf.cast(image, dtype=self._dtype) features = {'image': image} label = decoded_tensors[self._label_key] label = process_label(label, self._one_hot_label, self._num_classes, self._label_dtype) if self._output_audio: audio = decoded_tensors[self._audio_feature] audio = tf.cast(audio, dtype=self._dtype) audio = preprocess_ops_3d.sample_sequence( audio, self._audio_shape[0], random=False, stride=1) audio = tf.ensure_shape(audio, self._audio_shape) features['audio'] = audio return features, label class PostBatchProcessor(object): """Processes a video and label dataset which is batched.""" def __init__(self, input_params: exp_cfg.DataConfig): self._is_training = input_params.is_training self._num_frames = input_params.feature_shape[0] self._num_test_clips = input_params.num_test_clips self._num_test_crops = input_params.num_test_crops def __call__(self, features: Dict[str, tf.Tensor], label: tf.Tensor) -> Tuple[Dict[str, tf.Tensor], tf.Tensor]: """Parses a single tf.Example into image and label tensors.""" for key in ['image']: if key in features: features[key] = postprocess_image( image=features[key], is_training=self._is_training, num_frames=self._num_frames, num_test_clips=self._num_test_clips, num_test_crops=self._num_test_crops) return features, label