deanna-emery's picture
updates
93528c6
raw
history blame
5.76 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the progressive train_lib."""
import os
from absl import flags
from absl.testing import parameterized
import dataclasses
import orbit
import tensorflow as tf, tf_keras
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.common import flags as tfm_flags
# pylint: disable=unused-import
from official.common import registry_imports
# pylint: enable=unused-import
from official.core import config_definitions as cfg
from official.core import task_factory
from official.modeling import optimization
from official.modeling.hyperparams import params_dict
from official.modeling.fast_training.progressive import policies
from official.modeling.fast_training.progressive import train_lib
from official.modeling.fast_training.progressive import trainer as prog_trainer_lib
from official.utils.testing import mock_task
FLAGS = flags.FLAGS
tfm_flags.define_flags()
@dataclasses.dataclass
class ProgTaskConfig(cfg.TaskConfig):
pass
@task_factory.register_task_cls(ProgTaskConfig)
class ProgMockTask(policies.ProgressivePolicy, mock_task.MockTask):
"""Progressive task for testing."""
def __init__(self, params: cfg.TaskConfig, logging_dir: str = None):
mock_task.MockTask.__init__(
self, params=params, logging_dir=logging_dir)
policies.ProgressivePolicy.__init__(self)
def num_stages(self):
return 2
def num_steps(self, stage_id):
return 2 if stage_id == 0 else 4
def get_model(self, stage_id, old_model=None):
del stage_id, old_model
return self.build_model()
def get_optimizer(self, stage_id):
"""Build optimizer for each stage."""
params = optimization.OptimizationConfig({
'optimizer': {
'type': 'adamw',
},
'learning_rate': {
'type': 'polynomial',
'polynomial': {
'initial_learning_rate': 0.01,
'end_learning_rate': 0.0,
'power': 1.0,
'decay_steps': 10,
},
},
'warmup': {
'polynomial': {
'power': 1,
'warmup_steps': 2,
},
'type': 'polynomial',
}
})
opt_factory = optimization.OptimizerFactory(params)
optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
return optimizer
def get_train_dataset(self, stage_id):
del stage_id
strategy = tf.distribute.get_strategy()
return orbit.utils.make_distributed_dataset(
strategy, self.build_inputs, None)
def get_eval_dataset(self, stage_id):
del stage_id
strategy = tf.distribute.get_strategy()
return orbit.utils.make_distributed_dataset(
strategy, self.build_inputs, None)
class TrainTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super(TrainTest, self).setUp()
self._test_config = {
'trainer': {
'checkpoint_interval': 10,
'steps_per_loop': 10,
'summary_interval': 10,
'train_steps': 10,
'validation_steps': 5,
'validation_interval': 10,
'continuous_eval_timeout': 1,
'optimizer_config': {
'optimizer': {
'type': 'sgd',
},
'learning_rate': {
'type': 'constant'
}
}
},
}
@combinations.generate(
combinations.combine(
distribution_strategy=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],
flag_mode=['train', 'eval', 'train_and_eval'],
run_post_eval=[True, False]))
def test_end_to_end(self, distribution_strategy, flag_mode, run_post_eval):
model_dir = self.get_temp_dir()
experiment_config = cfg.ExperimentConfig(
trainer=prog_trainer_lib.ProgressiveTrainerConfig(),
task=ProgTaskConfig())
experiment_config = params_dict.override_params_dict(
experiment_config, self._test_config, is_strict=False)
with distribution_strategy.scope():
task = task_factory.get_task(experiment_config.task,
logging_dir=model_dir)
_, logs = train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode=flag_mode,
params=experiment_config,
model_dir=model_dir,
run_post_eval=run_post_eval)
if run_post_eval:
self.assertNotEmpty(logs)
else:
self.assertEmpty(logs)
if flag_mode == 'eval':
return
self.assertNotEmpty(
tf.io.gfile.glob(os.path.join(model_dir, 'checkpoint')))
# Tests continuous evaluation.
_, logs = train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode='continuous_eval',
params=experiment_config,
model_dir=model_dir,
run_post_eval=run_post_eval)
print(logs)
if __name__ == '__main__':
tf.test.main()