# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Common library to export a SavedModel from the export module.""" from typing import Dict, List, Optional, Union, Any import tensorflow as tf, tf_keras from official.core import export_base get_timestamped_export_dir = export_base.get_timestamped_export_dir def export(export_module: export_base.ExportModule, function_keys: Union[List[str], Dict[str, str]], export_savedmodel_dir: str, checkpoint_path: Optional[str] = None, timestamped: bool = True, module_key: Optional[str] = None, checkpoint_kwargs: Optional[Dict[str, Any]] = None) -> str: """Exports to SavedModel format. Args: export_module: a ExportModule with the keras Model and serving tf.functions. function_keys: a list of string keys to retrieve pre-defined serving signatures. The signaute keys will be set with defaults. If a dictionary is provided, the values will be used as signature keys. export_savedmodel_dir: Output saved model directory. checkpoint_path: Object-based checkpoint path or directory. timestamped: Whether to export the savedmodel to a timestamped directory. module_key: Optional string to identify a checkpoint object to load for the model in the export module. checkpoint_kwargs: Optional dict used as keyword args to create the checkpoint object. Not used if module_key is present. Returns: The savedmodel directory path. """ save_options = tf.saved_model.SaveOptions(function_aliases={ 'tpu_candidate': export_module.serve, }) if module_key: kwargs = {module_key: export_module.model} checkpoint = tf.train.Checkpoint(**kwargs) elif checkpoint_kwargs: checkpoint = tf.train.Checkpoint(**checkpoint_kwargs) else: checkpoint = None return export_base.export( export_module, function_keys, export_savedmodel_dir, checkpoint_path, timestamped, save_options, checkpoint=checkpoint)