Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Library to facilitate TFLite model conversion.""" | |
import functools | |
from typing import Iterator, List, Optional | |
from absl import logging | |
import tensorflow as tf, tf_keras | |
from official.core import base_task | |
from official.core import config_definitions as cfg | |
from official.vision import configs | |
from official.vision import tasks | |
def create_representative_dataset( | |
params: cfg.ExperimentConfig, | |
task: Optional[base_task.Task] = None) -> tf.data.Dataset: | |
"""Creates a tf.data.Dataset to load images for representative dataset. | |
Args: | |
params: An ExperimentConfig. | |
task: An optional task instance. If it is None, task will be built according | |
to the task type in params. | |
Returns: | |
A tf.data.Dataset instance. | |
Raises: | |
ValueError: If task is not supported. | |
""" | |
if task is None: | |
if isinstance(params.task, | |
configs.image_classification.ImageClassificationTask): | |
task = tasks.image_classification.ImageClassificationTask(params.task) | |
elif isinstance(params.task, configs.retinanet.RetinaNetTask): | |
task = tasks.retinanet.RetinaNetTask(params.task) | |
elif isinstance(params.task, configs.maskrcnn.MaskRCNNTask): | |
task = tasks.maskrcnn.MaskRCNNTask(params.task) | |
elif isinstance(params.task, | |
configs.semantic_segmentation.SemanticSegmentationTask): | |
task = tasks.semantic_segmentation.SemanticSegmentationTask(params.task) | |
else: | |
raise ValueError('Task {} not supported.'.format(type(params.task))) | |
# Ensure batch size is 1 for TFLite model. | |
params.task.train_data.global_batch_size = 1 | |
params.task.train_data.dtype = 'float32' | |
logging.info('Task config: %s', params.task.as_dict()) | |
return task.build_inputs(params=params.task.train_data) | |
def representative_dataset( | |
params: cfg.ExperimentConfig, | |
task: Optional[base_task.Task] = None, | |
calibration_steps: int = 2000) -> Iterator[List[tf.Tensor]]: | |
""""Creates representative dataset for input calibration. | |
Args: | |
params: An ExperimentConfig. | |
task: An optional task instance. If it is None, task will be built according | |
to the task type in params. | |
calibration_steps: The steps to do calibration. | |
Yields: | |
An input image tensor. | |
""" | |
dataset = create_representative_dataset(params=params, task=task) | |
for image, _ in dataset.take(calibration_steps): | |
# Skip images that do not have 3 channels. | |
if image.shape[-1] != 3: | |
continue | |
yield [image] | |
def convert_tflite_model( | |
saved_model_dir: Optional[str] = None, | |
concrete_function: Optional[tf.types.experimental.ConcreteFunction] = None, | |
model: Optional[tf.Module] = None, | |
quant_type: Optional[str] = None, | |
params: Optional[cfg.ExperimentConfig] = None, | |
task: Optional[base_task.Task] = None, | |
calibration_steps: Optional[int] = 2000, | |
denylisted_ops: Optional[List[str]] = None, | |
) -> 'bytes': | |
"""Converts and returns a TFLite model. | |
Args: | |
saved_model_dir: The directory to the SavedModel. | |
concrete_function: An optional concrete function to be exported. | |
model: An optional tf_keras.Model instance. If both `saved_model_dir` and | |
`concrete_function` are not available, convert this model to TFLite. | |
quant_type: The post training quantization (PTQ) method. It can be one of | |
`default` (dynamic range), `fp16` (float16), `int8` (integer wih float | |
fallback), `int8_full` (integer only) and None (no quantization). | |
params: An optional ExperimentConfig to load and preprocess input images to | |
do calibration for integer quantization. | |
task: An optional task instance. If it is None, task will be built according | |
to the task type in params. | |
calibration_steps: The steps to do calibration. | |
denylisted_ops: A list of strings containing ops that are excluded from | |
integer quantization. | |
Returns: | |
A converted TFLite model with optional PTQ. | |
Raises: | |
ValueError: If `representative_dataset_path` is not present if integer | |
quantization is requested, or `saved_model_dir`, `concrete_function` or | |
`model` are not provided. | |
""" | |
if saved_model_dir: | |
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) | |
elif concrete_function is not None: | |
converter = tf.lite.TFLiteConverter.from_concrete_functions( | |
[concrete_function] | |
) | |
elif model is not None: | |
converter = tf.lite.TFLiteConverter.from_keras_model(model) | |
else: | |
raise ValueError( | |
'`saved_model_dir`, `model` or `concrete_function` must be specified.' | |
) | |
if quant_type: | |
if quant_type.startswith('int8'): | |
converter.optimizations = [tf.lite.Optimize.DEFAULT] | |
converter.representative_dataset = functools.partial( | |
representative_dataset, | |
params=params, | |
task=task, | |
calibration_steps=calibration_steps) | |
if quant_type.startswith('int8_full'): | |
converter.target_spec.supported_ops = [ | |
tf.lite.OpsSet.TFLITE_BUILTINS_INT8 | |
] | |
if quant_type == 'int8_full': | |
converter.inference_input_type = tf.uint8 | |
converter.inference_output_type = tf.uint8 | |
if quant_type == 'int8_full_int8_io': | |
converter.inference_input_type = tf.int8 | |
converter.inference_output_type = tf.int8 | |
if denylisted_ops: | |
debug_options = tf.lite.experimental.QuantizationDebugOptions( | |
denylisted_ops=denylisted_ops) | |
debugger = tf.lite.experimental.QuantizationDebugger( | |
converter=converter, | |
debug_dataset=functools.partial( | |
representative_dataset, | |
params=params, | |
calibration_steps=calibration_steps), | |
debug_options=debug_options) | |
debugger.run() | |
return debugger.get_nondebug_quantized_model() | |
elif quant_type == 'uint8': | |
converter.optimizations = [tf.lite.Optimize.DEFAULT] | |
converter.default_ranges_stats = (-10, 10) | |
converter.inference_type = tf.uint8 | |
converter.quantized_input_stats = {'input_placeholder': (0., 1.)} | |
elif quant_type == 'fp16': | |
converter.optimizations = [tf.lite.Optimize.DEFAULT] | |
converter.target_spec.supported_types = [tf.float16] | |
elif quant_type in ('default', 'qat_fp32_io'): | |
converter.optimizations = [tf.lite.Optimize.DEFAULT] | |
elif quant_type == 'qat': | |
converter.optimizations = [tf.lite.Optimize.DEFAULT] | |
converter.inference_input_type = tf.uint8 # or tf.int8 | |
converter.inference_output_type = tf.uint8 # or tf.int8 | |
else: | |
raise ValueError(f'quantization type {quant_type} is not supported.') | |
return converter.convert() | |