Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Tests for multitask.train_lib.""" | |
from absl.testing import parameterized | |
import tensorflow as tf, tf_keras | |
from tensorflow.python.distribute import combinations | |
from tensorflow.python.distribute import strategy_combinations | |
from official.core import task_factory | |
from official.modeling.hyperparams import params_dict | |
from official.modeling.multitask import configs | |
from official.modeling.multitask import multitask | |
from official.modeling.multitask import test_utils | |
from official.modeling.multitask import train_lib | |
class TrainLibTest(tf.test.TestCase, parameterized.TestCase): | |
def setUp(self): | |
super().setUp() | |
self._test_config = { | |
'trainer': { | |
'checkpoint_interval': 10, | |
'steps_per_loop': 10, | |
'summary_interval': 10, | |
'train_steps': 10, | |
'validation_steps': 5, | |
'validation_interval': 10, | |
'continuous_eval_timeout': 1, | |
'optimizer_config': { | |
'optimizer': { | |
'type': 'sgd', | |
}, | |
'learning_rate': { | |
'type': 'constant' | |
} | |
} | |
}, | |
} | |
def test_end_to_end(self, distribution_strategy, optimizer, flag_mode): | |
model_dir = self.get_temp_dir() | |
experiment_config = configs.MultiTaskExperimentConfig( | |
task=configs.MultiTaskConfig( | |
task_routines=( | |
configs.TaskRoutine( | |
task_name='foo', task_config=test_utils.FooConfig()), | |
configs.TaskRoutine( | |
task_name='bar', task_config=test_utils.BarConfig())))) | |
experiment_config = params_dict.override_params_dict( | |
experiment_config, self._test_config, is_strict=False) | |
experiment_config.trainer.optimizer_config.optimizer.type = optimizer | |
with distribution_strategy.scope(): | |
test_multitask = multitask.MultiTask.from_config(experiment_config.task) | |
model = test_utils.MockMultiTaskModel() | |
train_lib.run_experiment( | |
distribution_strategy=distribution_strategy, | |
task=test_multitask, | |
model=model, | |
mode=flag_mode, | |
params=experiment_config, | |
model_dir=model_dir) | |
def test_end_to_end_multi_eval(self, distribution_strategy, flag_mode): | |
model_dir = self.get_temp_dir() | |
experiment_config = configs.MultiEvalExperimentConfig( | |
task=test_utils.FooConfig(), | |
eval_tasks=(configs.TaskRoutine( | |
task_name='foo', task_config=test_utils.FooConfig(), eval_steps=2), | |
configs.TaskRoutine( | |
task_name='bar', | |
task_config=test_utils.BarConfig(), | |
eval_steps=3))) | |
experiment_config = params_dict.override_params_dict( | |
experiment_config, self._test_config, is_strict=False) | |
with distribution_strategy.scope(): | |
train_task = task_factory.get_task(experiment_config.task) | |
eval_tasks = [ | |
task_factory.get_task(config.task_config, name=config.task_name) | |
for config in experiment_config.eval_tasks | |
] | |
train_lib.run_experiment_with_multitask_eval( | |
distribution_strategy=distribution_strategy, | |
train_task=train_task, | |
eval_tasks=eval_tasks, | |
mode=flag_mode, | |
params=experiment_config, | |
model_dir=model_dir) | |
if __name__ == '__main__': | |
tf.test.main() | |