File size: 5,834 Bytes
5672777
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
# Copyright 2023 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Provides the `ExportSavedModel` action and associated helper classes."""

import os
import re

from typing import Callable, Optional

import tensorflow as tf, tf_keras


_GS_PREFIX = r'gs://'  # Google Cloud Storage Prefix


def safe_normpath(path: str) -> str:
  """Normalize path safely to get around gfile.glob limitations."""
  if path.startswith(_GS_PREFIX):
    return _GS_PREFIX + os.path.normpath(path[len(_GS_PREFIX):])
  return os.path.normpath(path)


def _id_key(filename):
  _, id_num = filename.rsplit('-', maxsplit=1)
  return int(id_num)


def _find_managed_files(base_name):
  r"""Returns all files matching '{base_name}-\d+', in sorted order."""
  managed_file_regex = re.compile(rf'{re.escape(base_name)}-\d+$')
  filenames = tf.io.gfile.glob(f'{base_name}-*')
  filenames = filter(managed_file_regex.match, filenames)
  return sorted(filenames, key=_id_key)


class _CounterIdFn:
  """Implements a counter-based ID function for `ExportFileManager`."""

  def __init__(self, base_name: str):
    managed_files = _find_managed_files(base_name)
    self.value = _id_key(managed_files[-1]) + 1 if managed_files else 0

  def __call__(self):
    output = self.value
    self.value += 1
    return output


class ExportFileManager:
  """Utility class that manages a group of files with a shared base name.

  For actions like SavedModel exporting, there are potentially many different
  file naming and cleanup strategies that may be desirable. This class provides
  a basic interface allowing SavedModel export to be decoupled from these
  details, and a default implementation that should work for many basic
  scenarios. Users may subclass this class to alter behavior and define more
  customized naming and cleanup strategies.
  """

  def __init__(
      self,
      base_name: str,
      max_to_keep: int = 5,
      next_id_fn: Optional[Callable[[], int]] = None,
      subdirectory: Optional[str] = None,
  ):
    """Initializes the instance.

    Args:
      base_name: A shared base name for file names generated by this class.
      max_to_keep: The maximum number of files matching `base_name` to keep
        after each call to `cleanup`. The most recent (as determined by file
        modification time) `max_to_keep` files are preserved; the rest are
        deleted. If < 0, all files are preserved.
      next_id_fn: An optional callable that returns integer IDs to append to
        base name (formatted as `'{base_name}-{id}'`). The order of integers is
        used to sort files to determine the oldest ones deleted by `clean_up`.
        If not supplied, a default ID based on an incrementing counter is used.
        One common alternative maybe be to use the current global step count,
        for instance passing `next_id_fn=global_step.numpy`.
      subdirectory: An optional subdirectory to concat after the
        {base_name}-{id}. Then the file manager will manage
        {base_name}-{id}/{subdirectory} files.
    """
    self._base_name = safe_normpath(base_name)
    self._max_to_keep = max_to_keep
    self._next_id_fn = next_id_fn or _CounterIdFn(self._base_name)
    self._subdirectory = subdirectory or ''

  @property
  def managed_files(self):
    """Returns all files managed by this instance, in sorted order.

    Returns:
      The list of files matching the `base_name` provided when constructing this
      `ExportFileManager` instance, sorted in increasing integer order of the
      IDs returned by `next_id_fn`.
    """
    files = _find_managed_files(self._base_name)
    return [
        safe_normpath(os.path.join(f, self._subdirectory)) for f in files
    ]

  def clean_up(self):
    """Cleans up old files matching `{base_name}-*`.

    The most recent `max_to_keep` files are preserved.
    """
    if self._max_to_keep < 0:
      return

    # Note that the base folder will remain intact, only the folder with suffix
    # is deleted.
    for filename in self.managed_files[: -self._max_to_keep]:
      tf.io.gfile.rmtree(filename)

  def next_name(self) -> str:
    """Returns a new file name based on `base_name` and `next_id_fn()`."""
    base_path = f'{self._base_name}-{self._next_id_fn()}'
    return safe_normpath(os.path.join(base_path, self._subdirectory))


class ExportSavedModel:
  """Action that exports the given model as a SavedModel."""

  def __init__(self,
               model: tf.Module,
               file_manager: ExportFileManager,
               signatures,
               options: Optional[tf.saved_model.SaveOptions] = None):
    """Initializes the instance.

    Args:
      model: The model to export.
      file_manager: An instance of `ExportFileManager` (or a subclass), that
        provides file naming and cleanup functionality.
      signatures: The signatures to forward to `tf.saved_model.save()`.
      options: Optional options to forward to `tf.saved_model.save()`.
    """
    self.model = model
    self.file_manager = file_manager
    self.signatures = signatures
    self.options = options

  def __call__(self, _):
    """Exports the SavedModel."""
    export_dir = self.file_manager.next_name()
    tf.saved_model.save(self.model, export_dir, self.signatures, self.options)
    self.file_manager.clean_up()