Spaces:
Runtime error
Runtime error
# Copyright 2023 The Orbit Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Provides the `ExportSavedModel` action and associated helper classes.""" | |
import os | |
import re | |
from typing import Callable, Optional | |
import tensorflow as tf, tf_keras | |
_GS_PREFIX = r'gs://' # Google Cloud Storage Prefix | |
def safe_normpath(path: str) -> str: | |
"""Normalize path safely to get around gfile.glob limitations.""" | |
if path.startswith(_GS_PREFIX): | |
return _GS_PREFIX + os.path.normpath(path[len(_GS_PREFIX):]) | |
return os.path.normpath(path) | |
def _id_key(filename): | |
_, id_num = filename.rsplit('-', maxsplit=1) | |
return int(id_num) | |
def _find_managed_files(base_name): | |
r"""Returns all files matching '{base_name}-\d+', in sorted order.""" | |
managed_file_regex = re.compile(rf'{re.escape(base_name)}-\d+$') | |
filenames = tf.io.gfile.glob(f'{base_name}-*') | |
filenames = filter(managed_file_regex.match, filenames) | |
return sorted(filenames, key=_id_key) | |
class _CounterIdFn: | |
"""Implements a counter-based ID function for `ExportFileManager`.""" | |
def __init__(self, base_name: str): | |
managed_files = _find_managed_files(base_name) | |
self.value = _id_key(managed_files[-1]) + 1 if managed_files else 0 | |
def __call__(self): | |
output = self.value | |
self.value += 1 | |
return output | |
class ExportFileManager: | |
"""Utility class that manages a group of files with a shared base name. | |
For actions like SavedModel exporting, there are potentially many different | |
file naming and cleanup strategies that may be desirable. This class provides | |
a basic interface allowing SavedModel export to be decoupled from these | |
details, and a default implementation that should work for many basic | |
scenarios. Users may subclass this class to alter behavior and define more | |
customized naming and cleanup strategies. | |
""" | |
def __init__( | |
self, | |
base_name: str, | |
max_to_keep: int = 5, | |
next_id_fn: Optional[Callable[[], int]] = None, | |
subdirectory: Optional[str] = None, | |
): | |
"""Initializes the instance. | |
Args: | |
base_name: A shared base name for file names generated by this class. | |
max_to_keep: The maximum number of files matching `base_name` to keep | |
after each call to `cleanup`. The most recent (as determined by file | |
modification time) `max_to_keep` files are preserved; the rest are | |
deleted. If < 0, all files are preserved. | |
next_id_fn: An optional callable that returns integer IDs to append to | |
base name (formatted as `'{base_name}-{id}'`). The order of integers is | |
used to sort files to determine the oldest ones deleted by `clean_up`. | |
If not supplied, a default ID based on an incrementing counter is used. | |
One common alternative maybe be to use the current global step count, | |
for instance passing `next_id_fn=global_step.numpy`. | |
subdirectory: An optional subdirectory to concat after the | |
{base_name}-{id}. Then the file manager will manage | |
{base_name}-{id}/{subdirectory} files. | |
""" | |
self._base_name = safe_normpath(base_name) | |
self._max_to_keep = max_to_keep | |
self._next_id_fn = next_id_fn or _CounterIdFn(self._base_name) | |
self._subdirectory = subdirectory or '' | |
def managed_files(self): | |
"""Returns all files managed by this instance, in sorted order. | |
Returns: | |
The list of files matching the `base_name` provided when constructing this | |
`ExportFileManager` instance, sorted in increasing integer order of the | |
IDs returned by `next_id_fn`. | |
""" | |
files = _find_managed_files(self._base_name) | |
return [ | |
safe_normpath(os.path.join(f, self._subdirectory)) for f in files | |
] | |
def clean_up(self): | |
"""Cleans up old files matching `{base_name}-*`. | |
The most recent `max_to_keep` files are preserved. | |
""" | |
if self._max_to_keep < 0: | |
return | |
# Note that the base folder will remain intact, only the folder with suffix | |
# is deleted. | |
for filename in self.managed_files[: -self._max_to_keep]: | |
tf.io.gfile.rmtree(filename) | |
def next_name(self) -> str: | |
"""Returns a new file name based on `base_name` and `next_id_fn()`.""" | |
base_path = f'{self._base_name}-{self._next_id_fn()}' | |
return safe_normpath(os.path.join(base_path, self._subdirectory)) | |
class ExportSavedModel: | |
"""Action that exports the given model as a SavedModel.""" | |
def __init__(self, | |
model: tf.Module, | |
file_manager: ExportFileManager, | |
signatures, | |
options: Optional[tf.saved_model.SaveOptions] = None): | |
"""Initializes the instance. | |
Args: | |
model: The model to export. | |
file_manager: An instance of `ExportFileManager` (or a subclass), that | |
provides file naming and cleanup functionality. | |
signatures: The signatures to forward to `tf.saved_model.save()`. | |
options: Optional options to forward to `tf.saved_model.save()`. | |
""" | |
self.model = model | |
self.file_manager = file_manager | |
self.signatures = signatures | |
self.options = options | |
def __call__(self, _): | |
"""Exports the SavedModel.""" | |
export_dir = self.file_manager.next_name() | |
tf.saved_model.save(self.model, export_dir, self.signatures, self.options) | |
self.file_manager.clean_up() | |