Spaces:
Sleeping
Sleeping
# Copyright 2023 The Orbit Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""A trainer object that can train models with a single output.""" | |
import orbit | |
import tensorflow as tf, tf_keras | |
class SingleTaskTrainer(orbit.StandardTrainer): | |
"""Trains a single-output model on a given dataset. | |
This trainer will handle running a model with one output on a single | |
dataset. It will apply the provided loss function to the model's output | |
to calculate gradients and will apply them via the provided optimizer. It will | |
also supply the output of that model to one or more `tf_keras.metrics.Metric` | |
objects. | |
""" | |
def __init__(self, | |
train_dataset, | |
label_key, | |
model, | |
loss_fn, | |
optimizer, | |
metrics=None, | |
trainer_options=None): | |
"""Initializes a `SingleTaskTrainer` instance. | |
If the `SingleTaskTrainer` should run its model under a distribution | |
strategy, it should be created within that strategy's scope. | |
This trainer will also calculate metrics during training. The loss metric | |
is calculated by default, but other metrics can be passed to the `metrics` | |
arg. | |
Arguments: | |
train_dataset: A `tf.data.Dataset` or `DistributedDataset` that contains a | |
string-keyed dict of `Tensor`s. | |
label_key: The key corresponding to the label value in feature | |
dictionaries dequeued from `train_dataset`. This key will be removed | |
from the dictionary before it is passed to the model. | |
model: A `tf.Module` or Keras `Model` object to evaluate. It must accept a | |
`training` kwarg. | |
loss_fn: A per-element loss function of the form (target, output). The | |
output of this loss function will be reduced via `tf.reduce_mean` to | |
create the final loss. We recommend using the functions in the | |
`tf_keras.losses` package or `tf_keras.losses.Loss` objects with | |
`reduction=tf_keras.losses.reduction.NONE`. | |
optimizer: A `tf_keras.optimizers.Optimizer` instance. | |
metrics: A single `tf_keras.metrics.Metric` object, or a list of | |
`tf_keras.metrics.Metric` objects. | |
trainer_options: An optional `orbit.utils.StandardTrainerOptions` object. | |
""" | |
self.label_key = label_key | |
self.model = model | |
self.loss_fn = loss_fn | |
self.optimizer = optimizer | |
# Capture the strategy from the containing scope. | |
self.strategy = tf.distribute.get_strategy() | |
# We always want to report training loss. | |
self.train_loss = tf_keras.metrics.Mean('training_loss', dtype=tf.float32) | |
# We need self.metrics to be an iterable later, so we handle that here. | |
if metrics is None: | |
self.metrics = [] | |
elif isinstance(metrics, list): | |
self.metrics = metrics | |
else: | |
self.metrics = [metrics] | |
super(SingleTaskTrainer, self).__init__( | |
train_dataset=train_dataset, options=trainer_options) | |
def train_loop_begin(self): | |
"""Actions to take once, at the beginning of each train loop.""" | |
self.train_loss.reset_states() | |
for metric in self.metrics: | |
metric.reset_states() | |
def train_step(self, iterator): | |
"""A train step. Called multiple times per train loop by the superclass.""" | |
def train_fn(inputs): | |
with tf.GradientTape() as tape: | |
# Extract the target value and delete it from the input dict, so that | |
# the model never sees it. | |
target = inputs.pop(self.label_key) | |
# Get the outputs of the model. | |
output = self.model(inputs, training=True) | |
# Get the average per-batch loss and scale it down by the number of | |
# replicas. This ensures that we don't end up multiplying our loss by | |
# the number of workers - gradients are summed, not averaged, across | |
# replicas during the apply_gradients call. | |
# Note, the reduction of loss is explicitly handled and scaled by | |
# num_replicas_in_sync. Recommend to use a plain loss function. | |
# If you're using tf_keras.losses.Loss object, you may need to set | |
# reduction argument explicitly. | |
loss = tf.reduce_mean(self.loss_fn(target, output)) | |
scaled_loss = loss / self.strategy.num_replicas_in_sync | |
# Get the gradients by applying the loss to the model's trainable | |
# variables. | |
gradients = tape.gradient(scaled_loss, self.model.trainable_variables) | |
# Apply the gradients via the optimizer. | |
self.optimizer.apply_gradients( | |
list(zip(gradients, self.model.trainable_variables))) | |
# Update metrics. | |
self.train_loss.update_state(loss) | |
for metric in self.metrics: | |
metric.update_state(target, output) | |
# This is needed to handle distributed computation. | |
self.strategy.run(train_fn, args=(next(iterator),)) | |
def train_loop_end(self): | |
"""Actions to take once after a training loop.""" | |
with self.strategy.scope(): | |
# Export the metrics. | |
metrics = {metric.name: metric.result() for metric in self.metrics} | |
metrics[self.train_loss.name] = self.train_loss.result() | |
return metrics | |