deanna-emery's picture
updates
93528c6
raw
history blame
3.5 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A script to export the image classification as a TF-Hub SavedModel."""
# Import libraries
from absl import app
from absl import flags
from official.core import exp_factory
from official.modeling import hyperparams
from official.vision import registry_imports # pylint: disable=unused-import
from official.vision.serving import export_tfhub_lib
FLAGS = flags.FLAGS
_EXPERIMENT = flags.DEFINE_string(
'experiment', None, 'experiment type, e.g. retinanet_resnetfpn_coco'
)
_EXPORT_DIR = flags.DEFINE_string('export_dir', None, 'The export directory.')
_CHECKPOINT_PATH = flags.DEFINE_string(
'checkpoint_path', None, 'Checkpoint path.'
)
_CONFIG_FILE = flags.DEFINE_multi_string(
'config_file',
default=None,
help=(
'YAML/JSON files which specifies overrides. The override order follows'
' the order of args. Note that each file can be used as an override'
' template to override the default parameters specified in Python. If'
' the same parameter is specified in both `--config_file` and'
' `--params_override`, `config_file` will be used first, followed by'
' params_override.'
),
)
_PARAMS_OVERRIDE = flags.DEFINE_string(
'params_override',
'',
(
'The JSON/YAML file or string which specifies the parameter to be'
' overriden on top of `config_file` template.'
),
)
_BATCH_SIZE = flags.DEFINE_integer('batch_size', None, 'The batch size.')
_INPUT_IMAGE_SIZE = flags.DEFINE_string(
'input_image_size',
'224,224',
(
'The comma-separated string of two integers representing the'
' height,width of the input to the model.'
),
)
_SKIP_LOGITS_LAYER = flags.DEFINE_boolean(
'skip_logits_layer',
False,
'Whether to skip the prediction layer and only output the feature vector.',
)
def main(_):
params = exp_factory.get_exp_config(_EXPERIMENT.value)
for config_file in _CONFIG_FILE.value or []:
try:
params = hyperparams.override_params_dict(
params, config_file, is_strict=True
)
except KeyError:
params = hyperparams.override_params_dict(
params, config_file, is_strict=False
)
if _PARAMS_OVERRIDE.value:
try:
params = hyperparams.override_params_dict(
params, _PARAMS_OVERRIDE.value, is_strict=True
)
except KeyError:
params = hyperparams.override_params_dict(
params, _PARAMS_OVERRIDE.value, is_strict=False
)
params.validate()
params.lock()
export_tfhub_lib.export_model_to_tfhub(
params=params,
batch_size=_BATCH_SIZE.value,
input_image_size=[int(x) for x in _INPUT_IMAGE_SIZE.value.split(',')],
checkpoint_path=_CHECKPOINT_PATH.value,
export_path=_EXPORT_DIR.value,
num_channels=3,
skip_logits_layer=_SKIP_LOGITS_LAYER.value,
)
if __name__ == '__main__':
app.run(main)