NCTCMumbai's picture
Upload 2583 files
97b6013 verified
raw
history blame
7.55 kB
# Lint as: python2, python3
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Exports trained model to TensorFlow frozen graph."""
import os
import tensorflow as tf
from tensorflow.contrib import quantize as contrib_quantize
from tensorflow.python.tools import freeze_graph
from deeplab import common
from deeplab import input_preprocess
from deeplab import model
slim = tf.contrib.slim
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string('checkpoint_path', None, 'Checkpoint path')
flags.DEFINE_string('export_path', None,
'Path to output Tensorflow frozen graph.')
flags.DEFINE_integer('num_classes', 21, 'Number of classes.')
flags.DEFINE_multi_integer('crop_size', [513, 513],
'Crop size [height, width].')
# For `xception_65`, use atrous_rates = [12, 24, 36] if output_stride = 8, or
# rates = [6, 12, 18] if output_stride = 16. For `mobilenet_v2`, use None. Note
# one could use different atrous_rates/output_stride during training/evaluation.
flags.DEFINE_multi_integer('atrous_rates', None,
'Atrous rates for atrous spatial pyramid pooling.')
flags.DEFINE_integer('output_stride', 8,
'The ratio of input to output spatial resolution.')
# Change to [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] for multi-scale inference.
flags.DEFINE_multi_float('inference_scales', [1.0],
'The scales to resize images for inference.')
flags.DEFINE_bool('add_flipped_images', False,
'Add flipped images during inference or not.')
flags.DEFINE_integer(
'quantize_delay_step', -1,
'Steps to start quantized training. If < 0, will not quantize model.')
flags.DEFINE_bool('save_inference_graph', False,
'Save inference graph in text proto.')
# Input name of the exported model.
_INPUT_NAME = 'ImageTensor'
# Output name of the exported predictions.
_OUTPUT_NAME = 'SemanticPredictions'
_RAW_OUTPUT_NAME = 'RawSemanticPredictions'
# Output name of the exported probabilities.
_OUTPUT_PROB_NAME = 'SemanticProbabilities'
_RAW_OUTPUT_PROB_NAME = 'RawSemanticProbabilities'
def _create_input_tensors():
"""Creates and prepares input tensors for DeepLab model.
This method creates a 4-D uint8 image tensor 'ImageTensor' with shape
[1, None, None, 3]. The actual input tensor name to use during inference is
'ImageTensor:0'.
Returns:
image: Preprocessed 4-D float32 tensor with shape [1, crop_height,
crop_width, 3].
original_image_size: Original image shape tensor [height, width].
resized_image_size: Resized image shape tensor [height, width].
"""
# input_preprocess takes 4-D image tensor as input.
input_image = tf.placeholder(tf.uint8, [1, None, None, 3], name=_INPUT_NAME)
original_image_size = tf.shape(input_image)[1:3]
# Squeeze the dimension in axis=0 since `preprocess_image_and_label` assumes
# image to be 3-D.
image = tf.squeeze(input_image, axis=0)
resized_image, image, _ = input_preprocess.preprocess_image_and_label(
image,
label=None,
crop_height=FLAGS.crop_size[0],
crop_width=FLAGS.crop_size[1],
min_resize_value=FLAGS.min_resize_value,
max_resize_value=FLAGS.max_resize_value,
resize_factor=FLAGS.resize_factor,
is_training=False,
model_variant=FLAGS.model_variant)
resized_image_size = tf.shape(resized_image)[:2]
# Expand the dimension in axis=0, since the following operations assume the
# image to be 4-D.
image = tf.expand_dims(image, 0)
return image, original_image_size, resized_image_size
def main(unused_argv):
tf.logging.set_verbosity(tf.logging.INFO)
tf.logging.info('Prepare to export model to: %s', FLAGS.export_path)
with tf.Graph().as_default():
image, image_size, resized_image_size = _create_input_tensors()
model_options = common.ModelOptions(
outputs_to_num_classes={common.OUTPUT_TYPE: FLAGS.num_classes},
crop_size=FLAGS.crop_size,
atrous_rates=FLAGS.atrous_rates,
output_stride=FLAGS.output_stride)
if tuple(FLAGS.inference_scales) == (1.0,):
tf.logging.info('Exported model performs single-scale inference.')
predictions = model.predict_labels(
image,
model_options=model_options,
image_pyramid=FLAGS.image_pyramid)
else:
tf.logging.info('Exported model performs multi-scale inference.')
if FLAGS.quantize_delay_step >= 0:
raise ValueError(
'Quantize mode is not supported with multi-scale test.')
predictions = model.predict_labels_multi_scale(
image,
model_options=model_options,
eval_scales=FLAGS.inference_scales,
add_flipped_images=FLAGS.add_flipped_images)
raw_predictions = tf.identity(
tf.cast(predictions[common.OUTPUT_TYPE], tf.float32),
_RAW_OUTPUT_NAME)
raw_probabilities = tf.identity(
predictions[common.OUTPUT_TYPE + model.PROB_SUFFIX],
_RAW_OUTPUT_PROB_NAME)
# Crop the valid regions from the predictions.
semantic_predictions = raw_predictions[
:, :resized_image_size[0], :resized_image_size[1]]
semantic_probabilities = raw_probabilities[
:, :resized_image_size[0], :resized_image_size[1]]
# Resize back the prediction to the original image size.
def _resize_label(label, label_size):
# Expand dimension of label to [1, height, width, 1] for resize operation.
label = tf.expand_dims(label, 3)
resized_label = tf.image.resize_images(
label,
label_size,
method=tf.image.ResizeMethod.NEAREST_NEIGHBOR,
align_corners=True)
return tf.cast(tf.squeeze(resized_label, 3), tf.int32)
semantic_predictions = _resize_label(semantic_predictions, image_size)
semantic_predictions = tf.identity(semantic_predictions, name=_OUTPUT_NAME)
semantic_probabilities = tf.image.resize_bilinear(
semantic_probabilities, image_size, align_corners=True,
name=_OUTPUT_PROB_NAME)
if FLAGS.quantize_delay_step >= 0:
contrib_quantize.create_eval_graph()
saver = tf.train.Saver(tf.all_variables())
dirname = os.path.dirname(FLAGS.export_path)
tf.gfile.MakeDirs(dirname)
graph_def = tf.get_default_graph().as_graph_def(add_shapes=True)
freeze_graph.freeze_graph_with_def_protos(
graph_def,
saver.as_saver_def(),
FLAGS.checkpoint_path,
_OUTPUT_NAME + ',' + _OUTPUT_PROB_NAME,
restore_op_name=None,
filename_tensor_name=None,
output_graph=FLAGS.export_path,
clear_devices=True,
initializer_nodes=None)
if FLAGS.save_inference_graph:
tf.train.write_graph(graph_def, dirname, 'inference_graph.pbtxt')
if __name__ == '__main__':
flags.mark_flag_as_required('checkpoint_path')
flags.mark_flag_as_required('export_path')
tf.app.run()