# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import unittest from absl import flags import tensorflow as tf, tf_keras from official.utils.flags import core as flags_core # pylint: disable=g-bad-import-order def define_flags(): flags_core.define_base( clean=True, num_gpu=False, stop_threshold=True, hooks=True, train_epochs=True, epochs_between_evals=True) flags_core.define_performance( num_parallel_calls=True, inter_op=True, intra_op=True, loss_scale=True, synthetic_data=True, dtype=True) flags_core.define_image() flags_core.define_benchmark() class BaseTester(unittest.TestCase): @classmethod def setUpClass(cls): super(BaseTester, cls).setUpClass() define_flags() def test_default_setting(self): """Test to ensure fields exist and defaults can be set.""" defaults = dict( data_dir="dfgasf", model_dir="dfsdkjgbs", train_epochs=534, epochs_between_evals=15, batch_size=256, hooks=["LoggingTensorHook"], num_parallel_calls=18, inter_op_parallelism_threads=5, intra_op_parallelism_threads=10, data_format="channels_first") flags_core.set_defaults(**defaults) flags_core.parse_flags() for key, value in defaults.items(): assert flags.FLAGS.get_flag_value(name=key, default=None) == value def test_benchmark_setting(self): defaults = dict( hooks=["LoggingMetricHook"], benchmark_log_dir="/tmp/12345", gcp_project="project_abc", ) flags_core.set_defaults(**defaults) flags_core.parse_flags() for key, value in defaults.items(): assert flags.FLAGS.get_flag_value(name=key, default=None) == value def test_booleans(self): """Test to ensure boolean flags trigger as expected.""" flags_core.parse_flags([__file__, "--use_synthetic_data"]) assert flags.FLAGS.use_synthetic_data def test_parse_dtype_info(self): flags_core.parse_flags([__file__, "--dtype", "fp16"]) self.assertEqual(flags_core.get_tf_dtype(flags.FLAGS), tf.float16) self.assertEqual( flags_core.get_loss_scale(flags.FLAGS, default_for_fp16=2), 2) flags_core.parse_flags([__file__, "--dtype", "fp16", "--loss_scale", "5"]) self.assertEqual( flags_core.get_loss_scale(flags.FLAGS, default_for_fp16=2), 5) flags_core.parse_flags( [__file__, "--dtype", "fp16", "--loss_scale", "dynamic"]) self.assertEqual( flags_core.get_loss_scale(flags.FLAGS, default_for_fp16=2), "dynamic") flags_core.parse_flags([__file__, "--dtype", "fp32"]) self.assertEqual(flags_core.get_tf_dtype(flags.FLAGS), tf.float32) self.assertEqual( flags_core.get_loss_scale(flags.FLAGS, default_for_fp16=2), 1) flags_core.parse_flags([__file__, "--dtype", "fp32", "--loss_scale", "5"]) self.assertEqual( flags_core.get_loss_scale(flags.FLAGS, default_for_fp16=2), 5) with self.assertRaises(SystemExit): flags_core.parse_flags([__file__, "--dtype", "int8"]) with self.assertRaises(SystemExit): flags_core.parse_flags( [__file__, "--dtype", "fp16", "--loss_scale", "abc"]) def test_get_nondefault_flags_as_str(self): defaults = dict( clean=True, data_dir="abc", hooks=["LoggingTensorHook"], stop_threshold=1.5, use_synthetic_data=False) flags_core.set_defaults(**defaults) flags_core.parse_flags() expected_flags = "" self.assertEqual(flags_core.get_nondefault_flags_as_str(), expected_flags) flags.FLAGS.clean = False expected_flags += "--noclean" self.assertEqual(flags_core.get_nondefault_flags_as_str(), expected_flags) flags.FLAGS.data_dir = "xyz" expected_flags += " --data_dir=xyz" self.assertEqual(flags_core.get_nondefault_flags_as_str(), expected_flags) flags.FLAGS.hooks = ["aaa", "bbb", "ccc"] expected_flags += " --hooks=aaa,bbb,ccc" self.assertEqual(flags_core.get_nondefault_flags_as_str(), expected_flags) flags.FLAGS.stop_threshold = 3. expected_flags += " --stop_threshold=3.0" self.assertEqual(flags_core.get_nondefault_flags_as_str(), expected_flags) flags.FLAGS.use_synthetic_data = True expected_flags += " --use_synthetic_data" self.assertEqual(flags_core.get_nondefault_flags_as_str(), expected_flags) # Assert that explicit setting a flag to its default value does not cause it # to appear in the string flags.FLAGS.use_synthetic_data = False expected_flags = expected_flags[:-len(" --use_synthetic_data")] self.assertEqual(flags_core.get_nondefault_flags_as_str(), expected_flags) if __name__ == "__main__": unittest.main()