Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Data loader and input processing.""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
from typing import Optional, Text | |
import tensorflow as tf, tf_keras | |
from official.legacy.detection.dataloader import factory | |
from official.legacy.detection.dataloader import mode_keys as ModeKeys | |
from official.modeling.hyperparams import params_dict | |
class InputFn(object): | |
"""Input function that creates dataset from files.""" | |
def __init__(self, | |
file_pattern: Text, | |
params: params_dict.ParamsDict, | |
mode: Text, | |
batch_size: int, | |
num_examples: Optional[int] = -1): | |
"""Initialize. | |
Args: | |
file_pattern: the file pattern for the data example (TFRecords). | |
params: the parameter object for constructing example parser and model. | |
mode: ModeKeys.TRAIN or ModeKeys.Eval | |
batch_size: the data batch size. | |
num_examples: If positive, only takes this number of examples and raise | |
tf.errors.OutOfRangeError after that. If non-positive, it will be | |
ignored. | |
""" | |
assert file_pattern is not None | |
assert mode is not None | |
assert batch_size is not None | |
self._file_pattern = file_pattern | |
self._mode = mode | |
self._is_training = (mode == ModeKeys.TRAIN) | |
self._batch_size = batch_size | |
self._num_examples = num_examples | |
self._parser_fn = factory.parser_generator(params, mode) | |
self._dataset_fn = tf.data.TFRecordDataset | |
self._input_sharding = (not self._is_training) | |
try: | |
if self._is_training: | |
self._input_sharding = params.train.input_sharding | |
else: | |
self._input_sharding = params.eval.input_sharding | |
except AttributeError: | |
pass | |
def __call__(self, ctx=None, batch_size: int = None): | |
"""Provides tf.data.Dataset object. | |
Args: | |
ctx: context object. | |
batch_size: expected batch size input data. | |
Returns: | |
tf.data.Dataset object. | |
""" | |
if not batch_size: | |
batch_size = self._batch_size | |
assert batch_size is not None | |
dataset = tf.data.Dataset.list_files( | |
self._file_pattern, shuffle=self._is_training) | |
if self._input_sharding and ctx and ctx.num_input_pipelines > 1: | |
dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id) | |
dataset = dataset.cache() | |
if self._is_training: | |
dataset = dataset.repeat() | |
dataset = dataset.interleave( | |
map_func=self._dataset_fn, | |
cycle_length=32, | |
num_parallel_calls=tf.data.experimental.AUTOTUNE) | |
if self._is_training: | |
dataset = dataset.shuffle(1000) | |
if self._num_examples > 0: | |
dataset = dataset.take(self._num_examples) | |
# Parses the fetched records to input tensors for model function. | |
dataset = dataset.map( | |
self._parser_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) | |
dataset = dataset.batch(batch_size, drop_remainder=True) | |
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) | |
return dataset | |