|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Functions to export object detection inference graph.""" |
|
import os |
|
import tensorflow.compat.v2 as tf |
|
from object_detection.builders import model_builder |
|
from object_detection.core import standard_fields as fields |
|
from object_detection.data_decoders import tf_example_decoder |
|
from object_detection.utils import config_util |
|
|
|
|
|
def _decode_image(encoded_image_string_tensor): |
|
image_tensor = tf.image.decode_image(encoded_image_string_tensor, |
|
channels=3) |
|
image_tensor.set_shape((None, None, 3)) |
|
return image_tensor |
|
|
|
|
|
def _decode_tf_example(tf_example_string_tensor): |
|
tensor_dict = tf_example_decoder.TfExampleDecoder().decode( |
|
tf_example_string_tensor) |
|
image_tensor = tensor_dict[fields.InputDataFields.image] |
|
return image_tensor |
|
|
|
|
|
class DetectionInferenceModule(tf.Module): |
|
"""Detection Inference Module.""" |
|
|
|
def __init__(self, detection_model): |
|
"""Initializes a module for detection. |
|
|
|
Args: |
|
detection_model: The detection model to use for inference. |
|
""" |
|
self._model = detection_model |
|
|
|
def _run_inference_on_images(self, image): |
|
"""Cast image to float and run inference. |
|
|
|
Args: |
|
image: uint8 Tensor of shape [1, None, None, 3] |
|
Returns: |
|
Tensor dictionary holding detections. |
|
""" |
|
label_id_offset = 1 |
|
|
|
image = tf.cast(image, tf.float32) |
|
image, shapes = self._model.preprocess(image) |
|
prediction_dict = self._model.predict(image, shapes) |
|
detections = self._model.postprocess(prediction_dict, shapes) |
|
classes_field = fields.DetectionResultFields.detection_classes |
|
detections[classes_field] = ( |
|
tf.cast(detections[classes_field], tf.float32) + label_id_offset) |
|
|
|
for key, val in detections.items(): |
|
detections[key] = tf.cast(val, tf.float32) |
|
|
|
return detections |
|
|
|
|
|
class DetectionFromImageModule(DetectionInferenceModule): |
|
"""Detection Inference Module for image inputs.""" |
|
|
|
@tf.function( |
|
input_signature=[ |
|
tf.TensorSpec(shape=[1, None, None, 3], dtype=tf.uint8)]) |
|
def __call__(self, input_tensor): |
|
return self._run_inference_on_images(input_tensor) |
|
|
|
|
|
class DetectionFromFloatImageModule(DetectionInferenceModule): |
|
"""Detection Inference Module for float image inputs.""" |
|
|
|
@tf.function( |
|
input_signature=[ |
|
tf.TensorSpec(shape=[1, None, None, 3], dtype=tf.float32)]) |
|
def __call__(self, input_tensor): |
|
return self._run_inference_on_images(input_tensor) |
|
|
|
|
|
class DetectionFromEncodedImageModule(DetectionInferenceModule): |
|
"""Detection Inference Module for encoded image string inputs.""" |
|
|
|
@tf.function(input_signature=[tf.TensorSpec(shape=[1], dtype=tf.string)]) |
|
def __call__(self, input_tensor): |
|
with tf.device('cpu:0'): |
|
image = tf.map_fn( |
|
_decode_image, |
|
elems=input_tensor, |
|
dtype=tf.uint8, |
|
parallel_iterations=32, |
|
back_prop=False) |
|
return self._run_inference_on_images(image) |
|
|
|
|
|
class DetectionFromTFExampleModule(DetectionInferenceModule): |
|
"""Detection Inference Module for TF.Example inputs.""" |
|
|
|
@tf.function(input_signature=[tf.TensorSpec(shape=[1], dtype=tf.string)]) |
|
def __call__(self, input_tensor): |
|
with tf.device('cpu:0'): |
|
image = tf.map_fn( |
|
_decode_tf_example, |
|
elems=input_tensor, |
|
dtype=tf.uint8, |
|
parallel_iterations=32, |
|
back_prop=False) |
|
return self._run_inference_on_images(image) |
|
|
|
DETECTION_MODULE_MAP = { |
|
'image_tensor': DetectionFromImageModule, |
|
'encoded_image_string_tensor': |
|
DetectionFromEncodedImageModule, |
|
'tf_example': DetectionFromTFExampleModule, |
|
'float_image_tensor': DetectionFromFloatImageModule |
|
} |
|
|
|
|
|
def export_inference_graph(input_type, |
|
pipeline_config, |
|
trained_checkpoint_dir, |
|
output_directory): |
|
"""Exports inference graph for the model specified in the pipeline config. |
|
|
|
This function creates `output_directory` if it does not already exist, |
|
which will hold a copy of the pipeline config with filename `pipeline.config`, |
|
and two subdirectories named `checkpoint` and `saved_model` |
|
(containing the exported checkpoint and SavedModel respectively). |
|
|
|
Args: |
|
input_type: Type of input for the graph. Can be one of ['image_tensor', |
|
'encoded_image_string_tensor', 'tf_example']. |
|
pipeline_config: pipeline_pb2.TrainAndEvalPipelineConfig proto. |
|
trained_checkpoint_dir: Path to the trained checkpoint file. |
|
output_directory: Path to write outputs. |
|
Raises: |
|
ValueError: if input_type is invalid. |
|
""" |
|
output_checkpoint_directory = os.path.join(output_directory, 'checkpoint') |
|
output_saved_model_directory = os.path.join(output_directory, 'saved_model') |
|
|
|
detection_model = model_builder.build(pipeline_config.model, |
|
is_training=False) |
|
|
|
ckpt = tf.train.Checkpoint( |
|
model=detection_model) |
|
manager = tf.train.CheckpointManager( |
|
ckpt, trained_checkpoint_dir, max_to_keep=1) |
|
status = ckpt.restore(manager.latest_checkpoint).expect_partial() |
|
|
|
if input_type not in DETECTION_MODULE_MAP: |
|
raise ValueError('Unrecognized `input_type`') |
|
detection_module = DETECTION_MODULE_MAP[input_type](detection_model) |
|
|
|
|
|
|
|
concrete_function = detection_module.__call__.get_concrete_function() |
|
status.assert_existing_objects_matched() |
|
|
|
exported_checkpoint_manager = tf.train.CheckpointManager( |
|
ckpt, output_checkpoint_directory, max_to_keep=1) |
|
exported_checkpoint_manager.save(checkpoint_number=0) |
|
|
|
tf.saved_model.save(detection_module, |
|
output_saved_model_directory, |
|
signatures=concrete_function) |
|
|
|
config_util.save_pipeline_config(pipeline_config, output_directory) |
|
|