# Lint as: python2, python3 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Input reader builder. Creates data sources for DetectionModels from an InputReader config. See input_reader.proto for options. Note: If users wishes to also use their own InputReaders with the Object Detection configuration framework, they should define their own builder function that wraps the build function. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import tensorflow.compat.v1 as tf import tf_slim as slim from object_detection.data_decoders import tf_example_decoder from object_detection.data_decoders import tf_sequence_example_decoder from object_detection.protos import input_reader_pb2 parallel_reader = slim.parallel_reader def build(input_reader_config): """Builds a tensor dictionary based on the InputReader config. Args: input_reader_config: A input_reader_pb2.InputReader object. Returns: A tensor dict based on the input_reader_config. Raises: ValueError: On invalid input reader proto. ValueError: If no input paths are specified. """ if not isinstance(input_reader_config, input_reader_pb2.InputReader): raise ValueError('input_reader_config not of type ' 'input_reader_pb2.InputReader.') if input_reader_config.WhichOneof('input_reader') == 'tf_record_input_reader': config = input_reader_config.tf_record_input_reader if not config.input_path: raise ValueError('At least one input path must be specified in ' '`input_reader_config`.') _, string_tensor = parallel_reader.parallel_read( config.input_path[:], # Convert `RepeatedScalarContainer` to list. reader_class=tf.TFRecordReader, num_epochs=(input_reader_config.num_epochs if input_reader_config.num_epochs else None), num_readers=input_reader_config.num_readers, shuffle=input_reader_config.shuffle, dtypes=[tf.string, tf.string], capacity=input_reader_config.queue_capacity, min_after_dequeue=input_reader_config.min_after_dequeue) label_map_proto_file = None if input_reader_config.HasField('label_map_path'): label_map_proto_file = input_reader_config.label_map_path input_type = input_reader_config.input_type if input_type == input_reader_pb2.InputType.Value('TF_EXAMPLE'): decoder = tf_example_decoder.TfExampleDecoder( load_instance_masks=input_reader_config.load_instance_masks, instance_mask_type=input_reader_config.mask_type, label_map_proto_file=label_map_proto_file, load_context_features=input_reader_config.load_context_features) return decoder.decode(string_tensor) elif input_type == input_reader_pb2.InputType.Value('TF_SEQUENCE_EXAMPLE'): decoder = tf_sequence_example_decoder.TfSequenceExampleDecoder( label_map_proto_file=label_map_proto_file, load_context_features=input_reader_config.load_context_features) return decoder.decode(string_tensor) raise ValueError('Unsupported input_type.') raise ValueError('Unsupported input_reader_config.')