# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Example classification decoder and parser. This file defines the Decoder and Parser to load data. The example is shown on loading standard tf.Example data but non-standard tf.Example or other data format can be supported by implementing proper decoder and parser. """ from typing import Mapping, List, Tuple # Import libraries import tensorflow as tf, tf_keras from official.vision.dataloaders import decoder from official.vision.dataloaders import parser from official.vision.ops import preprocess_ops class Decoder(decoder.Decoder): """A tf.Example decoder for classification task.""" def __init__(self): """Initializes the decoder. The constructor defines the mapping between the field name and the value from an input tf.Example. For example, we define two fields for image bytes and labels. There is no limit on the number of fields to decode. """ self._keys_to_features = { 'image/encoded': tf.io.FixedLenFeature((), tf.string, default_value=''), 'image/class/label': tf.io.FixedLenFeature((), tf.int64, default_value=-1) } def decode(self, serialized_example: tf.train.Example) -> Mapping[str, tf.Tensor]: """Decodes a tf.Example to a dictionary. This function decodes a serialized tf.Example to a dictionary. The output will be consumed by `_parse_train_data` and `_parse_validation_data` in Parser. Args: serialized_example: A serialized tf.Example. Returns: A dictionary of field key name and decoded tensor mapping. """ return tf.io.parse_single_example( serialized_example, self._keys_to_features) class Parser(parser.Parser): """Parser to parse an image and its annotations. To define own Parser, client should override _parse_train_data and _parse_eval_data functions, where decoded tensors are parsed with optional pre-processing steps. The output from the two functions can be any structure like tuple, list or dictionary. """ def __init__(self, output_size: List[int], num_classes: float): """Initializes parameters for parsing annotations in the dataset. This example only takes two arguments but one can freely add as many arguments as needed. For example, pre-processing and augmentations usually happen in Parser, and related parameters can be passed in by this constructor. Args: output_size: `Tensor` or `list` for [height, width] of output image. num_classes: `float`, number of classes. """ self._output_size = output_size self._num_classes = num_classes self._dtype = tf.float32 def _parse_data( self, decoded_tensors: Mapping[str, tf.Tensor]) -> Tuple[tf.Tensor, tf.Tensor]: label = tf.cast(decoded_tensors['image/class/label'], dtype=tf.int32) image_bytes = decoded_tensors['image/encoded'] image = tf.io.decode_jpeg(image_bytes, channels=3) image = tf.image.resize( image, self._output_size, method=tf.image.ResizeMethod.BILINEAR) image = tf.ensure_shape(image, self._output_size + [3]) # Normalizes image with mean and std pixel values. image = preprocess_ops.normalize_image( image, offset=preprocess_ops.MEAN_RGB, scale=preprocess_ops.STDDEV_RGB) image = tf.image.convert_image_dtype(image, self._dtype) return image, label def _parse_train_data( self, decoded_tensors: Mapping[str, tf.Tensor]) -> Tuple[tf.Tensor, tf.Tensor]: """Parses data for training. Args: decoded_tensors: A dictionary of field key name and decoded tensor mapping from Decoder. Returns: A tuple of (image, label) tensors. """ return self._parse_data(decoded_tensors) def _parse_eval_data( self, decoded_tensors: Mapping[str, tf.Tensor]) -> Tuple[tf.Tensor, tf.Tensor]: """Parses data for evaluation. Args: decoded_tensors: A dictionary of field key name and decoded tensor mapping from Decoder. Returns: A tuple of (image, label) tensors. """ return self._parse_data(decoded_tensors)