|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Functions to export object detection inference graph.""" |
|
import os |
|
import tempfile |
|
import tensorflow.compat.v1 as tf |
|
import tf_slim as slim |
|
from tensorflow.core.protobuf import saver_pb2 |
|
from tensorflow.python.tools import freeze_graph |
|
from object_detection.builders import graph_rewriter_builder |
|
from object_detection.builders import model_builder |
|
from object_detection.core import standard_fields as fields |
|
from object_detection.data_decoders import tf_example_decoder |
|
from object_detection.utils import config_util |
|
from object_detection.utils import shape_utils |
|
|
|
|
|
try: |
|
from tensorflow.contrib import tfprof as contrib_tfprof |
|
from tensorflow.contrib.quantize.python import graph_matcher |
|
except ImportError: |
|
|
|
pass |
|
|
|
|
|
freeze_graph_with_def_protos = freeze_graph.freeze_graph_with_def_protos |
|
|
|
|
|
def parse_side_inputs(side_input_shapes_string, side_input_names_string, |
|
side_input_types_string): |
|
"""Parses side input flags. |
|
|
|
Args: |
|
side_input_shapes_string: The shape of the side input tensors, provided as a |
|
comma-separated list of integers. A value of -1 is used for unknown |
|
dimensions. A `/` denotes a break, starting the shape of the next side |
|
input tensor. |
|
side_input_names_string: The names of the side input tensors, provided as a |
|
comma-separated list of strings. |
|
side_input_types_string: The type of the side input tensors, provided as a |
|
comma-separated list of types, each of `string`, `integer`, or `float`. |
|
|
|
Returns: |
|
side_input_shapes: A list of shapes. |
|
side_input_names: A list of strings. |
|
side_input_types: A list of tensorflow dtypes. |
|
|
|
""" |
|
if side_input_shapes_string: |
|
side_input_shapes = [] |
|
for side_input_shape_list in side_input_shapes_string.split('/'): |
|
side_input_shape = [ |
|
int(dim) if dim != '-1' else None |
|
for dim in side_input_shape_list.split(',') |
|
] |
|
side_input_shapes.append(side_input_shape) |
|
else: |
|
raise ValueError('When using side_inputs, side_input_shapes must be ' |
|
'specified in the input flags.') |
|
if side_input_names_string: |
|
side_input_names = list(side_input_names_string.split(',')) |
|
else: |
|
raise ValueError('When using side_inputs, side_input_names must be ' |
|
'specified in the input flags.') |
|
if side_input_types_string: |
|
typelookup = {'float': tf.float32, 'int': tf.int32, 'string': tf.string} |
|
side_input_types = [ |
|
typelookup[side_input_type] |
|
for side_input_type in side_input_types_string.split(',') |
|
] |
|
else: |
|
raise ValueError('When using side_inputs, side_input_types must be ' |
|
'specified in the input flags.') |
|
return side_input_shapes, side_input_names, side_input_types |
|
|
|
|
|
def rewrite_nn_resize_op(is_quantized=False): |
|
"""Replaces a custom nearest-neighbor resize op with the Tensorflow version. |
|
|
|
Some graphs use this custom version for TPU-compatibility. |
|
|
|
Args: |
|
is_quantized: True if the default graph is quantized. |
|
""" |
|
def remove_nn(): |
|
"""Remove nearest neighbor upsampling structures and replace with TF op.""" |
|
input_pattern = graph_matcher.OpTypePattern( |
|
'FakeQuantWithMinMaxVars' if is_quantized else '*') |
|
stack_1_pattern = graph_matcher.OpTypePattern( |
|
'Pack', inputs=[input_pattern, input_pattern], ordered_inputs=False) |
|
stack_2_pattern = graph_matcher.OpTypePattern( |
|
'Pack', inputs=[stack_1_pattern, stack_1_pattern], ordered_inputs=False) |
|
reshape_pattern = graph_matcher.OpTypePattern( |
|
'Reshape', inputs=[stack_2_pattern, 'Const'], ordered_inputs=False) |
|
consumer_pattern1 = graph_matcher.OpTypePattern( |
|
'Add|AddV2|Max|Mul', inputs=[reshape_pattern, '*'], |
|
ordered_inputs=False) |
|
consumer_pattern2 = graph_matcher.OpTypePattern( |
|
'StridedSlice', inputs=[reshape_pattern, '*', '*', '*'], |
|
ordered_inputs=False) |
|
|
|
def replace_matches(consumer_pattern): |
|
"""Search for nearest neighbor pattern and replace with TF op.""" |
|
match_counter = 0 |
|
matcher = graph_matcher.GraphMatcher(consumer_pattern) |
|
for match in matcher.match_graph(tf.get_default_graph()): |
|
match_counter += 1 |
|
projection_op = match.get_op(input_pattern) |
|
reshape_op = match.get_op(reshape_pattern) |
|
consumer_op = match.get_op(consumer_pattern) |
|
nn_resize = tf.image.resize_nearest_neighbor( |
|
projection_op.outputs[0], |
|
reshape_op.outputs[0].shape.dims[1:3], |
|
align_corners=False, |
|
name=os.path.split(reshape_op.name)[0] + '/resize_nearest_neighbor') |
|
|
|
for index, op_input in enumerate(consumer_op.inputs): |
|
if op_input == reshape_op.outputs[0]: |
|
consumer_op._update_input(index, nn_resize) |
|
break |
|
|
|
return match_counter |
|
|
|
match_counter = replace_matches(consumer_pattern1) |
|
match_counter += replace_matches(consumer_pattern2) |
|
|
|
tf.logging.info('Found and fixed {} matches'.format(match_counter)) |
|
return match_counter |
|
|
|
|
|
total_removals = 0 |
|
while remove_nn(): |
|
total_removals += 1 |
|
|
|
if total_removals > 4: |
|
raise ValueError('Graph removal encountered a infinite loop.') |
|
|
|
|
|
def replace_variable_values_with_moving_averages(graph, |
|
current_checkpoint_file, |
|
new_checkpoint_file, |
|
no_ema_collection=None): |
|
"""Replaces variable values in the checkpoint with their moving averages. |
|
|
|
If the current checkpoint has shadow variables maintaining moving averages of |
|
the variables defined in the graph, this function generates a new checkpoint |
|
where the variables contain the values of their moving averages. |
|
|
|
Args: |
|
graph: a tf.Graph object. |
|
current_checkpoint_file: a checkpoint containing both original variables and |
|
their moving averages. |
|
new_checkpoint_file: file path to write a new checkpoint. |
|
no_ema_collection: A list of namescope substrings to match the variables |
|
to eliminate EMA. |
|
""" |
|
with graph.as_default(): |
|
variable_averages = tf.train.ExponentialMovingAverage(0.0) |
|
ema_variables_to_restore = variable_averages.variables_to_restore() |
|
ema_variables_to_restore = config_util.remove_unecessary_ema( |
|
ema_variables_to_restore, no_ema_collection) |
|
with tf.Session() as sess: |
|
read_saver = tf.train.Saver(ema_variables_to_restore) |
|
read_saver.restore(sess, current_checkpoint_file) |
|
write_saver = tf.train.Saver() |
|
write_saver.save(sess, new_checkpoint_file) |
|
|
|
|
|
def _image_tensor_input_placeholder(input_shape=None): |
|
"""Returns input placeholder and a 4-D uint8 image tensor.""" |
|
if input_shape is None: |
|
input_shape = (None, None, None, 3) |
|
input_tensor = tf.placeholder( |
|
dtype=tf.uint8, shape=input_shape, name='image_tensor') |
|
return input_tensor, input_tensor |
|
|
|
|
|
def _side_input_tensor_placeholder(side_input_shape, side_input_name, |
|
side_input_type): |
|
"""Returns side input placeholder and side input tensor.""" |
|
side_input_tensor = tf.placeholder( |
|
dtype=side_input_type, shape=side_input_shape, name=side_input_name) |
|
return side_input_tensor, side_input_tensor |
|
|
|
|
|
def _tf_example_input_placeholder(input_shape=None): |
|
"""Returns input that accepts a batch of strings with tf examples. |
|
|
|
Args: |
|
input_shape: the shape to resize the output decoded images to (optional). |
|
|
|
Returns: |
|
a tuple of input placeholder and the output decoded images. |
|
""" |
|
batch_tf_example_placeholder = tf.placeholder( |
|
tf.string, shape=[None], name='tf_example') |
|
def decode(tf_example_string_tensor): |
|
tensor_dict = tf_example_decoder.TfExampleDecoder().decode( |
|
tf_example_string_tensor) |
|
image_tensor = tensor_dict[fields.InputDataFields.image] |
|
if input_shape is not None: |
|
image_tensor = tf.image.resize(image_tensor, input_shape[1:3]) |
|
return image_tensor |
|
return (batch_tf_example_placeholder, |
|
shape_utils.static_or_dynamic_map_fn( |
|
decode, |
|
elems=batch_tf_example_placeholder, |
|
dtype=tf.uint8, |
|
parallel_iterations=32, |
|
back_prop=False)) |
|
|
|
|
|
def _encoded_image_string_tensor_input_placeholder(input_shape=None): |
|
"""Returns input that accepts a batch of PNG or JPEG strings. |
|
|
|
Args: |
|
input_shape: the shape to resize the output decoded images to (optional). |
|
|
|
Returns: |
|
a tuple of input placeholder and the output decoded images. |
|
""" |
|
batch_image_str_placeholder = tf.placeholder( |
|
dtype=tf.string, |
|
shape=[None], |
|
name='encoded_image_string_tensor') |
|
def decode(encoded_image_string_tensor): |
|
image_tensor = tf.image.decode_image(encoded_image_string_tensor, |
|
channels=3) |
|
image_tensor.set_shape((None, None, 3)) |
|
if input_shape is not None: |
|
image_tensor = tf.image.resize(image_tensor, input_shape[1:3]) |
|
return image_tensor |
|
return (batch_image_str_placeholder, |
|
tf.map_fn( |
|
decode, |
|
elems=batch_image_str_placeholder, |
|
dtype=tf.uint8, |
|
parallel_iterations=32, |
|
back_prop=False)) |
|
|
|
|
|
input_placeholder_fn_map = { |
|
'image_tensor': _image_tensor_input_placeholder, |
|
'encoded_image_string_tensor': |
|
_encoded_image_string_tensor_input_placeholder, |
|
'tf_example': _tf_example_input_placeholder |
|
} |
|
|
|
|
|
def add_output_tensor_nodes(postprocessed_tensors, |
|
output_collection_name='inference_op'): |
|
"""Adds output nodes for detection boxes and scores. |
|
|
|
Adds the following nodes for output tensors - |
|
* num_detections: float32 tensor of shape [batch_size]. |
|
* detection_boxes: float32 tensor of shape [batch_size, num_boxes, 4] |
|
containing detected boxes. |
|
* detection_scores: float32 tensor of shape [batch_size, num_boxes] |
|
containing scores for the detected boxes. |
|
* detection_multiclass_scores: (Optional) float32 tensor of shape |
|
[batch_size, num_boxes, num_classes_with_background] for containing class |
|
score distribution for detected boxes including background if any. |
|
* detection_features: (Optional) float32 tensor of shape |
|
[batch, num_boxes, roi_height, roi_width, depth] |
|
containing classifier features |
|
for each detected box |
|
* detection_classes: float32 tensor of shape [batch_size, num_boxes] |
|
containing class predictions for the detected boxes. |
|
* detection_keypoints: (Optional) float32 tensor of shape |
|
[batch_size, num_boxes, num_keypoints, 2] containing keypoints for each |
|
detection box. |
|
* detection_masks: (Optional) float32 tensor of shape |
|
[batch_size, num_boxes, mask_height, mask_width] containing masks for each |
|
detection box. |
|
|
|
Args: |
|
postprocessed_tensors: a dictionary containing the following fields |
|
'detection_boxes': [batch, max_detections, 4] |
|
'detection_scores': [batch, max_detections] |
|
'detection_multiclass_scores': [batch, max_detections, |
|
num_classes_with_background] |
|
'detection_features': [batch, num_boxes, roi_height, roi_width, depth] |
|
'detection_classes': [batch, max_detections] |
|
'detection_masks': [batch, max_detections, mask_height, mask_width] |
|
(optional). |
|
'detection_keypoints': [batch, max_detections, num_keypoints, 2] |
|
(optional). |
|
'num_detections': [batch] |
|
output_collection_name: Name of collection to add output tensors to. |
|
|
|
Returns: |
|
A tensor dict containing the added output tensor nodes. |
|
""" |
|
detection_fields = fields.DetectionResultFields |
|
label_id_offset = 1 |
|
boxes = postprocessed_tensors.get(detection_fields.detection_boxes) |
|
scores = postprocessed_tensors.get(detection_fields.detection_scores) |
|
multiclass_scores = postprocessed_tensors.get( |
|
detection_fields.detection_multiclass_scores) |
|
box_classifier_features = postprocessed_tensors.get( |
|
detection_fields.detection_features) |
|
raw_boxes = postprocessed_tensors.get(detection_fields.raw_detection_boxes) |
|
raw_scores = postprocessed_tensors.get(detection_fields.raw_detection_scores) |
|
classes = postprocessed_tensors.get( |
|
detection_fields.detection_classes) + label_id_offset |
|
keypoints = postprocessed_tensors.get(detection_fields.detection_keypoints) |
|
masks = postprocessed_tensors.get(detection_fields.detection_masks) |
|
num_detections = postprocessed_tensors.get(detection_fields.num_detections) |
|
outputs = {} |
|
outputs[detection_fields.detection_boxes] = tf.identity( |
|
boxes, name=detection_fields.detection_boxes) |
|
outputs[detection_fields.detection_scores] = tf.identity( |
|
scores, name=detection_fields.detection_scores) |
|
if multiclass_scores is not None: |
|
outputs[detection_fields.detection_multiclass_scores] = tf.identity( |
|
multiclass_scores, name=detection_fields.detection_multiclass_scores) |
|
if box_classifier_features is not None: |
|
outputs[detection_fields.detection_features] = tf.identity( |
|
box_classifier_features, |
|
name=detection_fields.detection_features) |
|
outputs[detection_fields.detection_classes] = tf.identity( |
|
classes, name=detection_fields.detection_classes) |
|
outputs[detection_fields.num_detections] = tf.identity( |
|
num_detections, name=detection_fields.num_detections) |
|
if raw_boxes is not None: |
|
outputs[detection_fields.raw_detection_boxes] = tf.identity( |
|
raw_boxes, name=detection_fields.raw_detection_boxes) |
|
if raw_scores is not None: |
|
outputs[detection_fields.raw_detection_scores] = tf.identity( |
|
raw_scores, name=detection_fields.raw_detection_scores) |
|
if keypoints is not None: |
|
outputs[detection_fields.detection_keypoints] = tf.identity( |
|
keypoints, name=detection_fields.detection_keypoints) |
|
if masks is not None: |
|
outputs[detection_fields.detection_masks] = tf.identity( |
|
masks, name=detection_fields.detection_masks) |
|
for output_key in outputs: |
|
tf.add_to_collection(output_collection_name, outputs[output_key]) |
|
|
|
return outputs |
|
|
|
|
|
def write_saved_model(saved_model_path, |
|
frozen_graph_def, |
|
inputs, |
|
outputs): |
|
"""Writes SavedModel to disk. |
|
|
|
If checkpoint_path is not None bakes the weights into the graph thereby |
|
eliminating the need of checkpoint files during inference. If the model |
|
was trained with moving averages, setting use_moving_averages to true |
|
restores the moving averages, otherwise the original set of variables |
|
is restored. |
|
|
|
Args: |
|
saved_model_path: Path to write SavedModel. |
|
frozen_graph_def: tf.GraphDef holding frozen graph. |
|
inputs: A tensor dictionary containing the inputs to a DetectionModel. |
|
outputs: A tensor dictionary containing the outputs of a DetectionModel. |
|
""" |
|
with tf.Graph().as_default(): |
|
with tf.Session() as sess: |
|
|
|
tf.import_graph_def(frozen_graph_def, name='') |
|
|
|
builder = tf.saved_model.builder.SavedModelBuilder(saved_model_path) |
|
|
|
tensor_info_inputs = {} |
|
if isinstance(inputs, dict): |
|
for k, v in inputs.items(): |
|
tensor_info_inputs[k] = tf.saved_model.utils.build_tensor_info(v) |
|
else: |
|
tensor_info_inputs['inputs'] = tf.saved_model.utils.build_tensor_info( |
|
inputs) |
|
tensor_info_outputs = {} |
|
for k, v in outputs.items(): |
|
tensor_info_outputs[k] = tf.saved_model.utils.build_tensor_info(v) |
|
|
|
detection_signature = ( |
|
tf.saved_model.signature_def_utils.build_signature_def( |
|
inputs=tensor_info_inputs, |
|
outputs=tensor_info_outputs, |
|
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME |
|
)) |
|
|
|
builder.add_meta_graph_and_variables( |
|
sess, |
|
[tf.saved_model.tag_constants.SERVING], |
|
signature_def_map={ |
|
tf.saved_model.signature_constants |
|
.DEFAULT_SERVING_SIGNATURE_DEF_KEY: |
|
detection_signature, |
|
}, |
|
) |
|
builder.save() |
|
|
|
|
|
def write_graph_and_checkpoint(inference_graph_def, |
|
model_path, |
|
input_saver_def, |
|
trained_checkpoint_prefix): |
|
"""Writes the graph and the checkpoint into disk.""" |
|
for node in inference_graph_def.node: |
|
node.device = '' |
|
with tf.Graph().as_default(): |
|
tf.import_graph_def(inference_graph_def, name='') |
|
with tf.Session() as sess: |
|
saver = tf.train.Saver( |
|
saver_def=input_saver_def, save_relative_paths=True) |
|
saver.restore(sess, trained_checkpoint_prefix) |
|
saver.save(sess, model_path) |
|
|
|
|
|
def _get_outputs_from_inputs(input_tensors, detection_model, |
|
output_collection_name, **side_inputs): |
|
inputs = tf.cast(input_tensors, dtype=tf.float32) |
|
preprocessed_inputs, true_image_shapes = detection_model.preprocess(inputs) |
|
output_tensors = detection_model.predict( |
|
preprocessed_inputs, true_image_shapes, **side_inputs) |
|
postprocessed_tensors = detection_model.postprocess( |
|
output_tensors, true_image_shapes) |
|
return add_output_tensor_nodes(postprocessed_tensors, |
|
output_collection_name) |
|
|
|
|
|
def build_detection_graph(input_type, detection_model, input_shape, |
|
output_collection_name, graph_hook_fn, |
|
use_side_inputs=False, side_input_shapes=None, |
|
side_input_names=None, side_input_types=None): |
|
"""Build the detection graph.""" |
|
if input_type not in input_placeholder_fn_map: |
|
raise ValueError('Unknown input type: {}'.format(input_type)) |
|
placeholder_args = {} |
|
side_inputs = {} |
|
if input_shape is not None: |
|
if (input_type != 'image_tensor' and |
|
input_type != 'encoded_image_string_tensor' and |
|
input_type != 'tf_example' and |
|
input_type != 'tf_sequence_example'): |
|
raise ValueError('Can only specify input shape for `image_tensor`, ' |
|
'`encoded_image_string_tensor`, `tf_example`, ' |
|
' or `tf_sequence_example` inputs.') |
|
placeholder_args['input_shape'] = input_shape |
|
placeholder_tensor, input_tensors = input_placeholder_fn_map[input_type]( |
|
**placeholder_args) |
|
placeholder_tensors = {'inputs': placeholder_tensor} |
|
if use_side_inputs: |
|
for idx, side_input_name in enumerate(side_input_names): |
|
side_input_placeholder, side_input = _side_input_tensor_placeholder( |
|
side_input_shapes[idx], side_input_name, side_input_types[idx]) |
|
print(side_input) |
|
side_inputs[side_input_name] = side_input |
|
placeholder_tensors[side_input_name] = side_input_placeholder |
|
outputs = _get_outputs_from_inputs( |
|
input_tensors=input_tensors, |
|
detection_model=detection_model, |
|
output_collection_name=output_collection_name, |
|
**side_inputs) |
|
|
|
|
|
slim.get_or_create_global_step() |
|
|
|
if graph_hook_fn: graph_hook_fn() |
|
|
|
return outputs, placeholder_tensors |
|
|
|
|
|
def _export_inference_graph(input_type, |
|
detection_model, |
|
use_moving_averages, |
|
trained_checkpoint_prefix, |
|
output_directory, |
|
additional_output_tensor_names=None, |
|
input_shape=None, |
|
output_collection_name='inference_op', |
|
graph_hook_fn=None, |
|
write_inference_graph=False, |
|
temp_checkpoint_prefix='', |
|
use_side_inputs=False, |
|
side_input_shapes=None, |
|
side_input_names=None, |
|
side_input_types=None): |
|
"""Export helper.""" |
|
tf.gfile.MakeDirs(output_directory) |
|
frozen_graph_path = os.path.join(output_directory, |
|
'frozen_inference_graph.pb') |
|
saved_model_path = os.path.join(output_directory, 'saved_model') |
|
model_path = os.path.join(output_directory, 'model.ckpt') |
|
|
|
outputs, placeholder_tensor_dict = build_detection_graph( |
|
input_type=input_type, |
|
detection_model=detection_model, |
|
input_shape=input_shape, |
|
output_collection_name=output_collection_name, |
|
graph_hook_fn=graph_hook_fn, |
|
use_side_inputs=use_side_inputs, |
|
side_input_shapes=side_input_shapes, |
|
side_input_names=side_input_names, |
|
side_input_types=side_input_types) |
|
|
|
profile_inference_graph(tf.get_default_graph()) |
|
saver_kwargs = {} |
|
if use_moving_averages: |
|
if not temp_checkpoint_prefix: |
|
|
|
if os.path.isfile(trained_checkpoint_prefix): |
|
saver_kwargs['write_version'] = saver_pb2.SaverDef.V1 |
|
temp_checkpoint_prefix = tempfile.NamedTemporaryFile().name |
|
else: |
|
temp_checkpoint_prefix = tempfile.mkdtemp() |
|
replace_variable_values_with_moving_averages( |
|
tf.get_default_graph(), trained_checkpoint_prefix, |
|
temp_checkpoint_prefix) |
|
checkpoint_to_use = temp_checkpoint_prefix |
|
else: |
|
checkpoint_to_use = trained_checkpoint_prefix |
|
|
|
saver = tf.train.Saver(**saver_kwargs) |
|
input_saver_def = saver.as_saver_def() |
|
|
|
write_graph_and_checkpoint( |
|
inference_graph_def=tf.get_default_graph().as_graph_def(), |
|
model_path=model_path, |
|
input_saver_def=input_saver_def, |
|
trained_checkpoint_prefix=checkpoint_to_use) |
|
if write_inference_graph: |
|
inference_graph_def = tf.get_default_graph().as_graph_def() |
|
inference_graph_path = os.path.join(output_directory, |
|
'inference_graph.pbtxt') |
|
for node in inference_graph_def.node: |
|
node.device = '' |
|
with tf.gfile.GFile(inference_graph_path, 'wb') as f: |
|
f.write(str(inference_graph_def)) |
|
|
|
if additional_output_tensor_names is not None: |
|
output_node_names = ','.join(list(outputs.keys())+( |
|
additional_output_tensor_names)) |
|
else: |
|
output_node_names = ','.join(outputs.keys()) |
|
|
|
frozen_graph_def = freeze_graph.freeze_graph_with_def_protos( |
|
input_graph_def=tf.get_default_graph().as_graph_def(), |
|
input_saver_def=input_saver_def, |
|
input_checkpoint=checkpoint_to_use, |
|
output_node_names=output_node_names, |
|
restore_op_name='save/restore_all', |
|
filename_tensor_name='save/Const:0', |
|
output_graph=frozen_graph_path, |
|
clear_devices=True, |
|
initializer_nodes='') |
|
|
|
write_saved_model(saved_model_path, frozen_graph_def, |
|
placeholder_tensor_dict, outputs) |
|
|
|
|
|
def export_inference_graph(input_type, |
|
pipeline_config, |
|
trained_checkpoint_prefix, |
|
output_directory, |
|
input_shape=None, |
|
output_collection_name='inference_op', |
|
additional_output_tensor_names=None, |
|
write_inference_graph=False, |
|
use_side_inputs=False, |
|
side_input_shapes=None, |
|
side_input_names=None, |
|
side_input_types=None): |
|
"""Exports inference graph for the model specified in the pipeline config. |
|
|
|
Args: |
|
input_type: Type of input for the graph. Can be one of ['image_tensor', |
|
'encoded_image_string_tensor', 'tf_example']. |
|
pipeline_config: pipeline_pb2.TrainAndEvalPipelineConfig proto. |
|
trained_checkpoint_prefix: Path to the trained checkpoint file. |
|
output_directory: Path to write outputs. |
|
input_shape: Sets a fixed shape for an `image_tensor` input. If not |
|
specified, will default to [None, None, None, 3]. |
|
output_collection_name: Name of collection to add output tensors to. |
|
If None, does not add output tensors to a collection. |
|
additional_output_tensor_names: list of additional output |
|
tensors to include in the frozen graph. |
|
write_inference_graph: If true, writes inference graph to disk. |
|
use_side_inputs: If True, the model requires side_inputs. |
|
side_input_shapes: List of shapes of the side input tensors, |
|
required if use_side_inputs is True. |
|
side_input_names: List of names of the side input tensors, |
|
required if use_side_inputs is True. |
|
side_input_types: List of types of the side input tensors, |
|
required if use_side_inputs is True. |
|
""" |
|
detection_model = model_builder.build(pipeline_config.model, |
|
is_training=False) |
|
graph_rewriter_fn = None |
|
if pipeline_config.HasField('graph_rewriter'): |
|
graph_rewriter_config = pipeline_config.graph_rewriter |
|
graph_rewriter_fn = graph_rewriter_builder.build(graph_rewriter_config, |
|
is_training=False) |
|
_export_inference_graph( |
|
input_type, |
|
detection_model, |
|
pipeline_config.eval_config.use_moving_averages, |
|
trained_checkpoint_prefix, |
|
output_directory, |
|
additional_output_tensor_names, |
|
input_shape, |
|
output_collection_name, |
|
graph_hook_fn=graph_rewriter_fn, |
|
write_inference_graph=write_inference_graph, |
|
use_side_inputs=use_side_inputs, |
|
side_input_shapes=side_input_shapes, |
|
side_input_names=side_input_names, |
|
side_input_types=side_input_types) |
|
pipeline_config.eval_config.use_moving_averages = False |
|
config_util.save_pipeline_config(pipeline_config, output_directory) |
|
|
|
|
|
def profile_inference_graph(graph): |
|
"""Profiles the inference graph. |
|
|
|
Prints model parameters and computation FLOPs given an inference graph. |
|
BatchNorms are excluded from the parameter count due to the fact that |
|
BatchNorms are usually folded. BatchNorm, Initializer, Regularizer |
|
and BiasAdd are not considered in FLOP count. |
|
|
|
Args: |
|
graph: the inference graph. |
|
""" |
|
tfprof_vars_option = ( |
|
contrib_tfprof.model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS) |
|
tfprof_flops_option = contrib_tfprof.model_analyzer.FLOAT_OPS_OPTIONS |
|
|
|
|
|
tfprof_vars_option['trim_name_regexes'] = ['.*BatchNorm.*'] |
|
|
|
tfprof_flops_option['trim_name_regexes'] = [ |
|
'.*BatchNorm.*', '.*Initializer.*', '.*Regularizer.*', '.*BiasAdd.*' |
|
] |
|
|
|
contrib_tfprof.model_analyzer.print_model_analysis( |
|
graph, tfprof_options=tfprof_vars_option) |
|
|
|
contrib_tfprof.model_analyzer.print_model_analysis( |
|
graph, tfprof_options=tfprof_flops_option) |
|
|