deanna-emery's picture
updates
93528c6
raw
history blame
8.12 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Vision models export utility function for serving/inference."""
import os
from typing import Optional, List, Union, Text, Dict
from absl import logging
import tensorflow as tf, tf_keras
from official.core import config_definitions as cfg
from official.core import export_base
from official.core import train_utils
from official.vision import configs
from official.vision.serving import detection
from official.vision.serving import image_classification
from official.vision.serving import semantic_segmentation
from official.vision.serving import video_classification
def export_inference_graph(
input_type: str,
batch_size: Optional[int],
input_image_size: List[int],
params: cfg.ExperimentConfig,
checkpoint_path: str,
export_dir: str,
num_channels: Optional[int] = 3,
export_module: Optional[export_base.ExportModule] = None,
export_checkpoint_subdir: Optional[str] = None,
export_saved_model_subdir: Optional[str] = None,
save_options: Optional[tf.saved_model.SaveOptions] = None,
log_model_flops_and_params: bool = False,
checkpoint: Optional[tf.train.Checkpoint] = None,
input_name: Optional[str] = None,
function_keys: Optional[Union[List[Text], Dict[Text, Text]]] = None,
add_tpu_function_alias: Optional[bool] = False,
):
"""Exports inference graph for the model specified in the exp config.
Saved model is stored at export_dir/saved_model, checkpoint is saved
at export_dir/checkpoint, and params is saved at export_dir/params.yaml.
Args:
input_type: One of `image_tensor`, `image_bytes`, `tf_example` or `tflite`.
batch_size: 'int', or None.
input_image_size: List or Tuple of height and width.
params: Experiment params.
checkpoint_path: Trained checkpoint path or directory.
export_dir: Export directory path.
num_channels: The number of input image channels.
export_module: Optional export module to be used instead of using params to
create one. If None, the params will be used to create an export module.
export_checkpoint_subdir: Optional subdirectory under export_dir to store
checkpoint.
export_saved_model_subdir: Optional subdirectory under export_dir to store
saved model.
save_options: `SaveOptions` for `tf.saved_model.save`.
log_model_flops_and_params: If True, writes model FLOPs to model_flops.txt
and model parameters to model_params.txt.
checkpoint: An optional tf.train.Checkpoint. If provided, the export module
will use it to read the weights.
input_name: The input tensor name, default at `None` which produces input
tensor name `inputs`.
function_keys: a list of string keys to retrieve pre-defined serving
signatures. The signaute keys will be set with defaults. If a dictionary
is provided, the values will be used as signature keys.
add_tpu_function_alias: Whether to add TPU function alias so that it can be
converted to a TPU compatible saved model later. Default is False.
"""
if export_checkpoint_subdir:
output_checkpoint_directory = os.path.join(
export_dir, export_checkpoint_subdir)
else:
output_checkpoint_directory = None
if export_saved_model_subdir:
output_saved_model_directory = os.path.join(
export_dir, export_saved_model_subdir)
else:
output_saved_model_directory = export_dir
# TODO(arashwan): Offers a direct path to use ExportModule with Task objects.
if not export_module:
if isinstance(params.task,
configs.image_classification.ImageClassificationTask):
export_module = image_classification.ClassificationModule(
params=params,
batch_size=batch_size,
input_image_size=input_image_size,
input_type=input_type,
num_channels=num_channels,
input_name=input_name)
elif isinstance(params.task, configs.retinanet.RetinaNetTask) or isinstance(
params.task, configs.maskrcnn.MaskRCNNTask):
export_module = detection.DetectionModule(
params=params,
batch_size=batch_size,
input_image_size=input_image_size,
input_type=input_type,
num_channels=num_channels,
input_name=input_name)
elif isinstance(params.task,
configs.semantic_segmentation.SemanticSegmentationTask):
export_module = semantic_segmentation.SegmentationModule(
params=params,
batch_size=batch_size,
input_image_size=input_image_size,
input_type=input_type,
num_channels=num_channels,
input_name=input_name)
elif isinstance(params.task,
configs.video_classification.VideoClassificationTask):
export_module = video_classification.VideoClassificationModule(
params=params,
batch_size=batch_size,
input_image_size=input_image_size,
input_type=input_type,
num_channels=num_channels,
input_name=input_name)
else:
raise ValueError('Export module not implemented for {} task.'.format(
type(params.task)))
if add_tpu_function_alias:
if input_type == 'image_tensor':
inference_func = export_module.inference_from_image_tensors
elif input_type == 'image_bytes':
inference_func = export_module.inference_from_image_bytes
elif input_type == 'tf_example':
inference_func = export_module.inference_from_tf_example
else:
raise ValueError(
'add_tpu_function_alias is only allowed for input_type of:'
' image_tensor, image_bytes, tf_example.'
)
save_options = tf.saved_model.SaveOptions(
function_aliases={
'tpu_candidate': inference_func,
}
)
export_base.export(
export_module,
function_keys=function_keys if function_keys else [input_type],
export_savedmodel_dir=output_saved_model_directory,
checkpoint=checkpoint,
checkpoint_path=checkpoint_path,
timestamped=False,
save_options=save_options)
if output_checkpoint_directory:
ckpt = tf.train.Checkpoint(model=export_module.model)
ckpt.save(os.path.join(output_checkpoint_directory, 'ckpt'))
train_utils.serialize_config(params, export_dir)
if log_model_flops_and_params:
inputs_kwargs = None
if isinstance(
params.task,
(configs.retinanet.RetinaNetTask, configs.maskrcnn.MaskRCNNTask)):
# We need to create inputs_kwargs argument to specify the input shapes for
# subclass model that overrides model.call to take multiple inputs,
# e.g., RetinaNet model.
inputs_kwargs = {
'images':
tf.TensorSpec([1] + input_image_size + [num_channels],
tf.float32),
'image_shape':
tf.TensorSpec([1, 2], tf.float32)
}
dummy_inputs = {
k: tf.ones(v.shape.as_list(), tf.float32)
for k, v in inputs_kwargs.items()
}
# Must do forward pass to build the model.
export_module.model(**dummy_inputs)
else:
logging.info(
'Logging model flops and params not implemented for %s task.',
type(params.task))
return
train_utils.try_count_flops(export_module.model, inputs_kwargs,
os.path.join(export_dir, 'model_flops.txt'))
train_utils.write_model_params(export_module.model,
os.path.join(export_dir, 'model_params.txt'))