Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Common library to export a SavedModel from the export module.""" | |
from typing import Dict, List, Optional, Union, Any | |
import tensorflow as tf, tf_keras | |
from official.core import export_base | |
get_timestamped_export_dir = export_base.get_timestamped_export_dir | |
def export(export_module: export_base.ExportModule, | |
function_keys: Union[List[str], Dict[str, str]], | |
export_savedmodel_dir: str, | |
checkpoint_path: Optional[str] = None, | |
timestamped: bool = True, | |
module_key: Optional[str] = None, | |
checkpoint_kwargs: Optional[Dict[str, Any]] = None) -> str: | |
"""Exports to SavedModel format. | |
Args: | |
export_module: a ExportModule with the keras Model and serving tf.functions. | |
function_keys: a list of string keys to retrieve pre-defined serving | |
signatures. The signaute keys will be set with defaults. If a dictionary | |
is provided, the values will be used as signature keys. | |
export_savedmodel_dir: Output saved model directory. | |
checkpoint_path: Object-based checkpoint path or directory. | |
timestamped: Whether to export the savedmodel to a timestamped directory. | |
module_key: Optional string to identify a checkpoint object to load for the | |
model in the export module. | |
checkpoint_kwargs: Optional dict used as keyword args to create the | |
checkpoint object. Not used if module_key is present. | |
Returns: | |
The savedmodel directory path. | |
""" | |
save_options = tf.saved_model.SaveOptions(function_aliases={ | |
'tpu_candidate': export_module.serve, | |
}) | |
if module_key: | |
kwargs = {module_key: export_module.model} | |
checkpoint = tf.train.Checkpoint(**kwargs) | |
elif checkpoint_kwargs: | |
checkpoint = tf.train.Checkpoint(**checkpoint_kwargs) | |
else: | |
checkpoint = None | |
return export_base.export( | |
export_module, | |
function_keys, | |
export_savedmodel_dir, | |
checkpoint_path, | |
timestamped, | |
save_options, | |
checkpoint=checkpoint) | |