Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
r"""Binary to convert a saved model to tflite model. | |
It requires a SavedModel exported using export_saved_model.py with batch size 1 | |
and input type `tflite`, and using the same config file used for exporting saved | |
model. It includes optional post-training quantization. When using integer | |
quantization, calibration steps need to be provided to calibrate model input. | |
To convert a SavedModel to a TFLite model: | |
EXPERIMENT_TYPE = XX | |
TFLITE_PATH = XX | |
SAVED_MOODEL_DIR = XX | |
CONFIG_FILE = XX | |
export_tflite --experiment=${EXPERIMENT_TYPE} \ | |
--saved_model_dir=${SAVED_MOODEL_DIR} \ | |
--tflite_path=${TFLITE_PATH} \ | |
--config_file=${CONFIG_FILE} \ | |
--quant_type=fp16 \ | |
--calibration_steps=500 | |
""" | |
from absl import app | |
from absl import flags | |
from absl import logging | |
import tensorflow as tf, tf_keras | |
from official.core import exp_factory | |
from official.modeling import hyperparams | |
from official.vision import registry_imports # pylint: disable=unused-import | |
from official.vision.serving import export_tflite_lib | |
FLAGS = flags.FLAGS | |
_EXPERIMENT = flags.DEFINE_string( | |
'experiment', | |
None, | |
'experiment type, e.g. retinanet_resnetfpn_coco', | |
required=True) | |
_CONFIG_FILE = flags.DEFINE_multi_string( | |
'config_file', | |
default='', | |
help='YAML/JSON files which specifies overrides. The override order ' | |
'follows the order of args. Note that each file ' | |
'can be used as an override template to override the default parameters ' | |
'specified in Python. If the same parameter is specified in both ' | |
'`--config_file` and `--params_override`, `config_file` will be used ' | |
'first, followed by params_override.') | |
_PARAMS_OVERRIDE = flags.DEFINE_string( | |
'params_override', '', | |
'The JSON/YAML file or string which specifies the parameter to be overriden' | |
' on top of `config_file` template.') | |
_SAVED_MODEL_DIR = flags.DEFINE_string( | |
'saved_model_dir', None, 'The directory to the saved model.', required=True) | |
_TFLITE_PATH = flags.DEFINE_string( | |
'tflite_path', None, 'The path to the output tflite model.', required=True) | |
_QUANT_TYPE = flags.DEFINE_string( | |
'quant_type', | |
default=None, | |
help='Post training quantization type. Support `int8_fallback`, ' | |
'`int8_full_fp32_io`, `int8_full`, `fp16`, `qat`, `qat_fp32_io`, ' | |
'`int8_full_int8_io` and `default`. See ' | |
'https://www.tensorflow.org/lite/performance/post_training_quantization ' | |
'for more details.') | |
_CALIBRATION_STEPS = flags.DEFINE_integer( | |
'calibration_steps', 500, | |
'The number of calibration steps for integer model.') | |
_DENYLISTED_OPS = flags.DEFINE_string( | |
'denylisted_ops', '', 'The comma-separated string of ops ' | |
'that are excluded from integer quantization. The name of ' | |
'ops should be all capital letters, such as CAST or GREATER.' | |
'This is useful to exclude certains ops that affects quality or latency. ' | |
'Valid ops that should not be included are quantization friendly ops, such ' | |
'as CONV_2D, DEPTHWISE_CONV_2D, FULLY_CONNECTED, etc.') | |
def main(_) -> None: | |
params = exp_factory.get_exp_config(_EXPERIMENT.value) | |
if _CONFIG_FILE.value is not None: | |
for config_file in _CONFIG_FILE.value or []: | |
try: | |
params = hyperparams.override_params_dict( | |
params, config_file, is_strict=True | |
) | |
except KeyError: | |
params = hyperparams.override_params_dict( | |
params, config_file, is_strict=False | |
) | |
if _PARAMS_OVERRIDE.value: | |
try: | |
params = hyperparams.override_params_dict( | |
params, _PARAMS_OVERRIDE.value, is_strict=True | |
) | |
except KeyError: | |
params = hyperparams.override_params_dict( | |
params, _PARAMS_OVERRIDE.value, is_strict=False | |
) | |
params.validate() | |
params.lock() | |
logging.info('Converting SavedModel from %s to TFLite model...', | |
_SAVED_MODEL_DIR.value) | |
if _DENYLISTED_OPS.value: | |
denylisted_ops = list(_DENYLISTED_OPS.value.split(',')) | |
else: | |
denylisted_ops = None | |
tflite_model = export_tflite_lib.convert_tflite_model( | |
saved_model_dir=_SAVED_MODEL_DIR.value, | |
quant_type=_QUANT_TYPE.value, | |
params=params, | |
calibration_steps=_CALIBRATION_STEPS.value, | |
denylisted_ops=denylisted_ops) | |
with tf.io.gfile.GFile(_TFLITE_PATH.value, 'wb') as fw: | |
fw.write(tflite_model) | |
logging.info('TFLite model converted and saved to %s.', _TFLITE_PATH.value) | |
if __name__ == '__main__': | |
app.run(main) | |