deanna-emery's picture
updates
93528c6
raw
history blame
7.27 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Library to facilitate TFLite model conversion."""
import functools
from typing import Iterator, List, Optional
from absl import logging
import tensorflow as tf, tf_keras
from official.core import base_task
from official.core import config_definitions as cfg
from official.vision import configs
from official.vision import tasks
def create_representative_dataset(
params: cfg.ExperimentConfig,
task: Optional[base_task.Task] = None) -> tf.data.Dataset:
"""Creates a tf.data.Dataset to load images for representative dataset.
Args:
params: An ExperimentConfig.
task: An optional task instance. If it is None, task will be built according
to the task type in params.
Returns:
A tf.data.Dataset instance.
Raises:
ValueError: If task is not supported.
"""
if task is None:
if isinstance(params.task,
configs.image_classification.ImageClassificationTask):
task = tasks.image_classification.ImageClassificationTask(params.task)
elif isinstance(params.task, configs.retinanet.RetinaNetTask):
task = tasks.retinanet.RetinaNetTask(params.task)
elif isinstance(params.task, configs.maskrcnn.MaskRCNNTask):
task = tasks.maskrcnn.MaskRCNNTask(params.task)
elif isinstance(params.task,
configs.semantic_segmentation.SemanticSegmentationTask):
task = tasks.semantic_segmentation.SemanticSegmentationTask(params.task)
else:
raise ValueError('Task {} not supported.'.format(type(params.task)))
# Ensure batch size is 1 for TFLite model.
params.task.train_data.global_batch_size = 1
params.task.train_data.dtype = 'float32'
logging.info('Task config: %s', params.task.as_dict())
return task.build_inputs(params=params.task.train_data)
def representative_dataset(
params: cfg.ExperimentConfig,
task: Optional[base_task.Task] = None,
calibration_steps: int = 2000) -> Iterator[List[tf.Tensor]]:
""""Creates representative dataset for input calibration.
Args:
params: An ExperimentConfig.
task: An optional task instance. If it is None, task will be built according
to the task type in params.
calibration_steps: The steps to do calibration.
Yields:
An input image tensor.
"""
dataset = create_representative_dataset(params=params, task=task)
for image, _ in dataset.take(calibration_steps):
# Skip images that do not have 3 channels.
if image.shape[-1] != 3:
continue
yield [image]
def convert_tflite_model(
saved_model_dir: Optional[str] = None,
concrete_function: Optional[tf.types.experimental.ConcreteFunction] = None,
model: Optional[tf.Module] = None,
quant_type: Optional[str] = None,
params: Optional[cfg.ExperimentConfig] = None,
task: Optional[base_task.Task] = None,
calibration_steps: Optional[int] = 2000,
denylisted_ops: Optional[List[str]] = None,
) -> 'bytes':
"""Converts and returns a TFLite model.
Args:
saved_model_dir: The directory to the SavedModel.
concrete_function: An optional concrete function to be exported.
model: An optional tf_keras.Model instance. If both `saved_model_dir` and
`concrete_function` are not available, convert this model to TFLite.
quant_type: The post training quantization (PTQ) method. It can be one of
`default` (dynamic range), `fp16` (float16), `int8` (integer wih float
fallback), `int8_full` (integer only) and None (no quantization).
params: An optional ExperimentConfig to load and preprocess input images to
do calibration for integer quantization.
task: An optional task instance. If it is None, task will be built according
to the task type in params.
calibration_steps: The steps to do calibration.
denylisted_ops: A list of strings containing ops that are excluded from
integer quantization.
Returns:
A converted TFLite model with optional PTQ.
Raises:
ValueError: If `representative_dataset_path` is not present if integer
quantization is requested, or `saved_model_dir`, `concrete_function` or
`model` are not provided.
"""
if saved_model_dir:
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
elif concrete_function is not None:
converter = tf.lite.TFLiteConverter.from_concrete_functions(
[concrete_function]
)
elif model is not None:
converter = tf.lite.TFLiteConverter.from_keras_model(model)
else:
raise ValueError(
'`saved_model_dir`, `model` or `concrete_function` must be specified.'
)
if quant_type:
if quant_type.startswith('int8'):
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = functools.partial(
representative_dataset,
params=params,
task=task,
calibration_steps=calibration_steps)
if quant_type.startswith('int8_full'):
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS_INT8
]
if quant_type == 'int8_full':
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8
if quant_type == 'int8_full_int8_io':
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8
if denylisted_ops:
debug_options = tf.lite.experimental.QuantizationDebugOptions(
denylisted_ops=denylisted_ops)
debugger = tf.lite.experimental.QuantizationDebugger(
converter=converter,
debug_dataset=functools.partial(
representative_dataset,
params=params,
calibration_steps=calibration_steps),
debug_options=debug_options)
debugger.run()
return debugger.get_nondebug_quantized_model()
elif quant_type == 'uint8':
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.default_ranges_stats = (-10, 10)
converter.inference_type = tf.uint8
converter.quantized_input_stats = {'input_placeholder': (0., 1.)}
elif quant_type == 'fp16':
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.float16]
elif quant_type in ('default', 'qat_fp32_io'):
converter.optimizations = [tf.lite.Optimize.DEFAULT]
elif quant_type == 'qat':
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.inference_input_type = tf.uint8 # or tf.int8
converter.inference_output_type = tf.uint8 # or tf.int8
else:
raise ValueError(f'quantization type {quant_type} is not supported.')
return converter.convert()