Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Multitask trainer that interleaves each task's train step.""" | |
from typing import Union | |
import gin | |
import orbit | |
import tensorflow as tf, tf_keras | |
from official.modeling.multitask import base_model | |
from official.modeling.multitask import base_trainer | |
from official.modeling.multitask import multitask | |
from official.modeling.multitask import task_sampler as sampler | |
class MultiTaskInterleavingTrainer(base_trainer.MultiTaskBaseTrainer): | |
"""MultiTask trainer that interleaves task update.""" | |
def __init__(self, | |
multi_task: multitask.MultiTask, | |
multi_task_model: Union[tf_keras.Model, | |
base_model.MultiTaskBaseModel], | |
optimizer: Union[tf.optimizers.Optimizer, | |
tf_keras.optimizers.experimental.Optimizer, | |
tf_keras.optimizers.legacy.Optimizer], | |
task_sampler: sampler.TaskSampler, | |
trainer_options=None): | |
super().__init__( | |
multi_task=multi_task, | |
multi_task_model=multi_task_model, | |
optimizer=optimizer, | |
trainer_options=trainer_options) | |
self._task_sampler = task_sampler | |
# Build per task train step. | |
def _get_task_step(task_name, task): | |
def step_fn(inputs): | |
if isinstance(self.multi_task_model, base_model.MultiTaskBaseModel): | |
task_model = self.multi_task_model.sub_tasks[task_name] | |
else: | |
task_model = self.multi_task_model | |
task_logs = task.train_step( | |
inputs, | |
model=task_model, | |
optimizer=self.optimizer, | |
metrics=self.training_metrics[task_name]) | |
self.training_losses[task_name].update_state(task_logs[task.loss]) | |
return step_fn | |
self._task_train_step_map = { | |
name: _get_task_step(name, task) | |
for name, task in self.multi_task.tasks.items() | |
} | |
# TODO(haozhangthu): Add taskwise step counter to train_loop_end for logging | |
# on TensorBoard. | |
self._task_step_counters = { | |
name: orbit.utils.create_global_step() for name in self.multi_task.tasks | |
} | |
# If the new Keras optimizer is used, we require all model variables are | |
# created before the training and let the optimizer to create the slot | |
# variable all together. | |
if isinstance(optimizer, tf_keras.optimizers.experimental.Optimizer): | |
multi_task_model.build() | |
optimizer.build(multi_task_model.trainable_variables) | |
def task_step_counter(self, name): | |
return self._task_step_counters[name] | |
def train_step(self, iterator_map): | |
# Sample one task to train according to a multinomial distribution | |
rn = tf.random.stateless_uniform(shape=[], seed=(0, self.global_step)) | |
cumulative_sample_distribution = self._task_sampler.task_cumulative_distribution( | |
self.global_step) | |
# Prepend a [0.0] for indexing convenience. | |
cumulative_sample_distribution = tf.concat( | |
[tf.constant([0.0], dtype=tf.float32), cumulative_sample_distribution], | |
axis=0) | |
for idx, (name, _) in enumerate(self.multi_task.tasks.items()): | |
begin = cumulative_sample_distribution[idx] | |
end = cumulative_sample_distribution[idx + 1] | |
if rn >= begin and rn < end: | |
self._strategy.run( | |
self._task_train_step_map[name], args=(next(iterator_map[name]),)) | |
self.global_step.assign_add(1) | |
self.task_step_counter(name).assign_add(1) | |
def train_loop_end(self): | |
"""Record loss and metric values per task.""" | |
result = super().train_loop_end() | |
# Interleaving training does not have a good semantic for `total_loss`. In | |
# fact, it is always zero. To avoid confusion, we filter the `total_loss` | |
# from the result logs. | |
if 'total_loss' in result: | |
result.pop('total_loss') | |
return result | |