File size: 10,716 Bytes
5672777
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Dataset reader for vision model garden."""

from typing import Any, Callable, Mapping, Optional, Tuple, Union

from absl import logging
import tensorflow as tf, tf_keras

from official.core import config_definitions as cfg
from official.core import input_reader


InputReader = input_reader.InputReader


def build_weighted_sampling_combine_fn(
    weights: Mapping[Any, Any], stop_on_empty_dataset=True
) -> Callable[[tf.data.Dataset], tf.data.Dataset]:
  """Builds a combine_fn using weighted sampling."""

  def combine_fn(datasets: Mapping[Any, tf.data.Dataset]) -> tf.data.Dataset:
    """Combines multiple datasets using weighted sampling."""
    ds = []
    ws = []
    for k, dataset in datasets.items():
      ds.append(dataset)
      ws.append(weights[k])
    return tf.data.Dataset.sample_from_datasets(
        ds, ws, stop_on_empty_dataset=stop_on_empty_dataset)

  return combine_fn


def create_combine_fn(
    params: cfg.DataConfig
) -> Union[None, Callable[[tf.data.Dataset], tf.data.Dataset]]:
  """Creates and returns a combine_fn for dataset mixing."""
  if (
      hasattr(params, 'stop_on_empty_dataset')
      and params.stop_on_empty_dataset is not None
  ):
    stop_on_empty_dataset = params.stop_on_empty_dataset
  else:
    stop_on_empty_dataset = True

  if params.weights:
    # Combine multiple datasets using weighted sampling.
    if (not isinstance(params.input_path, cfg.base_config.Config) or
        not isinstance(params.weights, cfg.base_config.Config)):
      raise ValueError(
          'input_path and weights must both be a Config to use weighted '
          'sampling.')
    input_paths = params.input_path.as_dict()
    weights = params.weights.as_dict()
    if len(input_paths) != len(weights):
      raise ValueError(
          'The number of input_path and weights must be the same, but got %d '
          'input_paths and %d weights.' % (len(input_paths), len(weights)))

    for k in input_paths.keys():
      if k not in weights:
        raise ValueError(
            'input_path key \'%s\' does not have a corresponding weight.' % k)

    return build_weighted_sampling_combine_fn(weights, stop_on_empty_dataset)
  return None


def calculate_batch_sizes(total_batch_size: int,
                          pseudo_label_ratio: float,
                          pseudo_label_batch_size: int = 0) -> Tuple[int, int]:
  """Calculates labeled and pseudo-labeled dataset batch sizes.

  Returns (labeled_batch_size, pseudo_labeled_batch_size) given a
  total batch size and pseudo-label data ratio.

  Args:
   total_batch_size: The total batch size for all data.
   pseudo_label_ratio: A float ratio of pseudo-labeled to labeled data in a
     batch. If it is negative, use `pseudo_label_batch_size` instead.
   pseudo_label_batch_size: The batch size of pseudo-labeled data. It is ignored
     if `pseudo_label_ratio` is valid. If not, it will be used and it cannot be
     larger than total global batch size or less than 0 if pseudo_label_ratio is
     also less than 0.

  Returns:
    (labeled_batch_size, pseudo_labeled_batch_size) as ints.

  Raises:
    ValueError: If total_batch_size is negative, or both If pseudo_label_ratio
      is negative and pseudo-label global_batch_size is negative or larger than
      total batch size.
  """
  if total_batch_size < 0:
    raise ValueError('Invalid total_batch_size: {}'.format(total_batch_size))
  if pseudo_label_ratio >= 0.0:
    ratio_factor = pseudo_label_ratio / (1.0 + pseudo_label_ratio)
    pseudo_label_batch_size = int(total_batch_size * ratio_factor)
    label_batch_size = total_batch_size - pseudo_label_batch_size
  else:
    if pseudo_label_batch_size > total_batch_size or pseudo_label_batch_size < 0:
      raise ValueError(
          'The batch size of pseudo-label dataset should not be larger than '
          'total global batch size.')
    logging.info('data_ratio for pseudo-label dataset is less than 0. '
                 'Use global_batch_size from pseudo_label data config instead.')
    label_batch_size = total_batch_size - pseudo_label_batch_size
  return label_batch_size, pseudo_label_batch_size


class CombinationDatasetInputReader(input_reader.InputReader):
  """Combination dataset input reader."""

  def __init__(self,
               params: cfg.DataConfig,
               dataset_fn=tf.data.TFRecordDataset,
               pseudo_label_dataset_fn=tf.data.TFRecordDataset,
               decoder_fn: Optional[Callable[..., Any]] = None,
               combine_fn: Optional[Callable[..., Any]] = None,
               sample_fn: Optional[Callable[..., Any]] = None,
               parser_fn: Optional[Callable[..., Any]] = None,
               transform_and_batch_fn: Optional[Callable[
                   [tf.data.Dataset, Optional[tf.distribute.InputContext]],
                   tf.data.Dataset]] = None,
               postprocess_fn: Optional[Callable[..., Any]] = None):
    """Initializes an CombinationDatasetInputReader instance.

    This class mixes a labeled and pseudo-labeled dataset. The params
    must contain "pseudo_label_data.input_path" to specify the
    pseudo-label dataset files and "pseudo_label_data.data_ratio"
    to specify a per-batch mixing ratio of pseudo-label examples to
    labeled dataset examples.

    Args:
      params: A config_definitions.DataConfig object.
      dataset_fn: A `tf.data.Dataset` that consumes the input files. For
        example, it can be `tf.data.TFRecordDataset`.
      pseudo_label_dataset_fn: A `tf.data.Dataset` that consumes the input
        files. For example, it can be `tf.data.TFRecordDataset`.
      decoder_fn: An optional `callable` that takes the serialized data string
        and decodes them into the raw tensor dictionary.
      combine_fn: An optional `callable` that takes a dictionarty of
        `tf.data.Dataset` objects as input and outputs a combined dataset. It
        will be executed after the decoder_fn and before the sample_fn.
      sample_fn: An optional `callable` that takes a `tf.data.Dataset` object as
        input and outputs the transformed dataset. It performs sampling on the
        decoded raw tensors dict before the parser_fn.
      parser_fn: An optional `callable` that takes the decoded raw tensors dict
        and parse them into a dictionary of tensors that can be consumed by the
        model. It will be executed after decoder_fn.
      transform_and_batch_fn: An optional `callable` that takes a
        `tf.data.Dataset` object and an optional `tf.distribute.InputContext` as
        input, and returns a `tf.data.Dataset` object. It will be executed after
        `parser_fn` to transform and batch the dataset; if None, after
        `parser_fn` is executed, the dataset will be batched into per-replica
        batch size.
      postprocess_fn: A optional `callable` that processes batched tensors. It
        will be executed after batching.

    Raises:
      ValueError: If drop_remainder is False.
    """
    super().__init__(
        params=params,
        dataset_fn=dataset_fn,
        decoder_fn=decoder_fn,
        combine_fn=combine_fn,
        sample_fn=sample_fn,
        parser_fn=parser_fn,
        transform_and_batch_fn=transform_and_batch_fn,
        postprocess_fn=postprocess_fn)

    self._pseudo_label_file_pattern = params.pseudo_label_data.input_path
    self._pseudo_label_dataset_fn = pseudo_label_dataset_fn
    self._pseudo_label_data_ratio = params.pseudo_label_data.data_ratio
    self._pseudo_label_batch_size = params.pseudo_label_data.global_batch_size
    self._pseudo_label_matched_files = input_reader.match_files(
        self._pseudo_label_file_pattern)
    if not self._drop_remainder:
      raise ValueError(
          'Must use drop_remainder=True with CombinationDatasetInputReader')

  def read(
      self,
      input_context: Optional[tf.distribute.InputContext] = None
  ) -> tf.data.Dataset:
    """Generates a tf.data.Dataset object."""

    labeled_batch_size, pl_batch_size = calculate_batch_sizes(
        self._global_batch_size, self._pseudo_label_data_ratio,
        self._pseudo_label_batch_size)

    if not labeled_batch_size and pl_batch_size:
      raise ValueError(
          'Invalid batch_size: {} and pseudo_label_data_ratio: {}, '
          'resulting in a 0 batch size for one of the datasets.'.format(
              self._global_batch_size, self._pseudo_label_data_ratio))

    def _read_decode_and_parse_dataset(matched_files, dataset_fn, batch_size,
                                       input_context):
      dataset = self._read_data_source(matched_files, dataset_fn, input_context)
      return self._decode_and_parse_dataset(dataset, batch_size, input_context)

    labeled_dataset = _read_decode_and_parse_dataset(
        matched_files=self._matched_files,
        dataset_fn=self._dataset_fn,
        batch_size=labeled_batch_size,
        input_context=input_context)

    pseudo_labeled_dataset = _read_decode_and_parse_dataset(
        matched_files=self._pseudo_label_matched_files,
        dataset_fn=self._pseudo_label_dataset_fn,
        batch_size=pl_batch_size,
        input_context=input_context)

    def concat_fn(d1, d2):
      return tf.nest.map_structure(
          lambda x1, x2: tf.concat([x1, x2], axis=0), d1, d2)

    dataset_concat = tf.data.Dataset.zip(
        (labeled_dataset, pseudo_labeled_dataset))
    dataset_concat = dataset_concat.map(
        concat_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)

    def maybe_map_fn(dataset, fn):
      return dataset if fn is None else dataset.map(
          fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)

    dataset_concat = maybe_map_fn(dataset_concat, self._postprocess_fn)
    dataset_concat = self._maybe_apply_data_service(dataset_concat,
                                                    input_context)

    if self._deterministic is not None:
      options = tf.data.Options()
      options.experimental_deterministic = self._deterministic
      dataset_concat = dataset_concat.with_options(options)

    return dataset_concat.prefetch(tf.data.experimental.AUTOTUNE)