NCTCMumbai's picture
Upload 2583 files
18ddfe2 verified
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A common dataset reader."""
from typing import Any, Callable, List, Optional
import tensorflow as tf
import tensorflow_datasets as tfds
from official.modeling.hyperparams import config_definitions as cfg
class InputReader:
"""Input reader that returns a tf.data.Dataset instance."""
def __init__(self,
params: cfg.DataConfig,
shards: Optional[List[str]] = None,
dataset_fn=tf.data.TFRecordDataset,
decoder_fn: Optional[Callable[..., Any]] = None,
parser_fn: Optional[Callable[..., Any]] = None,
dataset_transform_fn: Optional[Callable[[tf.data.Dataset],
tf.data.Dataset]] = None,
postprocess_fn: Optional[Callable[..., Any]] = None):
"""Initializes an InputReader instance.
Args:
params: A config_definitions.DataConfig object.
shards: A list of files to be read. If given, read from these files.
Otherwise, read from params.input_path.
dataset_fn: A `tf.data.Dataset` that consumes the input files. For
example, it can be `tf.data.TFRecordDataset`.
decoder_fn: An optional `callable` that takes the serialized data string
and decodes them into the raw tensor dictionary.
parser_fn: An optional `callable` that takes the decoded raw tensors dict
and parse them into a dictionary of tensors that can be consumed by the
model. It will be executed after decoder_fn.
dataset_transform_fn: An optional `callable` that takes a
`tf.data.Dataset` object and returns a `tf.data.Dataset`. It will be
executed after parser_fn.
postprocess_fn: A optional `callable` that processes batched tensors. It
will be executed after batching.
"""
if params.input_path and params.tfds_name:
raise ValueError('At most one of `input_path` and `tfds_name` can be '
'specified, but got %s and %s.' % (
params.input_path, params.tfds_name))
self._shards = shards
self._tfds_builder = None
if self._shards:
self._num_files = len(self._shards)
elif not params.tfds_name:
self._input_patterns = params.input_path.strip().split(',')
self._num_files = 0
for input_pattern in self._input_patterns:
input_pattern = input_pattern.strip()
if not input_pattern:
continue
matched_files = tf.io.gfile.glob(input_pattern)
if not matched_files:
raise ValueError('%s does not match any files.' % input_pattern)
else:
self._num_files += len(matched_files)
if self._num_files == 0:
raise ValueError('%s does not match any files.' % params.input_path)
else:
if not params.tfds_split:
raise ValueError(
'`tfds_name` is %s, but `tfds_split` is not specified.' %
params.tfds_name)
self._tfds_builder = tfds.builder(
params.tfds_name, data_dir=params.tfds_data_dir)
self._global_batch_size = params.global_batch_size
self._is_training = params.is_training
self._drop_remainder = params.drop_remainder
self._shuffle_buffer_size = params.shuffle_buffer_size
self._cache = params.cache
self._cycle_length = params.cycle_length
self._block_length = params.block_length
self._sharding = params.sharding
self._examples_consume = params.examples_consume
self._tfds_split = params.tfds_split
self._tfds_download = params.tfds_download
self._tfds_as_supervised = params.tfds_as_supervised
self._tfds_skip_decoding_feature = params.tfds_skip_decoding_feature
self._dataset_fn = dataset_fn
self._decoder_fn = decoder_fn
self._parser_fn = parser_fn
self._dataset_transform_fn = dataset_transform_fn
self._postprocess_fn = postprocess_fn
def _read_sharded_files(
self,
input_context: Optional[tf.distribute.InputContext] = None):
"""Reads a dataset from sharded files."""
# Read from `self._shards` if it is provided.
if self._shards:
dataset = tf.data.Dataset.from_tensor_slices(self._shards)
else:
dataset = tf.data.Dataset.list_files(
self._input_patterns, shuffle=self._is_training)
if self._sharding and input_context and (
input_context.num_input_pipelines > 1):
dataset = dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id)
if self._is_training:
dataset = dataset.repeat()
dataset = dataset.interleave(
map_func=self._dataset_fn,
cycle_length=self._cycle_length,
block_length=self._block_length,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
return dataset
def _read_single_file(
self,
input_context: Optional[tf.distribute.InputContext] = None):
"""Reads a dataset from a single file."""
# Read from `self._shards` if it is provided.
dataset = self._dataset_fn(self._shards or self._input_patterns)
# When `input_file` is a path to a single file, disable auto sharding
# so that same input file is sent to all workers.
options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = (
tf.data.experimental.AutoShardPolicy.OFF)
dataset = dataset.with_options(options)
if self._sharding and input_context and (
input_context.num_input_pipelines > 1):
dataset = dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id)
if self._is_training:
dataset = dataset.repeat()
return dataset
def _read_tfds(
self,
input_context: Optional[tf.distribute.InputContext] = None
) -> tf.data.Dataset:
"""Reads a dataset from tfds."""
if self._tfds_download:
self._tfds_builder.download_and_prepare()
read_config = tfds.ReadConfig(
interleave_cycle_length=self._cycle_length,
interleave_block_length=self._block_length,
input_context=input_context)
decoders = {}
if self._tfds_skip_decoding_feature:
for skip_feature in self._tfds_skip_decoding_feature.split(','):
decoders[skip_feature.strip()] = tfds.decode.SkipDecoding()
dataset = self._tfds_builder.as_dataset(
split=self._tfds_split,
shuffle_files=self._is_training,
as_supervised=self._tfds_as_supervised,
decoders=decoders,
read_config=read_config)
return dataset
@property
def tfds_info(self) -> tfds.core.DatasetInfo:
"""Returns TFDS dataset info, if available."""
if self._tfds_builder:
return self._tfds_builder.info
else:
raise ValueError('tfds_info is not available, because the dataset '
'is not loaded from tfds.')
def read(
self,
input_context: Optional[tf.distribute.InputContext] = None
) -> tf.data.Dataset:
"""Generates a tf.data.Dataset object."""
if self._tfds_builder:
dataset = self._read_tfds(input_context)
elif self._num_files > 1:
dataset = self._read_sharded_files(input_context)
else:
assert self._num_files == 1
dataset = self._read_single_file(input_context)
if self._cache:
dataset = dataset.cache()
if self._is_training:
dataset = dataset.shuffle(self._shuffle_buffer_size)
if self._examples_consume > 0:
dataset = dataset.take(self._examples_consume)
def maybe_map_fn(dataset, fn):
return dataset if fn is None else dataset.map(
fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = maybe_map_fn(dataset, self._decoder_fn)
dataset = maybe_map_fn(dataset, self._parser_fn)
if self._dataset_transform_fn is not None:
dataset = self._dataset_transform_fn(dataset)
per_replica_batch_size = input_context.get_per_replica_batch_size(
self._global_batch_size) if input_context else self._global_batch_size
dataset = dataset.batch(
per_replica_batch_size, drop_remainder=self._drop_remainder)
dataset = maybe_map_fn(dataset, self._postprocess_fn)
return dataset.prefetch(tf.data.experimental.AUTOTUNE)