# Copyright 2023 The Orbit Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Provides the `ExportSavedModel` action and associated helper classes.""" import os import re from typing import Callable, Optional import tensorflow as tf, tf_keras _GS_PREFIX = r'gs://' # Google Cloud Storage Prefix def safe_normpath(path: str) -> str: """Normalize path safely to get around gfile.glob limitations.""" if path.startswith(_GS_PREFIX): return _GS_PREFIX + os.path.normpath(path[len(_GS_PREFIX):]) return os.path.normpath(path) def _id_key(filename): _, id_num = filename.rsplit('-', maxsplit=1) return int(id_num) def _find_managed_files(base_name): r"""Returns all files matching '{base_name}-\d+', in sorted order.""" managed_file_regex = re.compile(rf'{re.escape(base_name)}-\d+$') filenames = tf.io.gfile.glob(f'{base_name}-*') filenames = filter(managed_file_regex.match, filenames) return sorted(filenames, key=_id_key) class _CounterIdFn: """Implements a counter-based ID function for `ExportFileManager`.""" def __init__(self, base_name: str): managed_files = _find_managed_files(base_name) self.value = _id_key(managed_files[-1]) + 1 if managed_files else 0 def __call__(self): output = self.value self.value += 1 return output class ExportFileManager: """Utility class that manages a group of files with a shared base name. For actions like SavedModel exporting, there are potentially many different file naming and cleanup strategies that may be desirable. This class provides a basic interface allowing SavedModel export to be decoupled from these details, and a default implementation that should work for many basic scenarios. Users may subclass this class to alter behavior and define more customized naming and cleanup strategies. """ def __init__( self, base_name: str, max_to_keep: int = 5, next_id_fn: Optional[Callable[[], int]] = None, subdirectory: Optional[str] = None, ): """Initializes the instance. Args: base_name: A shared base name for file names generated by this class. max_to_keep: The maximum number of files matching `base_name` to keep after each call to `cleanup`. The most recent (as determined by file modification time) `max_to_keep` files are preserved; the rest are deleted. If < 0, all files are preserved. next_id_fn: An optional callable that returns integer IDs to append to base name (formatted as `'{base_name}-{id}'`). The order of integers is used to sort files to determine the oldest ones deleted by `clean_up`. If not supplied, a default ID based on an incrementing counter is used. One common alternative maybe be to use the current global step count, for instance passing `next_id_fn=global_step.numpy`. subdirectory: An optional subdirectory to concat after the {base_name}-{id}. Then the file manager will manage {base_name}-{id}/{subdirectory} files. """ self._base_name = safe_normpath(base_name) self._max_to_keep = max_to_keep self._next_id_fn = next_id_fn or _CounterIdFn(self._base_name) self._subdirectory = subdirectory or '' @property def managed_files(self): """Returns all files managed by this instance, in sorted order. Returns: The list of files matching the `base_name` provided when constructing this `ExportFileManager` instance, sorted in increasing integer order of the IDs returned by `next_id_fn`. """ files = _find_managed_files(self._base_name) return [ safe_normpath(os.path.join(f, self._subdirectory)) for f in files ] def clean_up(self): """Cleans up old files matching `{base_name}-*`. The most recent `max_to_keep` files are preserved. """ if self._max_to_keep < 0: return # Note that the base folder will remain intact, only the folder with suffix # is deleted. for filename in self.managed_files[: -self._max_to_keep]: tf.io.gfile.rmtree(filename) def next_name(self) -> str: """Returns a new file name based on `base_name` and `next_id_fn()`.""" base_path = f'{self._base_name}-{self._next_id_fn()}' return safe_normpath(os.path.join(base_path, self._subdirectory)) class ExportSavedModel: """Action that exports the given model as a SavedModel.""" def __init__(self, model: tf.Module, file_manager: ExportFileManager, signatures, options: Optional[tf.saved_model.SaveOptions] = None): """Initializes the instance. Args: model: The model to export. file_manager: An instance of `ExportFileManager` (or a subclass), that provides file naming and cleanup functionality. signatures: The signatures to forward to `tf.saved_model.save()`. options: Optional options to forward to `tf.saved_model.save()`. """ self.model = model self.file_manager = file_manager self.signatures = signatures self.options = options def __call__(self, _): """Exports the SavedModel.""" export_dir = self.file_manager.next_name() tf.saved_model.save(self.model, export_dir, self.signatures, self.options) self.file_manager.clean_up()