Spaces:
Runtime error
Runtime error
File size: 5,834 Bytes
5672777 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
# Copyright 2023 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Provides the `ExportSavedModel` action and associated helper classes."""
import os
import re
from typing import Callable, Optional
import tensorflow as tf, tf_keras
_GS_PREFIX = r'gs://' # Google Cloud Storage Prefix
def safe_normpath(path: str) -> str:
"""Normalize path safely to get around gfile.glob limitations."""
if path.startswith(_GS_PREFIX):
return _GS_PREFIX + os.path.normpath(path[len(_GS_PREFIX):])
return os.path.normpath(path)
def _id_key(filename):
_, id_num = filename.rsplit('-', maxsplit=1)
return int(id_num)
def _find_managed_files(base_name):
r"""Returns all files matching '{base_name}-\d+', in sorted order."""
managed_file_regex = re.compile(rf'{re.escape(base_name)}-\d+$')
filenames = tf.io.gfile.glob(f'{base_name}-*')
filenames = filter(managed_file_regex.match, filenames)
return sorted(filenames, key=_id_key)
class _CounterIdFn:
"""Implements a counter-based ID function for `ExportFileManager`."""
def __init__(self, base_name: str):
managed_files = _find_managed_files(base_name)
self.value = _id_key(managed_files[-1]) + 1 if managed_files else 0
def __call__(self):
output = self.value
self.value += 1
return output
class ExportFileManager:
"""Utility class that manages a group of files with a shared base name.
For actions like SavedModel exporting, there are potentially many different
file naming and cleanup strategies that may be desirable. This class provides
a basic interface allowing SavedModel export to be decoupled from these
details, and a default implementation that should work for many basic
scenarios. Users may subclass this class to alter behavior and define more
customized naming and cleanup strategies.
"""
def __init__(
self,
base_name: str,
max_to_keep: int = 5,
next_id_fn: Optional[Callable[[], int]] = None,
subdirectory: Optional[str] = None,
):
"""Initializes the instance.
Args:
base_name: A shared base name for file names generated by this class.
max_to_keep: The maximum number of files matching `base_name` to keep
after each call to `cleanup`. The most recent (as determined by file
modification time) `max_to_keep` files are preserved; the rest are
deleted. If < 0, all files are preserved.
next_id_fn: An optional callable that returns integer IDs to append to
base name (formatted as `'{base_name}-{id}'`). The order of integers is
used to sort files to determine the oldest ones deleted by `clean_up`.
If not supplied, a default ID based on an incrementing counter is used.
One common alternative maybe be to use the current global step count,
for instance passing `next_id_fn=global_step.numpy`.
subdirectory: An optional subdirectory to concat after the
{base_name}-{id}. Then the file manager will manage
{base_name}-{id}/{subdirectory} files.
"""
self._base_name = safe_normpath(base_name)
self._max_to_keep = max_to_keep
self._next_id_fn = next_id_fn or _CounterIdFn(self._base_name)
self._subdirectory = subdirectory or ''
@property
def managed_files(self):
"""Returns all files managed by this instance, in sorted order.
Returns:
The list of files matching the `base_name` provided when constructing this
`ExportFileManager` instance, sorted in increasing integer order of the
IDs returned by `next_id_fn`.
"""
files = _find_managed_files(self._base_name)
return [
safe_normpath(os.path.join(f, self._subdirectory)) for f in files
]
def clean_up(self):
"""Cleans up old files matching `{base_name}-*`.
The most recent `max_to_keep` files are preserved.
"""
if self._max_to_keep < 0:
return
# Note that the base folder will remain intact, only the folder with suffix
# is deleted.
for filename in self.managed_files[: -self._max_to_keep]:
tf.io.gfile.rmtree(filename)
def next_name(self) -> str:
"""Returns a new file name based on `base_name` and `next_id_fn()`."""
base_path = f'{self._base_name}-{self._next_id_fn()}'
return safe_normpath(os.path.join(base_path, self._subdirectory))
class ExportSavedModel:
"""Action that exports the given model as a SavedModel."""
def __init__(self,
model: tf.Module,
file_manager: ExportFileManager,
signatures,
options: Optional[tf.saved_model.SaveOptions] = None):
"""Initializes the instance.
Args:
model: The model to export.
file_manager: An instance of `ExportFileManager` (or a subclass), that
provides file naming and cleanup functionality.
signatures: The signatures to forward to `tf.saved_model.save()`.
options: Optional options to forward to `tf.saved_model.save()`.
"""
self.model = model
self.file_manager = file_manager
self.signatures = signatures
self.options = options
def __call__(self, _):
"""Exports the SavedModel."""
export_dir = self.file_manager.next_name()
tf.saved_model.save(self.model, export_dir, self.signatures, self.options)
self.file_manager.clean_up()
|