ASL-MoViNet-T5-translator / official /nlp /serving /export_savedmodel_util.py
deanna-emery's picture
updates
93528c6
raw
history blame
2.59 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common library to export a SavedModel from the export module."""
from typing import Dict, List, Optional, Union, Any
import tensorflow as tf, tf_keras
from official.core import export_base
get_timestamped_export_dir = export_base.get_timestamped_export_dir
def export(export_module: export_base.ExportModule,
function_keys: Union[List[str], Dict[str, str]],
export_savedmodel_dir: str,
checkpoint_path: Optional[str] = None,
timestamped: bool = True,
module_key: Optional[str] = None,
checkpoint_kwargs: Optional[Dict[str, Any]] = None) -> str:
"""Exports to SavedModel format.
Args:
export_module: a ExportModule with the keras Model and serving tf.functions.
function_keys: a list of string keys to retrieve pre-defined serving
signatures. The signaute keys will be set with defaults. If a dictionary
is provided, the values will be used as signature keys.
export_savedmodel_dir: Output saved model directory.
checkpoint_path: Object-based checkpoint path or directory.
timestamped: Whether to export the savedmodel to a timestamped directory.
module_key: Optional string to identify a checkpoint object to load for the
model in the export module.
checkpoint_kwargs: Optional dict used as keyword args to create the
checkpoint object. Not used if module_key is present.
Returns:
The savedmodel directory path.
"""
save_options = tf.saved_model.SaveOptions(function_aliases={
'tpu_candidate': export_module.serve,
})
if module_key:
kwargs = {module_key: export_module.model}
checkpoint = tf.train.Checkpoint(**kwargs)
elif checkpoint_kwargs:
checkpoint = tf.train.Checkpoint(**checkpoint_kwargs)
else:
checkpoint = None
return export_base.export(
export_module,
function_keys,
export_savedmodel_dir,
checkpoint_path,
timestamped,
save_options,
checkpoint=checkpoint)