Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Defines the base task abstraction.""" | |
import abc | |
import functools | |
from typing import Optional | |
from absl import logging | |
import tensorflow as tf, tf_keras | |
from official.core import config_definitions | |
from official.modeling import optimization | |
from official.modeling import performance | |
from official.modeling.privacy import configs | |
from official.modeling.privacy import ops | |
OptimizationConfig = optimization.OptimizationConfig | |
RuntimeConfig = config_definitions.RuntimeConfig | |
DifferentialPrivacyConfig = configs.DifferentialPrivacyConfig | |
class Task(tf.Module, metaclass=abc.ABCMeta): | |
"""A single-replica view of training procedure. | |
Tasks provide artifacts for training/validation procedures, including | |
loading/iterating over Datasets, training/validation steps, calculating the | |
loss and customized metrics with reduction. | |
""" | |
# Special keys in train/validate step returned logs. | |
loss = "loss" | |
def __init__(self, | |
params, | |
logging_dir: Optional[str] = None, | |
name: Optional[str] = None): | |
"""Task initialization. | |
Args: | |
params: the task configuration instance, which can be any of dataclass, | |
ConfigDict, namedtuple, etc. | |
logging_dir: a string pointing to where the model, summaries etc. will be | |
saved. You can also write additional stuff in this directory. | |
name: the task name. | |
""" | |
super().__init__(name=name) | |
self._task_config = params | |
self._logging_dir = ( | |
logging_dir or "" | |
) # Empty directory hints current working dir. | |
def task_config(self): | |
return self._task_config | |
def logging_dir(self) -> str: | |
return self._logging_dir | |
def create_optimizer(cls, optimizer_config: OptimizationConfig, | |
runtime_config: Optional[RuntimeConfig] = None, | |
dp_config: Optional[DifferentialPrivacyConfig] = None): | |
"""Creates an TF optimizer from configurations. | |
Args: | |
optimizer_config: the parameters of the Optimization settings. | |
runtime_config: the parameters of the runtime. | |
dp_config: the parameter of differential privacy. | |
Returns: | |
A tf.optimizers.Optimizer object. | |
""" | |
gradient_transformers = None | |
if dp_config is not None: | |
logging.info("Adding differential privacy transform with config %s.", | |
dp_config.as_dict()) | |
noise_stddev = dp_config.clipping_norm * dp_config.noise_multiplier | |
gradient_transformers = [ | |
functools.partial( | |
ops.clip_l2_norm, l2_norm_clip=dp_config.clipping_norm), | |
functools.partial( | |
ops.add_noise, noise_stddev=noise_stddev) | |
] | |
opt_factory = optimization.OptimizerFactory(optimizer_config) | |
optimizer = opt_factory.build_optimizer( | |
opt_factory.build_learning_rate(), | |
gradient_transformers=gradient_transformers | |
) | |
# Configuring optimizer when loss_scale is set in runtime config. This helps | |
# avoiding overflow/underflow for float16 computations. | |
if runtime_config: | |
optimizer = performance.configure_optimizer( | |
optimizer, | |
use_float16=runtime_config.mixed_precision_dtype == "float16", | |
loss_scale=runtime_config.loss_scale) | |
return optimizer | |
def initialize(self, model: tf_keras.Model): | |
"""[Optional] A callback function used as CheckpointManager's init_fn. | |
This function will be called when no checkpoint is found for the model. | |
If there is a checkpoint, the checkpoint will be loaded and this function | |
will not be called. You can use this callback function to load a pretrained | |
checkpoint, saved under a directory other than the model_dir. | |
Args: | |
model: The keras.Model built or used by this task. | |
""" | |
ckpt_dir_or_file = self.task_config.init_checkpoint | |
logging.info("Trying to load pretrained checkpoint from %s", | |
ckpt_dir_or_file) | |
if ckpt_dir_or_file and tf.io.gfile.isdir(ckpt_dir_or_file): | |
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file) | |
if not ckpt_dir_or_file: | |
logging.info("No checkpoint file found from %s. Will not load.", | |
ckpt_dir_or_file) | |
return | |
if hasattr(model, "checkpoint_items"): | |
checkpoint_items = model.checkpoint_items | |
else: | |
checkpoint_items = dict(model=model) | |
ckpt = tf.train.Checkpoint(**checkpoint_items) | |
status = ckpt.read(ckpt_dir_or_file) | |
status.expect_partial().assert_existing_objects_matched() | |
logging.info("Finished loading pretrained checkpoint from %s", | |
ckpt_dir_or_file) | |
def build_model(self) -> tf_keras.Model: | |
"""[Optional] Creates model architecture. | |
Returns: | |
A model instance. | |
""" # pytype: disable=bad-return-type # typed-keras | |
def build_inputs(self, | |
params, | |
input_context: Optional[tf.distribute.InputContext] = None): | |
"""Returns a dataset or a nested structure of dataset functions. | |
Dataset functions define per-host datasets with the per-replica batch size. | |
With distributed training, this method runs on remote hosts. | |
Args: | |
params: hyperparams to create input pipelines, which can be any of | |
dataclass, ConfigDict, namedtuple, etc. | |
input_context: optional distribution input pipeline context. | |
Returns: | |
A nested structure of per-replica input functions. | |
""" | |
def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor: | |
"""Standard interface to compute losses. | |
Args: | |
labels: optional label tensors. | |
model_outputs: a nested structure of output tensors. | |
aux_losses: auxiliary loss tensors, i.e. `losses` in keras.Model. | |
Returns: | |
The total loss tensor. | |
""" | |
del model_outputs, labels | |
if aux_losses is None: | |
losses = [tf.constant(0.0, dtype=tf.float32)] | |
else: | |
losses = aux_losses | |
total_loss = tf.add_n(losses) | |
return total_loss | |
def build_metrics(self, training: bool = True): | |
"""Gets streaming metrics for training/validation.""" | |
del training | |
return [] | |
def process_metrics(self, metrics, labels, model_outputs, **kwargs): | |
"""Process and update metrics. | |
Called when using custom training loop API. | |
Args: | |
metrics: a nested structure of metrics objects. The return of function | |
self.build_metrics. | |
labels: a tensor or a nested structure of tensors. | |
model_outputs: a tensor or a nested structure of tensors. For example, | |
output of the keras model built by self.build_model. | |
**kwargs: other args. | |
""" | |
for metric in metrics: | |
metric.update_state(labels, model_outputs) | |
def process_compiled_metrics(self, compiled_metrics, labels, model_outputs): | |
"""Process and update compiled_metrics. | |
call when using compile/fit API. | |
Args: | |
compiled_metrics: the compiled metrics (model.compiled_metrics). | |
labels: a tensor or a nested structure of tensors. | |
model_outputs: a tensor or a nested structure of tensors. For example, | |
output of the keras model built by self.build_model. | |
""" | |
compiled_metrics.update_state(labels, model_outputs) | |
def train_step(self, | |
inputs, | |
model: tf_keras.Model, | |
optimizer: tf_keras.optimizers.Optimizer, | |
metrics=None): | |
"""Does forward and backward. | |
With distribution strategies, this method runs on devices. | |
Args: | |
inputs: a dictionary of input tensors. | |
model: the model, forward pass definition. | |
optimizer: the optimizer for this training step. | |
metrics: a nested structure of metrics objects. | |
Returns: | |
A dictionary of logs. | |
""" | |
if isinstance(inputs, tuple) and len(inputs) == 2: | |
features, labels = inputs | |
else: | |
features, labels = inputs, inputs | |
with tf.GradientTape() as tape: | |
outputs = model(features, training=True) | |
# Computes per-replica loss. | |
if model.compiled_loss: | |
loss = model.compiled_loss( | |
labels, outputs, regularization_losses=model.losses) | |
loss += self.build_losses( | |
labels=labels, model_outputs=outputs, aux_losses=None) | |
else: | |
loss = self.build_losses( | |
labels=labels, model_outputs=outputs, aux_losses=model.losses) | |
# Scales loss as the default gradients allreduce performs sum inside the | |
# optimizer. | |
scaled_loss = loss / tf.distribute.get_strategy().num_replicas_in_sync | |
# For mixed precision, when a LossScaleOptimizer is used, the loss is | |
# scaled to avoid numeric underflow. | |
if isinstance(optimizer, | |
tf_keras.mixed_precision.LossScaleOptimizer): | |
scaled_loss = optimizer.get_scaled_loss(scaled_loss) | |
tvars = model.trainable_variables | |
grads = tape.gradient(scaled_loss, tvars) | |
if isinstance(optimizer, | |
tf_keras.mixed_precision.LossScaleOptimizer): | |
grads = optimizer.get_unscaled_gradients(grads) | |
optimizer.apply_gradients(list(zip(grads, tvars))) | |
logs = {self.loss: loss} | |
if metrics: | |
self.process_metrics(metrics, labels, outputs) | |
if model.compiled_metrics: | |
self.process_compiled_metrics(model.compiled_metrics, labels, outputs) | |
logs.update({m.name: m.result() for m in metrics or []}) | |
logs.update({m.name: m.result() for m in model.metrics}) | |
return logs | |
def validation_step(self, inputs, model: tf_keras.Model, metrics=None): | |
"""Validation step. | |
With distribution strategies, this method runs on devices. | |
Args: | |
inputs: a dictionary of input tensors. | |
model: the keras.Model. | |
metrics: a nested structure of metrics objects. | |
Returns: | |
A dictionary of logs. | |
""" | |
if isinstance(inputs, tuple) and len(inputs) == 2: | |
features, labels = inputs | |
else: | |
features, labels = inputs, inputs | |
outputs = self.inference_step(features, model) | |
loss = self.build_losses( | |
labels=labels, model_outputs=outputs, aux_losses=model.losses) | |
logs = {self.loss: loss} | |
if metrics: | |
self.process_metrics(metrics, labels, outputs) | |
if model.compiled_metrics: | |
self.process_compiled_metrics(model.compiled_metrics, labels, outputs) | |
logs.update({m.name: m.result() for m in metrics or []}) | |
logs.update({m.name: m.result() for m in model.metrics}) | |
return logs | |
def inference_step(self, inputs, model: tf_keras.Model): | |
"""Performs the forward step. | |
With distribution strategies, this method runs on devices. | |
Args: | |
inputs: a dictionary of input tensors. | |
model: the keras.Model. | |
Returns: | |
Model outputs. | |
""" | |
return model(inputs, training=False) | |
def aggregate_logs(self, state, step_logs): | |
"""Optional aggregation over logs returned from a validation step. | |
Given step_logs from a validation step, this function aggregates the logs | |
after each eval_step() (see eval_reduce() function in | |
official/core/base_trainer.py). It runs on CPU and can be used to aggregate | |
metrics during validation, when there are too many metrics that cannot fit | |
into TPU memory. Note that this may increase latency due to data transfer | |
between TPU and CPU. Also, the step output from a validation step may be a | |
tuple with elements from replicas, and a concatenation of the elements is | |
needed in such case. | |
Args: | |
state: The current state of training, for example, it can be a sequence of | |
metrics. | |
step_logs: Logs from a validation step. Can be a dictionary. | |
""" | |
pass | |
def reduce_aggregated_logs(self, | |
aggregated_logs, | |
global_step: Optional[tf.Tensor] = None): | |
"""Optional reduce of aggregated logs over validation steps. | |
This function reduces aggregated logs at the end of validation, and can be | |
used to compute the final metrics. It runs on CPU and in each eval_end() in | |
base trainer (see eval_end() function in official/core/base_trainer.py). | |
Args: | |
aggregated_logs: Aggregated logs over multiple validation steps. | |
global_step: An optional variable of global step. | |
Returns: | |
A dictionary of reduced results. | |
""" | |
return {} | |