# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Classification decoder and parser.""" from typing import Any, Dict, List, Optional, Tuple # Import libraries import tensorflow as tf, tf_keras from official.vision.configs import common from official.vision.dataloaders import decoder from official.vision.dataloaders import parser from official.vision.ops import augment from official.vision.ops import preprocess_ops DEFAULT_IMAGE_FIELD_KEY = 'image/encoded' DEFAULT_LABEL_FIELD_KEY = 'image/class/label' class Decoder(decoder.Decoder): """A tf.Example decoder for classification task.""" def __init__(self, image_field_key: str = DEFAULT_IMAGE_FIELD_KEY, label_field_key: str = DEFAULT_LABEL_FIELD_KEY, is_multilabel: bool = False, keys_to_features: Optional[Dict[str, Any]] = None): if not keys_to_features: keys_to_features = { image_field_key: tf.io.FixedLenFeature((), tf.string, default_value=''), } if is_multilabel: keys_to_features.update( {label_field_key: tf.io.VarLenFeature(dtype=tf.int64)}) else: keys_to_features.update({ label_field_key: tf.io.FixedLenFeature((), tf.int64, default_value=-1) }) self._keys_to_features = keys_to_features def decode(self, serialized_example): return tf.io.parse_single_example(serialized_example, self._keys_to_features) class Parser(parser.Parser): """Parser to parse an image and its annotations into a dictionary of tensors.""" def __init__(self, output_size: List[int], num_classes: float, image_field_key: str = DEFAULT_IMAGE_FIELD_KEY, label_field_key: str = DEFAULT_LABEL_FIELD_KEY, decode_jpeg_only: bool = True, aug_rand_hflip: bool = True, aug_crop: Optional[bool] = True, aug_type: Optional[common.Augmentation] = None, color_jitter: float = 0., random_erasing: Optional[common.RandomErasing] = None, is_multilabel: bool = False, dtype: str = 'float32', crop_area_range: Optional[Tuple[float, float]] = (0.08, 1.0), center_crop_fraction: Optional[ float] = preprocess_ops.CENTER_CROP_FRACTION, tf_resize_method: str = 'bilinear', three_augment: bool = False): """Initializes parameters for parsing annotations in the dataset. Args: output_size: `Tensor` or `list` for [height, width] of output image. The output_size should be divided by the largest feature stride 2^max_level. num_classes: `float`, number of classes. image_field_key: `str`, the key name to encoded image or decoded image matrix in tf.Example. label_field_key: `str`, the key name to label in tf.Example. decode_jpeg_only: `bool`, if True, only JPEG format is decoded, this is faster than decoding other types. Default is True. aug_rand_hflip: `bool`, if True, augment training with random horizontal flip. aug_crop: `bool`, if True, perform random cropping during training and center crop during validation. aug_type: An optional Augmentation object to choose from AutoAugment and RandAugment. color_jitter: Magnitude of color jitter. If > 0, the value is used to generate random scale factor for brightness, contrast and saturation. See `preprocess_ops.color_jitter` for more details. random_erasing: if not None, augment input image by random erasing. See `augment.RandomErasing` for more details. is_multilabel: A `bool`, whether or not each example has multiple labels. dtype: `str`, cast output image in dtype. It can be 'float32', 'float16', or 'bfloat16'. crop_area_range: An optional `tuple` of (min_area, max_area) for image random crop function to constraint crop operation. The cropped areas of the image must contain a fraction of the input image within this range. The default area range is (0.08, 1.0). https://arxiv.org/abs/2204.07118. center_crop_fraction: center_crop_fraction. tf_resize_method: A `str`, interpolation method for resizing image. three_augment: A bool, whether to apply three augmentations. """ self._output_size = output_size self._aug_rand_hflip = aug_rand_hflip self._aug_crop = aug_crop self._num_classes = num_classes self._image_field_key = image_field_key if dtype == 'float32': self._dtype = tf.float32 elif dtype == 'float16': self._dtype = tf.float16 elif dtype == 'bfloat16': self._dtype = tf.bfloat16 else: raise ValueError('dtype {!r} is not supported!'.format(dtype)) if aug_type: if aug_type.type == 'autoaug': self._augmenter = augment.AutoAugment( augmentation_name=aug_type.autoaug.augmentation_name, cutout_const=aug_type.autoaug.cutout_const, translate_const=aug_type.autoaug.translate_const) elif aug_type.type == 'randaug': self._augmenter = augment.RandAugment( num_layers=aug_type.randaug.num_layers, magnitude=aug_type.randaug.magnitude, cutout_const=aug_type.randaug.cutout_const, translate_const=aug_type.randaug.translate_const, prob_to_apply=aug_type.randaug.prob_to_apply, exclude_ops=aug_type.randaug.exclude_ops) else: raise ValueError('Augmentation policy {} not supported.'.format( aug_type.type)) else: self._augmenter = None self._label_field_key = label_field_key self._color_jitter = color_jitter if random_erasing: self._random_erasing = augment.RandomErasing( probability=random_erasing.probability, min_area=random_erasing.min_area, max_area=random_erasing.max_area, min_aspect=random_erasing.min_aspect, max_aspect=random_erasing.max_aspect, min_count=random_erasing.min_count, max_count=random_erasing.max_count, trials=random_erasing.trials) else: self._random_erasing = None self._is_multilabel = is_multilabel self._decode_jpeg_only = decode_jpeg_only self._crop_area_range = crop_area_range self._center_crop_fraction = center_crop_fraction self._tf_resize_method = tf_resize_method self._three_augment = three_augment def _parse_train_data(self, decoded_tensors): """Parses data for training.""" image = self._parse_train_image(decoded_tensors) label = tf.cast(decoded_tensors[self._label_field_key], dtype=tf.int32) if self._is_multilabel: if isinstance(label, tf.sparse.SparseTensor): label = tf.sparse.to_dense(label) label = tf.reduce_sum(tf.one_hot(label, self._num_classes), axis=0) return image, label def _parse_eval_data(self, decoded_tensors): """Parses data for evaluation.""" image = self._parse_eval_image(decoded_tensors) label = tf.cast(decoded_tensors[self._label_field_key], dtype=tf.int32) if self._is_multilabel: if isinstance(label, tf.sparse.SparseTensor): label = tf.sparse.to_dense(label) label = tf.reduce_sum(tf.one_hot(label, self._num_classes), axis=0) return image, label def _parse_train_image(self, decoded_tensors): """Parses image data for training.""" image_bytes = decoded_tensors[self._image_field_key] require_decoding = ( not tf.is_tensor(image_bytes) or image_bytes.dtype == tf.dtypes.string ) if ( require_decoding and self._decode_jpeg_only and self._aug_crop ): image_shape = tf.image.extract_jpeg_shape(image_bytes) # Crops image. cropped_image = preprocess_ops.random_crop_image_v2( image_bytes, image_shape, area_range=self._crop_area_range) image = tf.cond( tf.reduce_all(tf.equal(tf.shape(cropped_image), image_shape)), lambda: preprocess_ops.center_crop_image_v2(image_bytes, image_shape), lambda: cropped_image) else: if require_decoding: # Decodes image. image = tf.io.decode_image(image_bytes, channels=3) image.set_shape([None, None, 3]) else: # Already decoded image matrix image = image_bytes # Crops image. if self._aug_crop: cropped_image = preprocess_ops.random_crop_image( image, area_range=self._crop_area_range) image = tf.cond( tf.reduce_all(tf.equal(tf.shape(cropped_image), tf.shape(image))), lambda: preprocess_ops.center_crop_image(image), lambda: cropped_image) if self._aug_rand_hflip: image = tf.image.random_flip_left_right(image) # Color jitter. if self._color_jitter > 0: image = preprocess_ops.color_jitter(image, self._color_jitter, self._color_jitter, self._color_jitter) # Resizes image. image = tf.image.resize( image, self._output_size, method=self._tf_resize_method) image.set_shape([self._output_size[0], self._output_size[1], 3]) # Apply autoaug or randaug. if self._augmenter is not None: image = self._augmenter.distort(image) # Three augmentation if self._three_augment: image = augment.AutoAugment( augmentation_name='deit3_three_augment', translate_const=20, ).distort(image) # Normalizes image with mean and std pixel values. image = preprocess_ops.normalize_image( image, offset=preprocess_ops.MEAN_RGB, scale=preprocess_ops.STDDEV_RGB) # Random erasing after the image has been normalized if self._random_erasing is not None: image = self._random_erasing.distort(image) # Convert image to self._dtype. image = tf.image.convert_image_dtype(image, self._dtype) return image def _parse_eval_image(self, decoded_tensors): """Parses image data for evaluation.""" image_bytes = decoded_tensors[self._image_field_key] require_decoding = ( not tf.is_tensor(image_bytes) or image_bytes.dtype == tf.dtypes.string ) if ( require_decoding and self._decode_jpeg_only and self._aug_crop ): image_shape = tf.image.extract_jpeg_shape(image_bytes) # Center crops. image = preprocess_ops.center_crop_image_v2( image_bytes, image_shape, self._center_crop_fraction) else: if require_decoding: # Decodes image. image = tf.io.decode_image(image_bytes, channels=3) image.set_shape([None, None, 3]) else: # Already decoded image matrix image = image_bytes # Center crops. if self._aug_crop: image = preprocess_ops.center_crop_image( image, self._center_crop_fraction) image = tf.image.resize( image, self._output_size, method=self._tf_resize_method) image.set_shape([self._output_size[0], self._output_size[1], 3]) # Normalizes image with mean and std pixel values. image = preprocess_ops.normalize_image( image, offset=preprocess_ops.MEAN_RGB, scale=preprocess_ops.STDDEV_RGB) # Convert image to self._dtype. image = tf.image.convert_image_dtype(image, self._dtype) return image def parse_train_image(self, decoded_tensors: Dict[str, tf.Tensor]) -> tf.Tensor: """Public interface for parsing image data for training.""" return self._parse_train_image(decoded_tensors) @classmethod def inference_fn(cls, image: tf.Tensor, input_image_size: List[int], num_channels: int = 3) -> tf.Tensor: """Builds image model inputs for serving.""" image = tf.cast(image, dtype=tf.float32) image = preprocess_ops.center_crop_image(image) image = tf.image.resize( image, input_image_size, method=tf.image.ResizeMethod.BILINEAR) # Normalizes image with mean and std pixel values. image = preprocess_ops.normalize_image( image, offset=preprocess_ops.MEAN_RGB, scale=preprocess_ops.STDDEV_RGB) image.set_shape(input_image_size + [num_channels]) return image