Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Tests for the progressive train_lib.""" | |
import os | |
from absl import flags | |
from absl.testing import parameterized | |
import dataclasses | |
import orbit | |
import tensorflow as tf, tf_keras | |
from tensorflow.python.distribute import combinations | |
from tensorflow.python.distribute import strategy_combinations | |
from official.common import flags as tfm_flags | |
# pylint: disable=unused-import | |
from official.common import registry_imports | |
# pylint: enable=unused-import | |
from official.core import config_definitions as cfg | |
from official.core import task_factory | |
from official.modeling import optimization | |
from official.modeling.hyperparams import params_dict | |
from official.modeling.fast_training.progressive import policies | |
from official.modeling.fast_training.progressive import train_lib | |
from official.modeling.fast_training.progressive import trainer as prog_trainer_lib | |
from official.utils.testing import mock_task | |
FLAGS = flags.FLAGS | |
tfm_flags.define_flags() | |
class ProgTaskConfig(cfg.TaskConfig): | |
pass | |
class ProgMockTask(policies.ProgressivePolicy, mock_task.MockTask): | |
"""Progressive task for testing.""" | |
def __init__(self, params: cfg.TaskConfig, logging_dir: str = None): | |
mock_task.MockTask.__init__( | |
self, params=params, logging_dir=logging_dir) | |
policies.ProgressivePolicy.__init__(self) | |
def num_stages(self): | |
return 2 | |
def num_steps(self, stage_id): | |
return 2 if stage_id == 0 else 4 | |
def get_model(self, stage_id, old_model=None): | |
del stage_id, old_model | |
return self.build_model() | |
def get_optimizer(self, stage_id): | |
"""Build optimizer for each stage.""" | |
params = optimization.OptimizationConfig({ | |
'optimizer': { | |
'type': 'adamw', | |
}, | |
'learning_rate': { | |
'type': 'polynomial', | |
'polynomial': { | |
'initial_learning_rate': 0.01, | |
'end_learning_rate': 0.0, | |
'power': 1.0, | |
'decay_steps': 10, | |
}, | |
}, | |
'warmup': { | |
'polynomial': { | |
'power': 1, | |
'warmup_steps': 2, | |
}, | |
'type': 'polynomial', | |
} | |
}) | |
opt_factory = optimization.OptimizerFactory(params) | |
optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate()) | |
return optimizer | |
def get_train_dataset(self, stage_id): | |
del stage_id | |
strategy = tf.distribute.get_strategy() | |
return orbit.utils.make_distributed_dataset( | |
strategy, self.build_inputs, None) | |
def get_eval_dataset(self, stage_id): | |
del stage_id | |
strategy = tf.distribute.get_strategy() | |
return orbit.utils.make_distributed_dataset( | |
strategy, self.build_inputs, None) | |
class TrainTest(tf.test.TestCase, parameterized.TestCase): | |
def setUp(self): | |
super(TrainTest, self).setUp() | |
self._test_config = { | |
'trainer': { | |
'checkpoint_interval': 10, | |
'steps_per_loop': 10, | |
'summary_interval': 10, | |
'train_steps': 10, | |
'validation_steps': 5, | |
'validation_interval': 10, | |
'continuous_eval_timeout': 1, | |
'optimizer_config': { | |
'optimizer': { | |
'type': 'sgd', | |
}, | |
'learning_rate': { | |
'type': 'constant' | |
} | |
} | |
}, | |
} | |
def test_end_to_end(self, distribution_strategy, flag_mode, run_post_eval): | |
model_dir = self.get_temp_dir() | |
experiment_config = cfg.ExperimentConfig( | |
trainer=prog_trainer_lib.ProgressiveTrainerConfig(), | |
task=ProgTaskConfig()) | |
experiment_config = params_dict.override_params_dict( | |
experiment_config, self._test_config, is_strict=False) | |
with distribution_strategy.scope(): | |
task = task_factory.get_task(experiment_config.task, | |
logging_dir=model_dir) | |
_, logs = train_lib.run_experiment( | |
distribution_strategy=distribution_strategy, | |
task=task, | |
mode=flag_mode, | |
params=experiment_config, | |
model_dir=model_dir, | |
run_post_eval=run_post_eval) | |
if run_post_eval: | |
self.assertNotEmpty(logs) | |
else: | |
self.assertEmpty(logs) | |
if flag_mode == 'eval': | |
return | |
self.assertNotEmpty( | |
tf.io.gfile.glob(os.path.join(model_dir, 'checkpoint'))) | |
# Tests continuous evaluation. | |
_, logs = train_lib.run_experiment( | |
distribution_strategy=distribution_strategy, | |
task=task, | |
mode='continuous_eval', | |
params=experiment_config, | |
model_dir=model_dir, | |
run_post_eval=run_post_eval) | |
print(logs) | |
if __name__ == '__main__': | |
tf.test.main() | |