deanna-emery's picture
updates
93528c6
raw
history blame
3.66 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Data loader and input processing."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from typing import Optional, Text
import tensorflow as tf, tf_keras
from official.legacy.detection.dataloader import factory
from official.legacy.detection.dataloader import mode_keys as ModeKeys
from official.modeling.hyperparams import params_dict
class InputFn(object):
"""Input function that creates dataset from files."""
def __init__(self,
file_pattern: Text,
params: params_dict.ParamsDict,
mode: Text,
batch_size: int,
num_examples: Optional[int] = -1):
"""Initialize.
Args:
file_pattern: the file pattern for the data example (TFRecords).
params: the parameter object for constructing example parser and model.
mode: ModeKeys.TRAIN or ModeKeys.Eval
batch_size: the data batch size.
num_examples: If positive, only takes this number of examples and raise
tf.errors.OutOfRangeError after that. If non-positive, it will be
ignored.
"""
assert file_pattern is not None
assert mode is not None
assert batch_size is not None
self._file_pattern = file_pattern
self._mode = mode
self._is_training = (mode == ModeKeys.TRAIN)
self._batch_size = batch_size
self._num_examples = num_examples
self._parser_fn = factory.parser_generator(params, mode)
self._dataset_fn = tf.data.TFRecordDataset
self._input_sharding = (not self._is_training)
try:
if self._is_training:
self._input_sharding = params.train.input_sharding
else:
self._input_sharding = params.eval.input_sharding
except AttributeError:
pass
def __call__(self, ctx=None, batch_size: int = None):
"""Provides tf.data.Dataset object.
Args:
ctx: context object.
batch_size: expected batch size input data.
Returns:
tf.data.Dataset object.
"""
if not batch_size:
batch_size = self._batch_size
assert batch_size is not None
dataset = tf.data.Dataset.list_files(
self._file_pattern, shuffle=self._is_training)
if self._input_sharding and ctx and ctx.num_input_pipelines > 1:
dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id)
dataset = dataset.cache()
if self._is_training:
dataset = dataset.repeat()
dataset = dataset.interleave(
map_func=self._dataset_fn,
cycle_length=32,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
if self._is_training:
dataset = dataset.shuffle(1000)
if self._num_examples > 0:
dataset = dataset.take(self._num_examples)
# Parses the fetched records to input tensors for model function.
dataset = dataset.map(
self._parser_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.batch(batch_size, drop_remainder=True)
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
return dataset