# Copyright 2023 The Orbit Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for orbit.actions.export_saved_model.""" import os from orbit import actions from orbit.actions import export_saved_model import tensorflow as tf, tf_keras def _id_key(name): _, id_num = name.rsplit('-', maxsplit=1) return int(id_num) def _id_sorted_file_base_names(dir_path): return sorted(tf.io.gfile.listdir(dir_path), key=_id_key) class TestModel(tf.Module): def __init__(self): self.value = tf.Variable(0) @tf.function(input_signature=[]) def __call__(self): return self.value class ExportSavedModelTest(tf.test.TestCase): def test_export_file_manager_default_ids(self): directory = self.create_tempdir() base_name = os.path.join(directory.full_path, 'basename') manager = actions.ExportFileManager(base_name, max_to_keep=3) self.assertEmpty(tf.io.gfile.listdir(directory.full_path)) directory.create_file(manager.next_name()) manager.clean_up() # Shouldn't do anything... self.assertLen(tf.io.gfile.listdir(directory.full_path), 1) directory.create_file(manager.next_name()) manager.clean_up() # Shouldn't do anything... self.assertLen(tf.io.gfile.listdir(directory.full_path), 2) directory.create_file(manager.next_name()) manager.clean_up() # Shouldn't do anything... self.assertLen(tf.io.gfile.listdir(directory.full_path), 3) directory.create_file(manager.next_name()) self.assertLen(tf.io.gfile.listdir(directory.full_path), 4) self.assertEqual( _id_sorted_file_base_names(directory.full_path), ['basename-0', 'basename-1', 'basename-2', 'basename-3']) manager.clean_up() # Should delete file with lowest ID. self.assertEqual( _id_sorted_file_base_names(directory.full_path), ['basename-1', 'basename-2', 'basename-3']) manager = actions.ExportFileManager(base_name, max_to_keep=3) self.assertEqual(os.path.basename(manager.next_name()), 'basename-4') def test_export_file_manager_custom_ids(self): directory = self.create_tempdir() base_name = os.path.join(directory.full_path, 'basename') id_num = 0 def next_id(): return id_num manager = actions.ExportFileManager( base_name, max_to_keep=2, next_id_fn=next_id) self.assertEmpty(tf.io.gfile.listdir(directory.full_path)) id_num = 30 directory.create_file(manager.next_name()) self.assertLen(tf.io.gfile.listdir(directory.full_path), 1) manager.clean_up() # Shouldn't do anything... self.assertEqual( _id_sorted_file_base_names(directory.full_path), ['basename-30']) id_num = 200 directory.create_file(manager.next_name()) self.assertLen(tf.io.gfile.listdir(directory.full_path), 2) manager.clean_up() # Shouldn't do anything... self.assertEqual( _id_sorted_file_base_names(directory.full_path), ['basename-30', 'basename-200']) id_num = 1000 directory.create_file(manager.next_name()) self.assertLen(tf.io.gfile.listdir(directory.full_path), 3) self.assertEqual( _id_sorted_file_base_names(directory.full_path), ['basename-30', 'basename-200', 'basename-1000']) manager.clean_up() # Should delete file with lowest ID. self.assertLen(tf.io.gfile.listdir(directory.full_path), 2) self.assertEqual( _id_sorted_file_base_names(directory.full_path), ['basename-200', 'basename-1000']) def test_export_file_manager_with_suffix(self): directory = self.create_tempdir() base_name = os.path.join(directory.full_path, 'basename') id_num = 0 def next_id(): return id_num subdirectory = 'sub' manager = actions.ExportFileManager( base_name, max_to_keep=2, next_id_fn=next_id, subdirectory=subdirectory ) self.assertEmpty(tf.io.gfile.listdir(directory.full_path)) id_num = 30 directory.create_file(manager.next_name()) self.assertLen(tf.io.gfile.listdir(directory.full_path), 1) manager.clean_up() # Shouldn't do anything... self.assertEqual( _id_sorted_file_base_names(directory.full_path), ['basename-30'] ) id_num = 200 directory.create_file(manager.next_name()) self.assertLen(tf.io.gfile.listdir(directory.full_path), 2) manager.clean_up() # Shouldn't do anything... self.assertEqual( _id_sorted_file_base_names(directory.full_path), ['basename-30', 'basename-200'], ) id_num = 1000 directory.create_file(manager.next_name()) self.assertLen(tf.io.gfile.listdir(directory.full_path), 3) self.assertEqual( _id_sorted_file_base_names(directory.full_path), ['basename-30', 'basename-200', 'basename-1000'], ) manager.clean_up() # Should delete file with lowest ID. self.assertLen(tf.io.gfile.listdir(directory.full_path), 3) # Note that the base folder is intact, only the suffix folder is deleted. self.assertEqual( _id_sorted_file_base_names(directory.full_path), ['basename-30', 'basename-200', 'basename-1000'], ) step_folder = os.path.join(directory.full_path, 'basename-1000') self.assertIn(subdirectory, tf.io.gfile.listdir(step_folder)) def test_export_file_manager_managed_files(self): directory = self.create_tempdir() directory.create_file('basename-5') directory.create_file('basename-10') directory.create_file('basename-50') directory.create_file('basename-1000') directory.create_file('basename-9') directory.create_file('basename-10-suffix') base_name = os.path.join(directory.full_path, 'basename') manager = actions.ExportFileManager(base_name, max_to_keep=3) self.assertLen(manager.managed_files, 5) self.assertEqual(manager.next_name(), f'{base_name}-1001') manager.clean_up() self.assertEqual( manager.managed_files, [f'{base_name}-10', f'{base_name}-50', f'{base_name}-1000']) def test_export_file_manager_managed_files_double_slash(self): directory = self.create_tempdir('foo//bar') directory.create_file('basename-5') directory.create_file('basename-10') directory.create_file('basename-50') directory.create_file('basename-1000') directory.create_file('basename-9') directory.create_file('basename-10-suffix') base_name = os.path.join(directory.full_path, 'basename') expected_base_name = os.path.normpath(base_name) self.assertNotEqual(base_name, expected_base_name) manager = actions.ExportFileManager(base_name, max_to_keep=3) self.assertLen(manager.managed_files, 5) self.assertEqual(manager.next_name(), f'{expected_base_name}-1001') manager.clean_up() self.assertEqual(manager.managed_files, [ f'{expected_base_name}-10', f'{expected_base_name}-50', f'{expected_base_name}-1000' ]) def test_export_saved_model(self): directory = self.create_tempdir() base_name = os.path.join(directory.full_path, 'basename') file_manager = actions.ExportFileManager(base_name, max_to_keep=2) model = TestModel() export_action = actions.ExportSavedModel( model, file_manager=file_manager, signatures=model.__call__) model.value.assign(3) self.assertEqual(model(), 3) self.assertEmpty(file_manager.managed_files) export_action({}) self.assertLen(file_manager.managed_files, 1) reloaded_model = tf.saved_model.load(file_manager.managed_files[-1]) self.assertEqual(reloaded_model(), 3) model.value.assign(5) self.assertEqual(model(), 5) export_action({}) self.assertLen(file_manager.managed_files, 2) reloaded_model = tf.saved_model.load(file_manager.managed_files[-1]) self.assertEqual(reloaded_model(), 5) model.value.assign(7) self.assertEqual(model(), 7) export_action({}) self.assertLen(file_manager.managed_files, 2) # Still 2, due to clean up. reloaded_model = tf.saved_model.load(file_manager.managed_files[-1]) self.assertEqual(reloaded_model(), 7) def test_safe_normpath_gs(self): path = export_saved_model.safe_normpath('gs://foo//bar') self.assertEqual(path, 'gs://foo/bar') if __name__ == '__main__': tf.test.main()