Spaces:
Runtime error
Runtime error
File size: 6,078 Bytes
5672777 93528c6 5672777 93528c6 5672777 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Multitask Evaluator implementation.
The evaluator implements the Orbit `AbstractEvaluator` interface.
"""
from typing import Dict, List, Optional, Union
import gin
import orbit
import tensorflow as tf, tf_keras
from official.core import base_task
from official.core import train_utils
from official.modeling.multitask import base_model
@gin.configurable
class MultiTaskEvaluator(orbit.AbstractEvaluator):
"""Implements the common trainer shared for TensorFlow models."""
def __init__(
self,
eval_tasks: List[base_task.Task],
model: Union[tf_keras.Model, base_model.MultiTaskBaseModel],
global_step: Optional[tf.Variable] = None,
eval_steps: Optional[Dict[str, int]] = None,
checkpoint_exporter: Optional[train_utils.BestCheckpointExporter] = None):
"""Initialize common trainer for TensorFlow models.
Args:
eval_tasks: A list of tasks to evaluate.
model: tf_keras.Model instance.
global_step: the global step variable.
eval_steps: a dictionary of steps to run eval keyed by task names.
checkpoint_exporter: an object that has the `maybe_export_checkpoint`
interface.
"""
# Gets the current distribution strategy. If not inside any strategy scope,
# it gets a single-replica no-op strategy.
self._strategy = tf.distribute.get_strategy()
self._tasks = eval_tasks
self._model = model
self._global_step = global_step or orbit.utils.create_global_step()
self._checkpoint_exporter = checkpoint_exporter
if hasattr(self.model, "checkpoint_items"):
checkpoint_items = self.model.checkpoint_items
else:
checkpoint_items = {}
self._checkpoint = tf.train.Checkpoint(
model=self.model,
global_step=self.global_step,
**checkpoint_items)
self._validation_losses = None
self._validation_metrics = None
# Builds per-task datasets.
self.eval_datasets = {}
self.eval_steps = eval_steps or {}
for task in self.tasks:
self.eval_datasets[task.name] = orbit.utils.make_distributed_dataset(
self.strategy, task.build_inputs, task.task_config.validation_data)
# Builds per-task validation loops.
def get_function(task_name, task):
task_metrics = self.validation_metrics[task_name]
task_loss = self.validation_losses[task_name]
if isinstance(self.model, base_model.MultiTaskBaseModel):
model = self.model.sub_tasks[task_name]
else:
model = self.model
def step_fn(inputs):
logs = task.validation_step(inputs, model=model, metrics=task_metrics)
task_loss.update_state(logs[task.loss])
return logs
@tf.function
def eval_step_fn(iterator):
distributed_outputs = self.strategy.run(step_fn, args=(next(iterator),))
return tf.nest.map_structure(self.strategy.experimental_local_results,
distributed_outputs)
return orbit.utils.create_loop_fn(eval_step_fn)
self.task_fns = {
task.name: get_function(task.name, task) for task in self.tasks
}
@property
def strategy(self):
return self._strategy
@property
def tasks(self):
return self._tasks
@property
def model(self):
return self._model
@property
def global_step(self):
return self._global_step
@property
def validation_losses(self):
"""Accesses the validation loss metric object."""
if self._validation_losses is None:
# Builds the per-task metrics and losses.
self._validation_losses = {}
for task in self.tasks:
self._validation_losses[task.name] = tf_keras.metrics.Mean(
"validation_loss", dtype=tf.float32)
return self._validation_losses
@property
def validation_metrics(self):
"""Accesses all validation metric metric objects."""
if self._validation_metrics is None:
# Builds the per-task metrics and losses.
self._validation_metrics = {}
for task in self.tasks:
self._validation_metrics[task.name] = task.build_metrics(training=False)
return self._validation_metrics
@property
def checkpoint(self):
"""Accesses the training checkpoint."""
return self._checkpoint
def evaluate(self, num_steps: tf.Tensor):
"""Performs evaluation for each `EvalTask`."""
for metric in self.validation_losses.values():
metric.reset_states()
for metrics in self.validation_metrics.values():
for metric in metrics:
metric.reset_states()
results = {}
eval_iters = tf.nest.map_structure(iter, self.eval_datasets)
for task in self.tasks:
outputs = None
name = task.name
eval_iter = eval_iters[name]
task_eval_steps = self.eval_steps.get(name, None) or num_steps
outputs = self.task_fns[name](
eval_iter,
task_eval_steps,
state=outputs,
reduce_fn=task.aggregate_logs)
task_metrics = self.validation_metrics[name]
task_loss = self.validation_losses[name]
logs = {}
for metric in task_metrics + [task_loss]:
logs[metric.name] = metric.result()
if outputs:
metrics = task.reduce_aggregated_logs(
outputs, global_step=self.global_step)
logs.update(metrics)
results[name] = logs
if self._checkpoint_exporter:
self._checkpoint_exporter.maybe_export_checkpoint(
self.checkpoint, results, self.global_step.numpy())
return results
|