Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Base class for model export.""" | |
import abc | |
import functools | |
import time | |
from typing import Any, Callable, Dict, Mapping, List, Optional, Text, Union | |
from absl import logging | |
import tensorflow as tf, tf_keras | |
MAX_DIRECTORY_CREATION_ATTEMPTS = 10 | |
class ExportModule(tf.Module, metaclass=abc.ABCMeta): | |
"""Base Export Module.""" | |
def __init__(self, | |
params, | |
model: Union[tf.Module, tf_keras.Model], | |
inference_step: Optional[Callable[..., Any]] = None, | |
*, | |
preprocessor: Optional[Callable[..., Any]] = None, | |
postprocessor: Optional[Callable[..., Any]] = None): | |
"""Instantiates an ExportModel. | |
Examples: | |
`inference_step` must be a function that has `model` as an kwarg or the | |
second positional argument. | |
``` | |
def _inference_step(inputs, model=None): | |
return model(inputs, training=False) | |
module = ExportModule(params, model, inference_step=_inference_step) | |
``` | |
`preprocessor` and `postprocessor` could be either functions or `tf.Module`. | |
The usages of preprocessor and postprocessor are managed by the | |
implementation of `serve()` method. | |
Args: | |
params: A dataclass for parameters to the module. | |
model: A model instance which contains weights and forward computation. | |
inference_step: An optional callable to forward-pass the model. If not | |
specified, it creates a parital function with `model` as an required | |
kwarg. | |
preprocessor: An optional callable to preprocess the inputs. | |
postprocessor: An optional callable to postprocess the model outputs. | |
""" | |
super().__init__(name=None) | |
self.model = model | |
self.params = params | |
if inference_step is not None: | |
self.inference_step = functools.partial(inference_step, model=self.model) | |
else: | |
if issubclass(type(model), tf_keras.Model): | |
# Default to self.model.call instead of self.model.__call__ to avoid | |
# keras tracing logic designed for training. | |
# Since most of Model Garden's call doesn't not have training kwargs | |
# or the default is False, we don't pass anything here. | |
# Please pass custom inference step if your model has training=True as | |
# default. | |
self.inference_step = self.model.call | |
else: | |
self.inference_step = functools.partial( | |
self.model.__call__, training=False) | |
self.preprocessor = preprocessor | |
self.postprocessor = postprocessor | |
def serve(self) -> Mapping[Text, tf.Tensor]: | |
"""The bare inference function which should run on all devices. | |
Expecting tensors are passed in through keyword arguments. Returns a | |
dictionary of tensors, when the keys will be used inside the SignatureDef. | |
""" | |
def get_inference_signatures( | |
self, function_keys: Dict[Text, Text]) -> Mapping[Text, Any]: | |
"""Get defined function signatures.""" | |
def export(export_module: ExportModule, | |
function_keys: Union[List[Text], Dict[Text, Text]], | |
export_savedmodel_dir: Text, | |
checkpoint_path: Optional[Text] = None, | |
timestamped: bool = True, | |
save_options: Optional[tf.saved_model.SaveOptions] = None, | |
checkpoint: Optional[tf.train.Checkpoint] = None) -> Text: | |
"""Exports to SavedModel format. | |
Args: | |
export_module: a ExportModule with the keras Model and serving tf.functions. | |
function_keys: a list of string keys to retrieve pre-defined serving | |
signatures. The signaute keys will be set with defaults. If a dictionary | |
is provided, the values will be used as signature keys. | |
export_savedmodel_dir: Output saved model directory. | |
checkpoint_path: Object-based checkpoint path or directory. | |
timestamped: Whether to export the savedmodel to a timestamped directory. | |
save_options: `SaveOptions` for `tf.saved_model.save`. | |
checkpoint: An optional tf.train.Checkpoint. If provided, the export module | |
will use it to read the weights. | |
Returns: | |
The savedmodel directory path. | |
""" | |
ckpt_dir_or_file = checkpoint_path | |
if ckpt_dir_or_file is not None and tf.io.gfile.isdir(ckpt_dir_or_file): | |
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file) | |
if ckpt_dir_or_file: | |
if checkpoint is None: | |
checkpoint = tf.train.Checkpoint(model=export_module.model) | |
checkpoint.read( | |
ckpt_dir_or_file).assert_existing_objects_matched().expect_partial() | |
if isinstance(function_keys, list): | |
if len(function_keys) == 1: | |
function_keys = { | |
function_keys[0]: tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY | |
} | |
else: | |
raise ValueError( | |
'If the function_keys is a list, it must contain a single element. %s' | |
% function_keys) | |
signatures = export_module.get_inference_signatures(function_keys) | |
if timestamped: | |
export_dir = get_timestamped_export_dir(export_savedmodel_dir).decode( | |
'utf-8') | |
else: | |
export_dir = export_savedmodel_dir | |
tf.saved_model.save( | |
export_module, export_dir, signatures=signatures, options=save_options) | |
return export_dir | |
def get_timestamped_export_dir(export_dir_base): | |
"""Builds a path to a new subdirectory within the base directory. | |
Args: | |
export_dir_base: A string containing a directory to write the exported graph | |
and checkpoints. | |
Returns: | |
The full path of the new subdirectory (which is not actually created yet). | |
Raises: | |
RuntimeError: if repeated attempts fail to obtain a unique timestamped | |
directory name. | |
""" | |
attempts = 0 | |
while attempts < MAX_DIRECTORY_CREATION_ATTEMPTS: | |
timestamp = int(time.time()) | |
result_dir = tf.io.gfile.join( | |
tf.compat.as_bytes(export_dir_base), tf.compat.as_bytes(str(timestamp))) | |
if not tf.io.gfile.exists(result_dir): | |
# Collisions are still possible (though extremely unlikely): this | |
# directory is not actually created yet, but it will be almost | |
# instantly on return from this function. | |
return result_dir | |
time.sleep(1) | |
attempts += 1 | |
logging.warning('Directory %s already exists; retrying (attempt %s/%s)', | |
str(result_dir), attempts, MAX_DIRECTORY_CREATION_ATTEMPTS) | |
raise RuntimeError('Failed to obtain a unique export directory name after ' | |
f'{MAX_DIRECTORY_CREATION_ATTEMPTS} attempts.') | |