deanna-emery's picture
updates
93528c6
raw
history blame
5.76 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Vision models export binary for serving/inference.
To export a trained checkpoint in saved_model format (shell script):
EXPERIMENT_TYPE = XX
CHECKPOINT_PATH = XX
EXPORT_DIR_PATH = XX
export_saved_model --experiment=${EXPERIMENT_TYPE} \
--export_dir=${EXPORT_DIR_PATH}/ \
--checkpoint_path=${CHECKPOINT_PATH} \
--batch_size=2 \
--input_image_size=224,224
To serve (python):
export_dir_path = XX
input_type = XX
input_images = XX
imported = tf.saved_model.load(export_dir_path)
model_fn = imported.signatures['serving_default']
output = model_fn(input_images)
"""
from absl import app
from absl import flags
from official.core import exp_factory
from official.modeling import hyperparams
from official.vision import registry_imports # pylint: disable=unused-import
from official.vision.serving import export_saved_model_lib
FLAGS = flags.FLAGS
_EXPERIMENT = flags.DEFINE_string(
'experiment', None, 'experiment type, e.g. retinanet_resnetfpn_coco')
_EXPORT_DIR = flags.DEFINE_string('export_dir', None, 'The export directory.')
_CHECKPOINT_PATH = flags.DEFINE_string('checkpoint_path', None,
'Checkpoint path.')
_CONFIG_FILE = flags.DEFINE_multi_string(
'config_file',
default=None,
help='YAML/JSON files which specifies overrides. The override order '
'follows the order of args. Note that each file '
'can be used as an override template to override the default parameters '
'specified in Python. If the same parameter is specified in both '
'`--config_file` and `--params_override`, `config_file` will be used '
'first, followed by params_override.')
_PARAMS_OVERRIDE = flags.DEFINE_string(
'params_override', '',
'The JSON/YAML file or string which specifies the parameter to be overriden'
' on top of `config_file` template.')
_BATCH_SIZE = flags.DEFINE_integer('batch_size', None, 'The batch size.')
_IMAGE_TYPE = flags.DEFINE_string(
'input_type', 'image_tensor',
'One of `image_tensor`, `image_bytes`, `tf_example` and `tflite`.')
_INPUT_IMAGE_SIZE = flags.DEFINE_string(
'input_image_size', '224,224',
'The comma-separated string of two integers representing the height,width '
'of the input to the model.')
_EXPORT_CHECKPOINT_SUBDIR = flags.DEFINE_string(
'export_checkpoint_subdir', 'checkpoint',
'The subdirectory for checkpoints.')
_EXPORT_SAVED_MODEL_SUBDIR = flags.DEFINE_string(
'export_saved_model_subdir', 'saved_model',
'The subdirectory for saved model.')
_LOG_MODEL_FLOPS_AND_PARAMS = flags.DEFINE_bool(
'log_model_flops_and_params', False,
'If true, logs model flops and parameters.')
_INPUT_NAME = flags.DEFINE_string(
'input_name', None,
'Input tensor name in signature def. Default at None which'
'produces input tensor name `inputs`.')
_FUNCTION_KEYS = flags.DEFINE_string(
'function_keys',
'',
(
'An optional comma-separated string of one or more key:value pair'
' indicating the serving function key and corresponding signature_def'
' name. For example,'
' `tf_example:serving_default,image_tensor:serving_image_tensor` means'
' two serving functions are defined for `tf_example` and `image_tensor`'
' input types.'
),
)
_ADD_TPU_FUNCTION_ALIAS = flags.DEFINE_bool(
'add_tpu_function_alias',
False,
(
'Whether to add TPU function alias so later it can be converted to a'
' TPU SavedModel for inference.'
),
)
def main(_):
params = exp_factory.get_exp_config(_EXPERIMENT.value)
for config_file in _CONFIG_FILE.value or []:
try:
params = hyperparams.override_params_dict(
params, config_file, is_strict=True
)
except KeyError:
params = hyperparams.override_params_dict(
params, config_file, is_strict=False
)
if _PARAMS_OVERRIDE.value:
try:
params = hyperparams.override_params_dict(
params, _PARAMS_OVERRIDE.value, is_strict=True
)
except KeyError:
params = hyperparams.override_params_dict(
params, _PARAMS_OVERRIDE.value, is_strict=False
)
params.validate()
params.lock()
function_keys = None
if _FUNCTION_KEYS.value:
function_keys = {}
for key_val in _FUNCTION_KEYS.value.split(','):
key_val_split = key_val.split(':')
function_keys[key_val_split[0]] = key_val_split[1]
export_saved_model_lib.export_inference_graph(
input_type=_IMAGE_TYPE.value,
batch_size=_BATCH_SIZE.value,
input_image_size=[int(x) for x in _INPUT_IMAGE_SIZE.value.split(',')],
params=params,
checkpoint_path=_CHECKPOINT_PATH.value,
export_dir=_EXPORT_DIR.value,
function_keys=function_keys,
export_checkpoint_subdir=_EXPORT_CHECKPOINT_SUBDIR.value,
export_saved_model_subdir=_EXPORT_SAVED_MODEL_SUBDIR.value,
log_model_flops_and_params=_LOG_MODEL_FLOPS_AND_PARAMS.value,
input_name=_INPUT_NAME.value,
add_tpu_function_alias=_ADD_TPU_FUNCTION_ALIAS.value,
)
if __name__ == '__main__':
app.run(main)