# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. r"""Vision models export binary for serving/inference. To export a trained checkpoint in saved_model format (shell script): EXPERIMENT_TYPE = XX CHECKPOINT_PATH = XX EXPORT_DIR_PATH = XX export_saved_model --experiment=${EXPERIMENT_TYPE} \ --export_dir=${EXPORT_DIR_PATH}/ \ --checkpoint_path=${CHECKPOINT_PATH} \ --batch_size=2 \ --input_image_size=224,224 To serve (python): export_dir_path = XX input_type = XX input_images = XX imported = tf.saved_model.load(export_dir_path) model_fn = imported.signatures['serving_default'] output = model_fn(input_images) """ from absl import app from absl import flags from official.core import exp_factory from official.modeling import hyperparams from official.vision import registry_imports # pylint: disable=unused-import from official.vision.serving import export_saved_model_lib FLAGS = flags.FLAGS _EXPERIMENT = flags.DEFINE_string( 'experiment', None, 'experiment type, e.g. retinanet_resnetfpn_coco') _EXPORT_DIR = flags.DEFINE_string('export_dir', None, 'The export directory.') _CHECKPOINT_PATH = flags.DEFINE_string('checkpoint_path', None, 'Checkpoint path.') _CONFIG_FILE = flags.DEFINE_multi_string( 'config_file', default=None, help='YAML/JSON files which specifies overrides. The override order ' 'follows the order of args. Note that each file ' 'can be used as an override template to override the default parameters ' 'specified in Python. If the same parameter is specified in both ' '`--config_file` and `--params_override`, `config_file` will be used ' 'first, followed by params_override.') _PARAMS_OVERRIDE = flags.DEFINE_string( 'params_override', '', 'The JSON/YAML file or string which specifies the parameter to be overriden' ' on top of `config_file` template.') _BATCH_SIZE = flags.DEFINE_integer('batch_size', None, 'The batch size.') _IMAGE_TYPE = flags.DEFINE_string( 'input_type', 'image_tensor', 'One of `image_tensor`, `image_bytes`, `tf_example` and `tflite`.') _INPUT_IMAGE_SIZE = flags.DEFINE_string( 'input_image_size', '224,224', 'The comma-separated string of two integers representing the height,width ' 'of the input to the model.') _EXPORT_CHECKPOINT_SUBDIR = flags.DEFINE_string( 'export_checkpoint_subdir', 'checkpoint', 'The subdirectory for checkpoints.') _EXPORT_SAVED_MODEL_SUBDIR = flags.DEFINE_string( 'export_saved_model_subdir', 'saved_model', 'The subdirectory for saved model.') _LOG_MODEL_FLOPS_AND_PARAMS = flags.DEFINE_bool( 'log_model_flops_and_params', False, 'If true, logs model flops and parameters.') _INPUT_NAME = flags.DEFINE_string( 'input_name', None, 'Input tensor name in signature def. Default at None which' 'produces input tensor name `inputs`.') _FUNCTION_KEYS = flags.DEFINE_string( 'function_keys', '', ( 'An optional comma-separated string of one or more key:value pair' ' indicating the serving function key and corresponding signature_def' ' name. For example,' ' `tf_example:serving_default,image_tensor:serving_image_tensor` means' ' two serving functions are defined for `tf_example` and `image_tensor`' ' input types.' ), ) _ADD_TPU_FUNCTION_ALIAS = flags.DEFINE_bool( 'add_tpu_function_alias', False, ( 'Whether to add TPU function alias so later it can be converted to a' ' TPU SavedModel for inference.' ), ) def main(_): params = exp_factory.get_exp_config(_EXPERIMENT.value) for config_file in _CONFIG_FILE.value or []: try: params = hyperparams.override_params_dict( params, config_file, is_strict=True ) except KeyError: params = hyperparams.override_params_dict( params, config_file, is_strict=False ) if _PARAMS_OVERRIDE.value: try: params = hyperparams.override_params_dict( params, _PARAMS_OVERRIDE.value, is_strict=True ) except KeyError: params = hyperparams.override_params_dict( params, _PARAMS_OVERRIDE.value, is_strict=False ) params.validate() params.lock() function_keys = None if _FUNCTION_KEYS.value: function_keys = {} for key_val in _FUNCTION_KEYS.value.split(','): key_val_split = key_val.split(':') function_keys[key_val_split[0]] = key_val_split[1] export_saved_model_lib.export_inference_graph( input_type=_IMAGE_TYPE.value, batch_size=_BATCH_SIZE.value, input_image_size=[int(x) for x in _INPUT_IMAGE_SIZE.value.split(',')], params=params, checkpoint_path=_CHECKPOINT_PATH.value, export_dir=_EXPORT_DIR.value, function_keys=function_keys, export_checkpoint_subdir=_EXPORT_CHECKPOINT_SUBDIR.value, export_saved_model_subdir=_EXPORT_SAVED_MODEL_SUBDIR.value, log_model_flops_and_params=_LOG_MODEL_FLOPS_AND_PARAMS.value, input_name=_INPUT_NAME.value, add_tpu_function_alias=_ADD_TPU_FUNCTION_ALIAS.value, ) if __name__ == '__main__': app.run(main)