Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""A script to export the image classification as a TF-Hub SavedModel.""" | |
# Import libraries | |
from absl import app | |
from absl import flags | |
from official.core import exp_factory | |
from official.modeling import hyperparams | |
from official.vision import registry_imports # pylint: disable=unused-import | |
from official.vision.serving import export_tfhub_lib | |
FLAGS = flags.FLAGS | |
_EXPERIMENT = flags.DEFINE_string( | |
'experiment', None, 'experiment type, e.g. retinanet_resnetfpn_coco' | |
) | |
_EXPORT_DIR = flags.DEFINE_string('export_dir', None, 'The export directory.') | |
_CHECKPOINT_PATH = flags.DEFINE_string( | |
'checkpoint_path', None, 'Checkpoint path.' | |
) | |
_CONFIG_FILE = flags.DEFINE_multi_string( | |
'config_file', | |
default=None, | |
help=( | |
'YAML/JSON files which specifies overrides. The override order follows' | |
' the order of args. Note that each file can be used as an override' | |
' template to override the default parameters specified in Python. If' | |
' the same parameter is specified in both `--config_file` and' | |
' `--params_override`, `config_file` will be used first, followed by' | |
' params_override.' | |
), | |
) | |
_PARAMS_OVERRIDE = flags.DEFINE_string( | |
'params_override', | |
'', | |
( | |
'The JSON/YAML file or string which specifies the parameter to be' | |
' overriden on top of `config_file` template.' | |
), | |
) | |
_BATCH_SIZE = flags.DEFINE_integer('batch_size', None, 'The batch size.') | |
_INPUT_IMAGE_SIZE = flags.DEFINE_string( | |
'input_image_size', | |
'224,224', | |
( | |
'The comma-separated string of two integers representing the' | |
' height,width of the input to the model.' | |
), | |
) | |
_SKIP_LOGITS_LAYER = flags.DEFINE_boolean( | |
'skip_logits_layer', | |
False, | |
'Whether to skip the prediction layer and only output the feature vector.', | |
) | |
def main(_): | |
params = exp_factory.get_exp_config(_EXPERIMENT.value) | |
for config_file in _CONFIG_FILE.value or []: | |
try: | |
params = hyperparams.override_params_dict( | |
params, config_file, is_strict=True | |
) | |
except KeyError: | |
params = hyperparams.override_params_dict( | |
params, config_file, is_strict=False | |
) | |
if _PARAMS_OVERRIDE.value: | |
try: | |
params = hyperparams.override_params_dict( | |
params, _PARAMS_OVERRIDE.value, is_strict=True | |
) | |
except KeyError: | |
params = hyperparams.override_params_dict( | |
params, _PARAMS_OVERRIDE.value, is_strict=False | |
) | |
params.validate() | |
params.lock() | |
export_tfhub_lib.export_model_to_tfhub( | |
params=params, | |
batch_size=_BATCH_SIZE.value, | |
input_image_size=[int(x) for x in _INPUT_IMAGE_SIZE.value.split(',')], | |
checkpoint_path=_CHECKPOINT_PATH.value, | |
export_path=_EXPORT_DIR.value, | |
num_channels=3, | |
skip_logits_layer=_SKIP_LOGITS_LAYER.value, | |
) | |
if __name__ == '__main__': | |
app.run(main) | |