Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Utils to sample tasks for interleaved optimization.""" | |
import abc | |
from typing import Union, Dict, Text | |
import tensorflow as tf, tf_keras | |
from official.modeling.multitask import configs | |
class TaskSampler(tf.Module, metaclass=abc.ABCMeta): | |
"""An abstract class defining task sampling API for interleaving trainer.""" | |
def __init__(self, task_weights: Dict[Text, Union[float, int]]): | |
self._task_weights = task_weights | |
def task_weights(self): | |
return self._task_weights | |
def task_cumulative_distribution(self, global_step: tf.Tensor) -> tf.Tensor: | |
"""Compute cumulative distribution to sample tasks. | |
It calculates the cumulative distribution of the multinomial task | |
distribution with respect to which to be sampled against. | |
Args: | |
global_step: A tensor indicating current progess of training. | |
Returns: | |
A float tensor with shape (#(task), 1) that represents the cumulative | |
sampling distribution. | |
""" | |
pass | |
class UniformTaskSampler(TaskSampler): | |
"""Sample all tasks uniformly.""" | |
def __init__(self, task_weights: Dict[Text, Union[float, int]]): | |
super(UniformTaskSampler, self).__init__(task_weights=task_weights) | |
self._uniform_cumulative = tf.math.cumsum( | |
tf.constant( | |
[1.0 / len(self._task_weights)] * len(self._task_weights), | |
dtype=tf.float32)) | |
def task_cumulative_distribution(self, global_step: tf.Tensor) -> tf.Tensor: | |
del global_step | |
return self._uniform_cumulative | |
class ProportionalTaskSampler(TaskSampler): | |
"""Sample tasks proportional to task weights.""" | |
def __init__(self, | |
task_weights: Dict[Text, Union[float, int]], | |
alpha: float = 1.0): | |
super(ProportionalTaskSampler, self).__init__(task_weights=task_weights) | |
self._alpha = tf.cast(alpha, dtype=tf.float32) | |
task_weight_dict_ordered_list = tf.constant( | |
[weight for _, weight in self._task_weights.items()], dtype=tf.float32) | |
task_sizes = tf.math.pow(task_weight_dict_ordered_list, self._alpha) | |
task_distribution = task_sizes / tf.reduce_sum(task_sizes) | |
self._porportional_cumulative = tf.math.cumsum(task_distribution) | |
def task_cumulative_distribution(self, global_step: tf.Tensor) -> tf.Tensor: | |
del global_step | |
return self._porportional_cumulative | |
class AnnealingTaskSampler(TaskSampler): | |
"""Sample tasks according to task weights as well as training progress. | |
See http://proceedings.mlr.press/v97/stickland19a/stickland19a.pdf | |
""" | |
def __init__(self, | |
task_weights: Dict[Text, Union[float, int]], | |
steps_per_epoch: int, | |
total_steps: int): | |
super(AnnealingTaskSampler, self).__init__(task_weights=task_weights) | |
self._steps_per_epoch = tf.cast(steps_per_epoch, dtype=tf.float32) | |
self._total_epochs = tf.cast( | |
total_steps / self._steps_per_epoch, dtype=tf.float32) | |
def task_cumulative_distribution(self, global_step: tf.Tensor) -> tf.Tensor: | |
cur_epoch = tf.math.floor( | |
tf.cast(global_step, dtype=tf.float32) / self._steps_per_epoch) | |
alpha = 1.0 - 0.8 * (cur_epoch - 1) / (self._total_epochs - 1 + 1e-10) | |
task_weight_dict_ordered_list = [ | |
weight for _, weight in self._task_weights.items() | |
] | |
task_sizes = tf.math.pow( | |
tf.constant(task_weight_dict_ordered_list, dtype=tf.float32), | |
tf.cast(alpha, dtype=tf.float32)) | |
dynamic_task_distribution = task_sizes / tf.reduce_sum(task_sizes) | |
return tf.math.cumsum(dynamic_task_distribution) | |
def get_task_sampler(config: configs.TaskSamplingConfig, | |
task_weights: Dict[Text, float]) -> TaskSampler: | |
"""Utils to create task sampler with configuration and task weights.""" | |
oneof_config = config.get() | |
if config.type == 'uniform': | |
return UniformTaskSampler(task_weights=task_weights) | |
elif config.type == 'proportional': | |
return ProportionalTaskSampler( | |
task_weights=task_weights, alpha=oneof_config.alpha) | |
elif config.type == 'annealing': | |
return AnnealingTaskSampler( | |
task_weights=task_weights, | |
steps_per_epoch=oneof_config.steps_per_epoch, | |
total_steps=oneof_config.total_steps) | |
else: | |
raise RuntimeError('Task sampler type not supported') | |