Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Tests for multitask.interleaving_trainer.""" | |
from absl.testing import parameterized | |
import tensorflow as tf, tf_keras | |
from tensorflow.python.distribute import combinations | |
from tensorflow.python.distribute import strategy_combinations | |
from official.modeling.multitask import configs | |
from official.modeling.multitask import interleaving_trainer | |
from official.modeling.multitask import multitask | |
from official.modeling.multitask import task_sampler | |
from official.modeling.multitask import test_utils | |
def all_strategy_combinations(): | |
return combinations.combine( | |
distribution=[ | |
strategy_combinations.default_strategy, | |
strategy_combinations.cloud_tpu_strategy, | |
strategy_combinations.one_device_strategy_gpu, | |
], | |
mode="eager", | |
) | |
class InterleavingTrainerTest(tf.test.TestCase, parameterized.TestCase): | |
def test_multitask_interleaving_trainer(self, distribution): | |
with distribution.scope(): | |
tasks = [ | |
test_utils.MockFooTask(params=test_utils.FooConfig(), name="foo"), | |
test_utils.MockBarTask(params=test_utils.BarConfig(), name="bar") | |
] | |
test_multitask = multitask.MultiTask(tasks=tasks) | |
test_optimizer = tf_keras.optimizers.SGD(0.1) | |
model = test_utils.MockMultiTaskModel() | |
sampler = task_sampler.UniformTaskSampler( | |
task_weights=test_multitask.task_weights) | |
test_trainer = interleaving_trainer.MultiTaskInterleavingTrainer( | |
multi_task=test_multitask, | |
multi_task_model=model, | |
optimizer=test_optimizer, | |
task_sampler=sampler) | |
results = test_trainer.train(tf.convert_to_tensor(5, dtype=tf.int32)) | |
self.assertContainsSubset(["training_loss", "bar_acc"], | |
results["bar"].keys()) | |
self.assertContainsSubset(["training_loss", "foo_acc"], | |
results["foo"].keys()) | |
self.assertNotIn("total_loss", results) | |
def test_trainer_with_configs(self, distribution): | |
config = configs.MultiTaskConfig( | |
task_routines=(configs.TaskRoutine( | |
task_name="foo", | |
task_config=test_utils.FooConfig(), | |
task_weight=3.0), | |
configs.TaskRoutine( | |
task_name="bar", | |
task_config=test_utils.BarConfig(), | |
task_weight=1.0))) | |
with distribution.scope(): | |
test_multitask = multitask.MultiTask.from_config(config) | |
test_optimizer = tf_keras.optimizers.SGD(0.1) | |
model = test_utils.MockMultiTaskModel() | |
num_step = 1000 | |
sampler = task_sampler.AnnealingTaskSampler( | |
task_weights=test_multitask.task_weights, | |
steps_per_epoch=num_step/5, | |
total_steps=num_step) | |
test_trainer = interleaving_trainer.MultiTaskInterleavingTrainer( | |
multi_task=test_multitask, | |
multi_task_model=model, | |
optimizer=test_optimizer, | |
task_sampler=sampler) | |
results = test_trainer.train(tf.convert_to_tensor(num_step, dtype=tf.int32)) | |
self.assertContainsSubset(["training_loss", "bar_acc"], | |
results["bar"].keys()) | |
self.assertContainsSubset(["training_loss", "foo_acc"], | |
results["foo"].keys()) | |
self.assertEqual(test_trainer.global_step.numpy(), num_step) | |
bar_sampled_step = test_trainer.task_step_counter("bar").numpy() | |
foo_sampled_step = test_trainer.task_step_counter("foo").numpy() | |
self.assertEqual(bar_sampled_step + foo_sampled_step, num_step) | |
if __name__ == "__main__": | |
tf.test.main() | |