deanna-emery's picture
updates
93528c6
raw
history blame
7.04 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base class for model export."""
import abc
import functools
import time
from typing import Any, Callable, Dict, Mapping, List, Optional, Text, Union
from absl import logging
import tensorflow as tf, tf_keras
MAX_DIRECTORY_CREATION_ATTEMPTS = 10
class ExportModule(tf.Module, metaclass=abc.ABCMeta):
"""Base Export Module."""
def __init__(self,
params,
model: Union[tf.Module, tf_keras.Model],
inference_step: Optional[Callable[..., Any]] = None,
*,
preprocessor: Optional[Callable[..., Any]] = None,
postprocessor: Optional[Callable[..., Any]] = None):
"""Instantiates an ExportModel.
Examples:
`inference_step` must be a function that has `model` as an kwarg or the
second positional argument.
```
def _inference_step(inputs, model=None):
return model(inputs, training=False)
module = ExportModule(params, model, inference_step=_inference_step)
```
`preprocessor` and `postprocessor` could be either functions or `tf.Module`.
The usages of preprocessor and postprocessor are managed by the
implementation of `serve()` method.
Args:
params: A dataclass for parameters to the module.
model: A model instance which contains weights and forward computation.
inference_step: An optional callable to forward-pass the model. If not
specified, it creates a parital function with `model` as an required
kwarg.
preprocessor: An optional callable to preprocess the inputs.
postprocessor: An optional callable to postprocess the model outputs.
"""
super().__init__(name=None)
self.model = model
self.params = params
if inference_step is not None:
self.inference_step = functools.partial(inference_step, model=self.model)
else:
if issubclass(type(model), tf_keras.Model):
# Default to self.model.call instead of self.model.__call__ to avoid
# keras tracing logic designed for training.
# Since most of Model Garden's call doesn't not have training kwargs
# or the default is False, we don't pass anything here.
# Please pass custom inference step if your model has training=True as
# default.
self.inference_step = self.model.call
else:
self.inference_step = functools.partial(
self.model.__call__, training=False)
self.preprocessor = preprocessor
self.postprocessor = postprocessor
@abc.abstractmethod
def serve(self) -> Mapping[Text, tf.Tensor]:
"""The bare inference function which should run on all devices.
Expecting tensors are passed in through keyword arguments. Returns a
dictionary of tensors, when the keys will be used inside the SignatureDef.
"""
@abc.abstractmethod
def get_inference_signatures(
self, function_keys: Dict[Text, Text]) -> Mapping[Text, Any]:
"""Get defined function signatures."""
def export(export_module: ExportModule,
function_keys: Union[List[Text], Dict[Text, Text]],
export_savedmodel_dir: Text,
checkpoint_path: Optional[Text] = None,
timestamped: bool = True,
save_options: Optional[tf.saved_model.SaveOptions] = None,
checkpoint: Optional[tf.train.Checkpoint] = None) -> Text:
"""Exports to SavedModel format.
Args:
export_module: a ExportModule with the keras Model and serving tf.functions.
function_keys: a list of string keys to retrieve pre-defined serving
signatures. The signaute keys will be set with defaults. If a dictionary
is provided, the values will be used as signature keys.
export_savedmodel_dir: Output saved model directory.
checkpoint_path: Object-based checkpoint path or directory.
timestamped: Whether to export the savedmodel to a timestamped directory.
save_options: `SaveOptions` for `tf.saved_model.save`.
checkpoint: An optional tf.train.Checkpoint. If provided, the export module
will use it to read the weights.
Returns:
The savedmodel directory path.
"""
ckpt_dir_or_file = checkpoint_path
if ckpt_dir_or_file is not None and tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
if ckpt_dir_or_file:
if checkpoint is None:
checkpoint = tf.train.Checkpoint(model=export_module.model)
checkpoint.read(
ckpt_dir_or_file).assert_existing_objects_matched().expect_partial()
if isinstance(function_keys, list):
if len(function_keys) == 1:
function_keys = {
function_keys[0]: tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
}
else:
raise ValueError(
'If the function_keys is a list, it must contain a single element. %s'
% function_keys)
signatures = export_module.get_inference_signatures(function_keys)
if timestamped:
export_dir = get_timestamped_export_dir(export_savedmodel_dir).decode(
'utf-8')
else:
export_dir = export_savedmodel_dir
tf.saved_model.save(
export_module, export_dir, signatures=signatures, options=save_options)
return export_dir
def get_timestamped_export_dir(export_dir_base):
"""Builds a path to a new subdirectory within the base directory.
Args:
export_dir_base: A string containing a directory to write the exported graph
and checkpoints.
Returns:
The full path of the new subdirectory (which is not actually created yet).
Raises:
RuntimeError: if repeated attempts fail to obtain a unique timestamped
directory name.
"""
attempts = 0
while attempts < MAX_DIRECTORY_CREATION_ATTEMPTS:
timestamp = int(time.time())
result_dir = tf.io.gfile.join(
tf.compat.as_bytes(export_dir_base), tf.compat.as_bytes(str(timestamp)))
if not tf.io.gfile.exists(result_dir):
# Collisions are still possible (though extremely unlikely): this
# directory is not actually created yet, but it will be almost
# instantly on return from this function.
return result_dir
time.sleep(1)
attempts += 1
logging.warning('Directory %s already exists; retrying (attempt %s/%s)',
str(result_dir), attempts, MAX_DIRECTORY_CREATION_ATTEMPTS)
raise RuntimeError('Failed to obtain a unique export directory name after '
f'{MAX_DIRECTORY_CREATION_ATTEMPTS} attempts.')