File size: 4,799 Bytes
5672777
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Example classification decoder and parser.

This file defines the Decoder and Parser to load data. The example is shown on
loading standard tf.Example data but non-standard tf.Example or other data
format can be supported by implementing proper decoder and parser.
"""
from typing import Mapping, List, Tuple
# Import libraries
import tensorflow as tf, tf_keras

from official.vision.dataloaders import decoder
from official.vision.dataloaders import parser
from official.vision.ops import preprocess_ops


class Decoder(decoder.Decoder):
  """A tf.Example decoder for classification task."""

  def __init__(self):
    """Initializes the decoder.

    The constructor defines the mapping between the field name and the value
    from an input tf.Example. For example, we define two fields for image bytes
    and labels. There is no limit on the number of fields to decode.
    """
    self._keys_to_features = {
        'image/encoded':
            tf.io.FixedLenFeature((), tf.string, default_value=''),
        'image/class/label':
            tf.io.FixedLenFeature((), tf.int64, default_value=-1)
    }

  def decode(self,
             serialized_example: tf.train.Example) -> Mapping[str, tf.Tensor]:
    """Decodes a tf.Example to a dictionary.

    This function decodes a serialized tf.Example to a dictionary. The output
    will be consumed by `_parse_train_data` and `_parse_validation_data` in
    Parser.

    Args:
      serialized_example: A serialized tf.Example.

    Returns:
      A dictionary of field key name and decoded tensor mapping.
    """
    return tf.io.parse_single_example(
        serialized_example, self._keys_to_features)


class Parser(parser.Parser):
  """Parser to parse an image and its annotations.

  To define own Parser, client should override _parse_train_data and
  _parse_eval_data functions, where decoded tensors are parsed with optional
  pre-processing steps. The output from the two functions can be any structure
  like tuple, list or dictionary.
  """

  def __init__(self, output_size: List[int], num_classes: float):
    """Initializes parameters for parsing annotations in the dataset.

    This example only takes two arguments but one can freely add as many
    arguments as needed. For example, pre-processing and augmentations usually
    happen in Parser, and related parameters can be passed in by this
    constructor.

    Args:
      output_size: `Tensor` or `list` for [height, width] of output image.
      num_classes: `float`, number of classes.
    """
    self._output_size = output_size
    self._num_classes = num_classes
    self._dtype = tf.float32

  def _parse_data(
      self, decoded_tensors: Mapping[str,
                                     tf.Tensor]) -> Tuple[tf.Tensor, tf.Tensor]:
    label = tf.cast(decoded_tensors['image/class/label'], dtype=tf.int32)
    image_bytes = decoded_tensors['image/encoded']
    image = tf.io.decode_jpeg(image_bytes, channels=3)
    image = tf.image.resize(
        image, self._output_size, method=tf.image.ResizeMethod.BILINEAR)
    image = tf.ensure_shape(image, self._output_size + [3])

    # Normalizes image with mean and std pixel values.
    image = preprocess_ops.normalize_image(
        image, offset=preprocess_ops.MEAN_RGB, scale=preprocess_ops.STDDEV_RGB)

    image = tf.image.convert_image_dtype(image, self._dtype)
    return image, label

  def _parse_train_data(
      self, decoded_tensors: Mapping[str,
                                     tf.Tensor]) -> Tuple[tf.Tensor, tf.Tensor]:
    """Parses data for training.

    Args:
      decoded_tensors: A dictionary of field key name and decoded tensor mapping
        from Decoder.

    Returns:
      A tuple of (image, label) tensors.

    """
    return self._parse_data(decoded_tensors)

  def _parse_eval_data(
      self, decoded_tensors: Mapping[str,
                                     tf.Tensor]) -> Tuple[tf.Tensor, tf.Tensor]:
    """Parses data for evaluation.

    Args:
      decoded_tensors: A dictionary of field key name and decoded tensor mapping
        from Decoder.

    Returns:
      A tuple of (image, label) tensors.
    """
    return self._parse_data(decoded_tensors)