# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """A script to export a TF-Hub SavedModel.""" from typing import List, Optional # Import libraries import tensorflow as tf, tf_keras from official.core import config_definitions as cfg from official.vision import configs from official.vision.modeling import factory def build_model(batch_size: Optional[int], input_image_size: List[int], params: cfg.ExperimentConfig, num_channels: int = 3, skip_logits_layer: bool = False) -> tf_keras.Model: """Builds a model for TF Hub export. Args: batch_size: The batch size of input. input_image_size: A list of [height, width] specifying the input image size. params: The config used to train the model. num_channels: The number of input image channels. skip_logits_layer: Whether to skip the logits layer for image classification model. Default is False. Returns: A tf_keras.Model instance. Raises: ValueError: If the task is not supported. """ input_specs = tf_keras.layers.InputSpec(shape=[batch_size] + input_image_size + [num_channels]) if isinstance(params.task, configs.image_classification.ImageClassificationTask): model = factory.build_classification_model( input_specs=input_specs, model_config=params.task.model, l2_regularizer=None, skip_logits_layer=skip_logits_layer) else: raise ValueError('Export module not implemented for {} task.'.format( type(params.task))) return model def export_model_to_tfhub(batch_size: Optional[int], input_image_size: List[int], params: cfg.ExperimentConfig, checkpoint_path: str, export_path: str, num_channels: int = 3, skip_logits_layer: bool = False): """Export a TF2 model to TF-Hub.""" model = build_model(batch_size, input_image_size, params, num_channels, skip_logits_layer) checkpoint = tf.train.Checkpoint(model=model) checkpoint.restore(checkpoint_path).assert_existing_objects_matched() model.save(export_path, include_optimizer=False, save_format='tf')