Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Image classification task definition.""" | |
from typing import Any, List, Optional, Tuple | |
from absl import logging | |
import tensorflow as tf, tf_keras | |
from official.common import dataset_fn | |
from official.core import base_task | |
from official.core import task_factory | |
from official.modeling import tf_utils | |
from official.vision.configs import image_classification as exp_cfg | |
from official.vision.dataloaders import classification_input | |
from official.vision.dataloaders import input_reader | |
from official.vision.dataloaders import input_reader_factory | |
from official.vision.dataloaders import tfds_factory | |
from official.vision.modeling import factory | |
from official.vision.ops import augment | |
_EPSILON = 1e-6 | |
class ImageClassificationTask(base_task.Task): | |
"""A task for image classification.""" | |
def build_model(self): | |
"""Builds classification model.""" | |
input_specs = tf_keras.layers.InputSpec( | |
shape=[None] + self.task_config.model.input_size) | |
l2_weight_decay = self.task_config.losses.l2_weight_decay | |
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss. | |
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2) | |
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss) | |
l2_regularizer = (tf_keras.regularizers.l2( | |
l2_weight_decay / 2.0) if l2_weight_decay else None) | |
model = factory.build_classification_model( | |
input_specs=input_specs, | |
model_config=self.task_config.model, | |
l2_regularizer=l2_regularizer) | |
if self.task_config.freeze_backbone: | |
model.backbone.trainable = False | |
# Builds the model | |
dummy_inputs = tf_keras.Input(self.task_config.model.input_size) | |
_ = model(dummy_inputs, training=False) | |
return model | |
def initialize(self, model: tf_keras.Model): | |
"""Loads pretrained checkpoint.""" | |
if not self.task_config.init_checkpoint: | |
return | |
ckpt_dir_or_file = self.task_config.init_checkpoint | |
if tf.io.gfile.isdir(ckpt_dir_or_file): | |
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file) | |
# Restoring checkpoint. | |
if self.task_config.init_checkpoint_modules == 'all': | |
ckpt = tf.train.Checkpoint(model=model) | |
status = ckpt.read(ckpt_dir_or_file) | |
status.expect_partial().assert_existing_objects_matched() | |
elif self.task_config.init_checkpoint_modules == 'backbone': | |
ckpt = tf.train.Checkpoint(backbone=model.backbone) | |
status = ckpt.read(ckpt_dir_or_file) | |
status.expect_partial().assert_existing_objects_matched() | |
else: | |
raise ValueError( | |
"Only 'all' or 'backbone' can be used to initialize the model.") | |
logging.info('Finished loading pretrained checkpoint from %s', | |
ckpt_dir_or_file) | |
def build_inputs( | |
self, | |
params: exp_cfg.DataConfig, | |
input_context: Optional[tf.distribute.InputContext] = None | |
) -> tf.data.Dataset: | |
"""Builds classification input.""" | |
num_classes = self.task_config.model.num_classes | |
input_size = self.task_config.model.input_size | |
image_field_key = self.task_config.train_data.image_field_key | |
label_field_key = self.task_config.train_data.label_field_key | |
is_multilabel = self.task_config.train_data.is_multilabel | |
if params.tfds_name: | |
decoder = tfds_factory.get_classification_decoder(params.tfds_name) | |
else: | |
decoder = classification_input.Decoder( | |
image_field_key=image_field_key, label_field_key=label_field_key, | |
is_multilabel=is_multilabel) | |
parser = classification_input.Parser( | |
output_size=input_size[:2], | |
num_classes=num_classes, | |
image_field_key=image_field_key, | |
label_field_key=label_field_key, | |
decode_jpeg_only=params.decode_jpeg_only, | |
aug_rand_hflip=params.aug_rand_hflip, | |
aug_crop=params.aug_crop, | |
aug_type=params.aug_type, | |
color_jitter=params.color_jitter, | |
random_erasing=params.random_erasing, | |
is_multilabel=is_multilabel, | |
dtype=params.dtype, | |
center_crop_fraction=params.center_crop_fraction, | |
tf_resize_method=params.tf_resize_method, | |
three_augment=params.three_augment) | |
postprocess_fn = None | |
if params.mixup_and_cutmix: | |
postprocess_fn = augment.MixupAndCutmix( | |
mixup_alpha=params.mixup_and_cutmix.mixup_alpha, | |
cutmix_alpha=params.mixup_and_cutmix.cutmix_alpha, | |
prob=params.mixup_and_cutmix.prob, | |
label_smoothing=params.mixup_and_cutmix.label_smoothing, | |
num_classes=num_classes) | |
def sample_fn(repeated_augment, dataset): | |
weights = [1 / repeated_augment] * repeated_augment | |
dataset = tf.data.Dataset.sample_from_datasets( | |
datasets=[dataset] * repeated_augment, | |
weights=weights, | |
seed=None, | |
stop_on_empty_dataset=True, | |
) | |
return dataset | |
is_repeated_augment = ( | |
params.is_training | |
and params.repeated_augment is not None | |
) | |
reader = input_reader_factory.input_reader_generator( | |
params, | |
dataset_fn=dataset_fn.pick_dataset_fn(params.file_type), | |
decoder_fn=decoder.decode, | |
combine_fn=input_reader.create_combine_fn(params), | |
parser_fn=parser.parse_fn(params.is_training), | |
postprocess_fn=postprocess_fn, | |
sample_fn=(lambda ds: sample_fn(params.repeated_augment, ds)) | |
if is_repeated_augment | |
else None, | |
) | |
dataset = reader.read(input_context=input_context) | |
return dataset | |
def build_losses(self, | |
labels: tf.Tensor, | |
model_outputs: tf.Tensor, | |
aux_losses: Optional[Any] = None) -> tf.Tensor: | |
"""Builds sparse categorical cross entropy loss. | |
Args: | |
labels: Input groundtruth labels. | |
model_outputs: Output logits of the classifier. | |
aux_losses: The auxiliarly loss tensors, i.e. `losses` in tf_keras.Model. | |
Returns: | |
The total loss tensor. | |
""" | |
losses_config = self.task_config.losses | |
is_multilabel = self.task_config.train_data.is_multilabel | |
if not is_multilabel: | |
if losses_config.use_binary_cross_entropy: | |
total_loss = tf.nn.sigmoid_cross_entropy_with_logits( | |
labels=labels, logits=model_outputs | |
) | |
# Average over all object classes inside an image. | |
total_loss = tf.reduce_mean(total_loss, axis=-1) | |
elif losses_config.one_hot: | |
total_loss = tf_keras.losses.categorical_crossentropy( | |
labels, | |
model_outputs, | |
from_logits=True, | |
label_smoothing=losses_config.label_smoothing) | |
elif losses_config.soft_labels: | |
total_loss = tf.nn.softmax_cross_entropy_with_logits( | |
labels, model_outputs) | |
else: | |
total_loss = tf_keras.losses.sparse_categorical_crossentropy( | |
labels, model_outputs, from_logits=True) | |
else: | |
# Multi-label binary cross entropy loss. This will apply `reduce_mean`. | |
total_loss = tf_keras.losses.binary_crossentropy( | |
labels, | |
model_outputs, | |
from_logits=True, | |
label_smoothing=losses_config.label_smoothing, | |
axis=-1) | |
# Multiple num_classes to behave like `reduce_sum`. | |
total_loss = total_loss * self.task_config.model.num_classes | |
total_loss = tf_utils.safe_mean(total_loss) | |
if aux_losses: | |
total_loss += tf.add_n(aux_losses) | |
total_loss = losses_config.loss_weight * total_loss | |
return total_loss | |
def build_metrics(self, | |
training: bool = True) -> List[tf_keras.metrics.Metric]: | |
"""Gets streaming metrics for training/validation.""" | |
is_multilabel = self.task_config.train_data.is_multilabel | |
if not is_multilabel: | |
k = self.task_config.evaluation.top_k | |
if (self.task_config.losses.one_hot or | |
self.task_config.losses.soft_labels): | |
metrics = [ | |
tf_keras.metrics.CategoricalAccuracy(name='accuracy'), | |
tf_keras.metrics.TopKCategoricalAccuracy( | |
k=k, name='top_{}_accuracy'.format(k))] | |
if hasattr( | |
self.task_config.evaluation, 'precision_and_recall_thresholds' | |
) and self.task_config.evaluation.precision_and_recall_thresholds: | |
thresholds = self.task_config.evaluation.precision_and_recall_thresholds # pylint: disable=line-too-long | |
# pylint:disable=g-complex-comprehension | |
metrics += [ | |
tf_keras.metrics.Precision( | |
thresholds=th, | |
name='precision_at_threshold_{}'.format(th), | |
top_k=1) for th in thresholds | |
] | |
metrics += [ | |
tf_keras.metrics.Recall( | |
thresholds=th, | |
name='recall_at_threshold_{}'.format(th), | |
top_k=1) for th in thresholds | |
] | |
# Add per-class precision and recall. | |
if hasattr( | |
self.task_config.evaluation, | |
'report_per_class_precision_and_recall' | |
) and self.task_config.evaluation.report_per_class_precision_and_recall: | |
for class_id in range(self.task_config.model.num_classes): | |
metrics += [ | |
tf_keras.metrics.Precision( | |
thresholds=th, | |
class_id=class_id, | |
name=f'precision_at_threshold_{th}/{class_id}', | |
top_k=1) for th in thresholds | |
] | |
metrics += [ | |
tf_keras.metrics.Recall( | |
thresholds=th, | |
class_id=class_id, | |
name=f'recall_at_threshold_{th}/{class_id}', | |
top_k=1) for th in thresholds | |
] | |
# pylint:enable=g-complex-comprehension | |
else: | |
metrics = [ | |
tf_keras.metrics.SparseCategoricalAccuracy(name='accuracy'), | |
tf_keras.metrics.SparseTopKCategoricalAccuracy( | |
k=k, name='top_{}_accuracy'.format(k))] | |
else: | |
metrics = [] | |
# These metrics destablize the training if included in training. The jobs | |
# fail due to OOM. | |
# TODO(arashwan): Investigate adding following metric to train. | |
if not training: | |
metrics = [ | |
tf_keras.metrics.AUC( | |
name='globalPR-AUC', | |
curve='PR', | |
multi_label=False, | |
from_logits=True), | |
tf_keras.metrics.AUC( | |
name='meanPR-AUC', | |
curve='PR', | |
multi_label=True, | |
num_labels=self.task_config.model.num_classes, | |
from_logits=True), | |
] | |
return metrics | |
def train_step(self, | |
inputs: Tuple[Any, Any], | |
model: tf_keras.Model, | |
optimizer: tf_keras.optimizers.Optimizer, | |
metrics: Optional[List[Any]] = None): | |
"""Does forward and backward. | |
Args: | |
inputs: A tuple of input tensors of (features, labels). | |
model: A tf_keras.Model instance. | |
optimizer: The optimizer for this training step. | |
metrics: A nested structure of metrics objects. | |
Returns: | |
A dictionary of logs. | |
""" | |
features, labels = inputs | |
is_multilabel = self.task_config.train_data.is_multilabel | |
if self.task_config.losses.one_hot and not is_multilabel: | |
labels = tf.one_hot(labels, self.task_config.model.num_classes) | |
if self.task_config.losses.use_binary_cross_entropy: | |
# BCE loss converts the multiclass classification to multilabel. The | |
# corresponding label value of objects present in the image would be one. | |
if self.task_config.train_data.mixup_and_cutmix is not None: | |
# label values below off_value_threshold would be mapped to zero and | |
# above that would be mapped to one. Negative labels are guaranteed to | |
# have value less than or equal value of the off_value from mixup. | |
off_value_threshold = ( | |
self.task_config.train_data.mixup_and_cutmix.label_smoothing | |
/ self.task_config.model.num_classes | |
) | |
labels = tf.where( | |
tf.less(labels, off_value_threshold + _EPSILON), 0.0, 1.0) | |
elif tf.rank(labels) == 1: | |
labels = tf.one_hot(labels, self.task_config.model.num_classes) | |
num_replicas = tf.distribute.get_strategy().num_replicas_in_sync | |
with tf.GradientTape() as tape: | |
outputs = model(features, training=True) | |
# Casting output layer as float32 is necessary when mixed_precision is | |
# mixed_float16 or mixed_bfloat16 to ensure output is casted as float32. | |
outputs = tf.nest.map_structure( | |
lambda x: tf.cast(x, tf.float32), outputs) | |
# Computes per-replica loss. | |
loss = self.build_losses( | |
model_outputs=outputs, | |
labels=labels, | |
aux_losses=model.losses) | |
# Scales loss as the default gradients allreduce performs sum inside the | |
# optimizer. | |
scaled_loss = loss / num_replicas | |
# For mixed_precision policy, when LossScaleOptimizer is used, loss is | |
# scaled for numerical stability. | |
if isinstance( | |
optimizer, tf_keras.mixed_precision.LossScaleOptimizer): | |
scaled_loss = optimizer.get_scaled_loss(scaled_loss) | |
tvars = model.trainable_variables | |
grads = tape.gradient(scaled_loss, tvars) | |
# Scales back gradient before apply_gradients when LossScaleOptimizer is | |
# used. | |
if isinstance( | |
optimizer, tf_keras.mixed_precision.LossScaleOptimizer): | |
grads = optimizer.get_unscaled_gradients(grads) | |
optimizer.apply_gradients(list(zip(grads, tvars))) | |
logs = {self.loss: loss} | |
# Convert logits to softmax for metric computation if needed. | |
if hasattr(self.task_config.model, | |
'output_softmax') and self.task_config.model.output_softmax: | |
outputs = tf.nn.softmax(outputs, axis=-1) | |
if metrics: | |
self.process_metrics(metrics, labels, outputs) | |
elif model.compiled_metrics: | |
self.process_compiled_metrics(model.compiled_metrics, labels, outputs) | |
logs.update({m.name: m.result() for m in model.metrics}) | |
return logs | |
def validation_step(self, | |
inputs: Tuple[Any, Any], | |
model: tf_keras.Model, | |
metrics: Optional[List[Any]] = None): | |
"""Runs validatation step. | |
Args: | |
inputs: A tuple of input tensors of (features, labels). | |
model: A tf_keras.Model instance. | |
metrics: A nested structure of metrics objects. | |
Returns: | |
A dictionary of logs. | |
""" | |
features, labels = inputs | |
one_hot = self.task_config.losses.one_hot | |
soft_labels = self.task_config.losses.soft_labels | |
is_multilabel = self.task_config.train_data.is_multilabel | |
# Note: `soft_labels`` only apply to the training phrase. In the validation | |
# phrase, labels should still be integer ids and need to be converted to | |
# one hot format. | |
if (one_hot or soft_labels) and not is_multilabel: | |
labels = tf.one_hot(labels, self.task_config.model.num_classes) | |
outputs = self.inference_step(features, model) | |
outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs) | |
loss = self.build_losses( | |
model_outputs=outputs, | |
labels=labels, | |
aux_losses=model.losses) | |
logs = {self.loss: loss} | |
# Convert logits to softmax for metric computation if needed. | |
if hasattr(self.task_config.model, | |
'output_softmax') and self.task_config.model.output_softmax: | |
outputs = tf.nn.softmax(outputs, axis=-1) | |
if metrics: | |
self.process_metrics(metrics, labels, outputs) | |
elif model.compiled_metrics: | |
self.process_compiled_metrics(model.compiled_metrics, labels, outputs) | |
logs.update({m.name: m.result() for m in model.metrics}) | |
return logs | |
def inference_step(self, inputs: tf.Tensor, model: tf_keras.Model): | |
"""Performs the forward step.""" | |
return model(inputs, training=False) | |