Spaces:
Runtime error
Runtime error
# Copyright 2023 The Orbit Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Tests for orbit.actions.export_saved_model.""" | |
import os | |
from orbit import actions | |
from orbit.actions import export_saved_model | |
import tensorflow as tf, tf_keras | |
def _id_key(name): | |
_, id_num = name.rsplit('-', maxsplit=1) | |
return int(id_num) | |
def _id_sorted_file_base_names(dir_path): | |
return sorted(tf.io.gfile.listdir(dir_path), key=_id_key) | |
class TestModel(tf.Module): | |
def __init__(self): | |
self.value = tf.Variable(0) | |
def __call__(self): | |
return self.value | |
class ExportSavedModelTest(tf.test.TestCase): | |
def test_export_file_manager_default_ids(self): | |
directory = self.create_tempdir() | |
base_name = os.path.join(directory.full_path, 'basename') | |
manager = actions.ExportFileManager(base_name, max_to_keep=3) | |
self.assertEmpty(tf.io.gfile.listdir(directory.full_path)) | |
directory.create_file(manager.next_name()) | |
manager.clean_up() # Shouldn't do anything... | |
self.assertLen(tf.io.gfile.listdir(directory.full_path), 1) | |
directory.create_file(manager.next_name()) | |
manager.clean_up() # Shouldn't do anything... | |
self.assertLen(tf.io.gfile.listdir(directory.full_path), 2) | |
directory.create_file(manager.next_name()) | |
manager.clean_up() # Shouldn't do anything... | |
self.assertLen(tf.io.gfile.listdir(directory.full_path), 3) | |
directory.create_file(manager.next_name()) | |
self.assertLen(tf.io.gfile.listdir(directory.full_path), 4) | |
self.assertEqual( | |
_id_sorted_file_base_names(directory.full_path), | |
['basename-0', 'basename-1', 'basename-2', 'basename-3']) | |
manager.clean_up() # Should delete file with lowest ID. | |
self.assertEqual( | |
_id_sorted_file_base_names(directory.full_path), | |
['basename-1', 'basename-2', 'basename-3']) | |
manager = actions.ExportFileManager(base_name, max_to_keep=3) | |
self.assertEqual(os.path.basename(manager.next_name()), 'basename-4') | |
def test_export_file_manager_custom_ids(self): | |
directory = self.create_tempdir() | |
base_name = os.path.join(directory.full_path, 'basename') | |
id_num = 0 | |
def next_id(): | |
return id_num | |
manager = actions.ExportFileManager( | |
base_name, max_to_keep=2, next_id_fn=next_id) | |
self.assertEmpty(tf.io.gfile.listdir(directory.full_path)) | |
id_num = 30 | |
directory.create_file(manager.next_name()) | |
self.assertLen(tf.io.gfile.listdir(directory.full_path), 1) | |
manager.clean_up() # Shouldn't do anything... | |
self.assertEqual( | |
_id_sorted_file_base_names(directory.full_path), ['basename-30']) | |
id_num = 200 | |
directory.create_file(manager.next_name()) | |
self.assertLen(tf.io.gfile.listdir(directory.full_path), 2) | |
manager.clean_up() # Shouldn't do anything... | |
self.assertEqual( | |
_id_sorted_file_base_names(directory.full_path), | |
['basename-30', 'basename-200']) | |
id_num = 1000 | |
directory.create_file(manager.next_name()) | |
self.assertLen(tf.io.gfile.listdir(directory.full_path), 3) | |
self.assertEqual( | |
_id_sorted_file_base_names(directory.full_path), | |
['basename-30', 'basename-200', 'basename-1000']) | |
manager.clean_up() # Should delete file with lowest ID. | |
self.assertLen(tf.io.gfile.listdir(directory.full_path), 2) | |
self.assertEqual( | |
_id_sorted_file_base_names(directory.full_path), | |
['basename-200', 'basename-1000']) | |
def test_export_file_manager_with_suffix(self): | |
directory = self.create_tempdir() | |
base_name = os.path.join(directory.full_path, 'basename') | |
id_num = 0 | |
def next_id(): | |
return id_num | |
subdirectory = 'sub' | |
manager = actions.ExportFileManager( | |
base_name, max_to_keep=2, next_id_fn=next_id, subdirectory=subdirectory | |
) | |
self.assertEmpty(tf.io.gfile.listdir(directory.full_path)) | |
id_num = 30 | |
directory.create_file(manager.next_name()) | |
self.assertLen(tf.io.gfile.listdir(directory.full_path), 1) | |
manager.clean_up() # Shouldn't do anything... | |
self.assertEqual( | |
_id_sorted_file_base_names(directory.full_path), ['basename-30'] | |
) | |
id_num = 200 | |
directory.create_file(manager.next_name()) | |
self.assertLen(tf.io.gfile.listdir(directory.full_path), 2) | |
manager.clean_up() # Shouldn't do anything... | |
self.assertEqual( | |
_id_sorted_file_base_names(directory.full_path), | |
['basename-30', 'basename-200'], | |
) | |
id_num = 1000 | |
directory.create_file(manager.next_name()) | |
self.assertLen(tf.io.gfile.listdir(directory.full_path), 3) | |
self.assertEqual( | |
_id_sorted_file_base_names(directory.full_path), | |
['basename-30', 'basename-200', 'basename-1000'], | |
) | |
manager.clean_up() # Should delete file with lowest ID. | |
self.assertLen(tf.io.gfile.listdir(directory.full_path), 3) | |
# Note that the base folder is intact, only the suffix folder is deleted. | |
self.assertEqual( | |
_id_sorted_file_base_names(directory.full_path), | |
['basename-30', 'basename-200', 'basename-1000'], | |
) | |
step_folder = os.path.join(directory.full_path, 'basename-1000') | |
self.assertIn(subdirectory, tf.io.gfile.listdir(step_folder)) | |
def test_export_file_manager_managed_files(self): | |
directory = self.create_tempdir() | |
directory.create_file('basename-5') | |
directory.create_file('basename-10') | |
directory.create_file('basename-50') | |
directory.create_file('basename-1000') | |
directory.create_file('basename-9') | |
directory.create_file('basename-10-suffix') | |
base_name = os.path.join(directory.full_path, 'basename') | |
manager = actions.ExportFileManager(base_name, max_to_keep=3) | |
self.assertLen(manager.managed_files, 5) | |
self.assertEqual(manager.next_name(), f'{base_name}-1001') | |
manager.clean_up() | |
self.assertEqual( | |
manager.managed_files, | |
[f'{base_name}-10', f'{base_name}-50', f'{base_name}-1000']) | |
def test_export_file_manager_managed_files_double_slash(self): | |
directory = self.create_tempdir('foo//bar') | |
directory.create_file('basename-5') | |
directory.create_file('basename-10') | |
directory.create_file('basename-50') | |
directory.create_file('basename-1000') | |
directory.create_file('basename-9') | |
directory.create_file('basename-10-suffix') | |
base_name = os.path.join(directory.full_path, 'basename') | |
expected_base_name = os.path.normpath(base_name) | |
self.assertNotEqual(base_name, expected_base_name) | |
manager = actions.ExportFileManager(base_name, max_to_keep=3) | |
self.assertLen(manager.managed_files, 5) | |
self.assertEqual(manager.next_name(), f'{expected_base_name}-1001') | |
manager.clean_up() | |
self.assertEqual(manager.managed_files, [ | |
f'{expected_base_name}-10', f'{expected_base_name}-50', | |
f'{expected_base_name}-1000' | |
]) | |
def test_export_saved_model(self): | |
directory = self.create_tempdir() | |
base_name = os.path.join(directory.full_path, 'basename') | |
file_manager = actions.ExportFileManager(base_name, max_to_keep=2) | |
model = TestModel() | |
export_action = actions.ExportSavedModel( | |
model, file_manager=file_manager, signatures=model.__call__) | |
model.value.assign(3) | |
self.assertEqual(model(), 3) | |
self.assertEmpty(file_manager.managed_files) | |
export_action({}) | |
self.assertLen(file_manager.managed_files, 1) | |
reloaded_model = tf.saved_model.load(file_manager.managed_files[-1]) | |
self.assertEqual(reloaded_model(), 3) | |
model.value.assign(5) | |
self.assertEqual(model(), 5) | |
export_action({}) | |
self.assertLen(file_manager.managed_files, 2) | |
reloaded_model = tf.saved_model.load(file_manager.managed_files[-1]) | |
self.assertEqual(reloaded_model(), 5) | |
model.value.assign(7) | |
self.assertEqual(model(), 7) | |
export_action({}) | |
self.assertLen(file_manager.managed_files, 2) # Still 2, due to clean up. | |
reloaded_model = tf.saved_model.load(file_manager.managed_files[-1]) | |
self.assertEqual(reloaded_model(), 7) | |
def test_safe_normpath_gs(self): | |
path = export_saved_model.safe_normpath('gs://foo//bar') | |
self.assertEqual(path, 'gs://foo/bar') | |
if __name__ == '__main__': | |
tf.test.main() | |