Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
r"""Vision models export binary for serving/inference. | |
To export a trained checkpoint in saved_model format (shell script): | |
EXPERIMENT_TYPE = XX | |
CHECKPOINT_PATH = XX | |
EXPORT_DIR_PATH = XX | |
export_saved_model --experiment=${EXPERIMENT_TYPE} \ | |
--export_dir=${EXPORT_DIR_PATH}/ \ | |
--checkpoint_path=${CHECKPOINT_PATH} \ | |
--batch_size=2 \ | |
--input_image_size=224,224 | |
To serve (python): | |
export_dir_path = XX | |
input_type = XX | |
input_images = XX | |
imported = tf.saved_model.load(export_dir_path) | |
model_fn = imported.signatures['serving_default'] | |
output = model_fn(input_images) | |
""" | |
from absl import app | |
from absl import flags | |
from official.core import exp_factory | |
from official.modeling import hyperparams | |
from official.vision import registry_imports # pylint: disable=unused-import | |
from official.vision.serving import export_saved_model_lib | |
FLAGS = flags.FLAGS | |
_EXPERIMENT = flags.DEFINE_string( | |
'experiment', None, 'experiment type, e.g. retinanet_resnetfpn_coco') | |
_EXPORT_DIR = flags.DEFINE_string('export_dir', None, 'The export directory.') | |
_CHECKPOINT_PATH = flags.DEFINE_string('checkpoint_path', None, | |
'Checkpoint path.') | |
_CONFIG_FILE = flags.DEFINE_multi_string( | |
'config_file', | |
default=None, | |
help='YAML/JSON files which specifies overrides. The override order ' | |
'follows the order of args. Note that each file ' | |
'can be used as an override template to override the default parameters ' | |
'specified in Python. If the same parameter is specified in both ' | |
'`--config_file` and `--params_override`, `config_file` will be used ' | |
'first, followed by params_override.') | |
_PARAMS_OVERRIDE = flags.DEFINE_string( | |
'params_override', '', | |
'The JSON/YAML file or string which specifies the parameter to be overriden' | |
' on top of `config_file` template.') | |
_BATCH_SIZE = flags.DEFINE_integer('batch_size', None, 'The batch size.') | |
_IMAGE_TYPE = flags.DEFINE_string( | |
'input_type', 'image_tensor', | |
'One of `image_tensor`, `image_bytes`, `tf_example` and `tflite`.') | |
_INPUT_IMAGE_SIZE = flags.DEFINE_string( | |
'input_image_size', '224,224', | |
'The comma-separated string of two integers representing the height,width ' | |
'of the input to the model.') | |
_EXPORT_CHECKPOINT_SUBDIR = flags.DEFINE_string( | |
'export_checkpoint_subdir', 'checkpoint', | |
'The subdirectory for checkpoints.') | |
_EXPORT_SAVED_MODEL_SUBDIR = flags.DEFINE_string( | |
'export_saved_model_subdir', 'saved_model', | |
'The subdirectory for saved model.') | |
_LOG_MODEL_FLOPS_AND_PARAMS = flags.DEFINE_bool( | |
'log_model_flops_and_params', False, | |
'If true, logs model flops and parameters.') | |
_INPUT_NAME = flags.DEFINE_string( | |
'input_name', None, | |
'Input tensor name in signature def. Default at None which' | |
'produces input tensor name `inputs`.') | |
_FUNCTION_KEYS = flags.DEFINE_string( | |
'function_keys', | |
'', | |
( | |
'An optional comma-separated string of one or more key:value pair' | |
' indicating the serving function key and corresponding signature_def' | |
' name. For example,' | |
' `tf_example:serving_default,image_tensor:serving_image_tensor` means' | |
' two serving functions are defined for `tf_example` and `image_tensor`' | |
' input types.' | |
), | |
) | |
_ADD_TPU_FUNCTION_ALIAS = flags.DEFINE_bool( | |
'add_tpu_function_alias', | |
False, | |
( | |
'Whether to add TPU function alias so later it can be converted to a' | |
' TPU SavedModel for inference.' | |
), | |
) | |
def main(_): | |
params = exp_factory.get_exp_config(_EXPERIMENT.value) | |
for config_file in _CONFIG_FILE.value or []: | |
try: | |
params = hyperparams.override_params_dict( | |
params, config_file, is_strict=True | |
) | |
except KeyError: | |
params = hyperparams.override_params_dict( | |
params, config_file, is_strict=False | |
) | |
if _PARAMS_OVERRIDE.value: | |
try: | |
params = hyperparams.override_params_dict( | |
params, _PARAMS_OVERRIDE.value, is_strict=True | |
) | |
except KeyError: | |
params = hyperparams.override_params_dict( | |
params, _PARAMS_OVERRIDE.value, is_strict=False | |
) | |
params.validate() | |
params.lock() | |
function_keys = None | |
if _FUNCTION_KEYS.value: | |
function_keys = {} | |
for key_val in _FUNCTION_KEYS.value.split(','): | |
key_val_split = key_val.split(':') | |
function_keys[key_val_split[0]] = key_val_split[1] | |
export_saved_model_lib.export_inference_graph( | |
input_type=_IMAGE_TYPE.value, | |
batch_size=_BATCH_SIZE.value, | |
input_image_size=[int(x) for x in _INPUT_IMAGE_SIZE.value.split(',')], | |
params=params, | |
checkpoint_path=_CHECKPOINT_PATH.value, | |
export_dir=_EXPORT_DIR.value, | |
function_keys=function_keys, | |
export_checkpoint_subdir=_EXPORT_CHECKPOINT_SUBDIR.value, | |
export_saved_model_subdir=_EXPORT_SAVED_MODEL_SUBDIR.value, | |
log_model_flops_and_params=_LOG_MODEL_FLOPS_AND_PARAMS.value, | |
input_name=_INPUT_NAME.value, | |
add_tpu_function_alias=_ADD_TPU_FUNCTION_ALIAS.value, | |
) | |
if __name__ == '__main__': | |
app.run(main) | |