deanna-emery's picture
updates
93528c6
raw
history blame
16.7 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image classification task definition."""
from typing import Any, List, Optional, Tuple
from absl import logging
import tensorflow as tf, tf_keras
from official.common import dataset_fn
from official.core import base_task
from official.core import task_factory
from official.modeling import tf_utils
from official.vision.configs import image_classification as exp_cfg
from official.vision.dataloaders import classification_input
from official.vision.dataloaders import input_reader
from official.vision.dataloaders import input_reader_factory
from official.vision.dataloaders import tfds_factory
from official.vision.modeling import factory
from official.vision.ops import augment
_EPSILON = 1e-6
@task_factory.register_task_cls(exp_cfg.ImageClassificationTask)
class ImageClassificationTask(base_task.Task):
"""A task for image classification."""
def build_model(self):
"""Builds classification model."""
input_specs = tf_keras.layers.InputSpec(
shape=[None] + self.task_config.model.input_size)
l2_weight_decay = self.task_config.losses.l2_weight_decay
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
l2_regularizer = (tf_keras.regularizers.l2(
l2_weight_decay / 2.0) if l2_weight_decay else None)
model = factory.build_classification_model(
input_specs=input_specs,
model_config=self.task_config.model,
l2_regularizer=l2_regularizer)
if self.task_config.freeze_backbone:
model.backbone.trainable = False
# Builds the model
dummy_inputs = tf_keras.Input(self.task_config.model.input_size)
_ = model(dummy_inputs, training=False)
return model
def initialize(self, model: tf_keras.Model):
"""Loads pretrained checkpoint."""
if not self.task_config.init_checkpoint:
return
ckpt_dir_or_file = self.task_config.init_checkpoint
if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
# Restoring checkpoint.
if self.task_config.init_checkpoint_modules == 'all':
ckpt = tf.train.Checkpoint(model=model)
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
elif self.task_config.init_checkpoint_modules == 'backbone':
ckpt = tf.train.Checkpoint(backbone=model.backbone)
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
else:
raise ValueError(
"Only 'all' or 'backbone' can be used to initialize the model.")
logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file)
def build_inputs(
self,
params: exp_cfg.DataConfig,
input_context: Optional[tf.distribute.InputContext] = None
) -> tf.data.Dataset:
"""Builds classification input."""
num_classes = self.task_config.model.num_classes
input_size = self.task_config.model.input_size
image_field_key = self.task_config.train_data.image_field_key
label_field_key = self.task_config.train_data.label_field_key
is_multilabel = self.task_config.train_data.is_multilabel
if params.tfds_name:
decoder = tfds_factory.get_classification_decoder(params.tfds_name)
else:
decoder = classification_input.Decoder(
image_field_key=image_field_key, label_field_key=label_field_key,
is_multilabel=is_multilabel)
parser = classification_input.Parser(
output_size=input_size[:2],
num_classes=num_classes,
image_field_key=image_field_key,
label_field_key=label_field_key,
decode_jpeg_only=params.decode_jpeg_only,
aug_rand_hflip=params.aug_rand_hflip,
aug_crop=params.aug_crop,
aug_type=params.aug_type,
color_jitter=params.color_jitter,
random_erasing=params.random_erasing,
is_multilabel=is_multilabel,
dtype=params.dtype,
center_crop_fraction=params.center_crop_fraction,
tf_resize_method=params.tf_resize_method,
three_augment=params.three_augment)
postprocess_fn = None
if params.mixup_and_cutmix:
postprocess_fn = augment.MixupAndCutmix(
mixup_alpha=params.mixup_and_cutmix.mixup_alpha,
cutmix_alpha=params.mixup_and_cutmix.cutmix_alpha,
prob=params.mixup_and_cutmix.prob,
label_smoothing=params.mixup_and_cutmix.label_smoothing,
num_classes=num_classes)
def sample_fn(repeated_augment, dataset):
weights = [1 / repeated_augment] * repeated_augment
dataset = tf.data.Dataset.sample_from_datasets(
datasets=[dataset] * repeated_augment,
weights=weights,
seed=None,
stop_on_empty_dataset=True,
)
return dataset
is_repeated_augment = (
params.is_training
and params.repeated_augment is not None
)
reader = input_reader_factory.input_reader_generator(
params,
dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
decoder_fn=decoder.decode,
combine_fn=input_reader.create_combine_fn(params),
parser_fn=parser.parse_fn(params.is_training),
postprocess_fn=postprocess_fn,
sample_fn=(lambda ds: sample_fn(params.repeated_augment, ds))
if is_repeated_augment
else None,
)
dataset = reader.read(input_context=input_context)
return dataset
def build_losses(self,
labels: tf.Tensor,
model_outputs: tf.Tensor,
aux_losses: Optional[Any] = None) -> tf.Tensor:
"""Builds sparse categorical cross entropy loss.
Args:
labels: Input groundtruth labels.
model_outputs: Output logits of the classifier.
aux_losses: The auxiliarly loss tensors, i.e. `losses` in tf_keras.Model.
Returns:
The total loss tensor.
"""
losses_config = self.task_config.losses
is_multilabel = self.task_config.train_data.is_multilabel
if not is_multilabel:
if losses_config.use_binary_cross_entropy:
total_loss = tf.nn.sigmoid_cross_entropy_with_logits(
labels=labels, logits=model_outputs
)
# Average over all object classes inside an image.
total_loss = tf.reduce_mean(total_loss, axis=-1)
elif losses_config.one_hot:
total_loss = tf_keras.losses.categorical_crossentropy(
labels,
model_outputs,
from_logits=True,
label_smoothing=losses_config.label_smoothing)
elif losses_config.soft_labels:
total_loss = tf.nn.softmax_cross_entropy_with_logits(
labels, model_outputs)
else:
total_loss = tf_keras.losses.sparse_categorical_crossentropy(
labels, model_outputs, from_logits=True)
else:
# Multi-label binary cross entropy loss. This will apply `reduce_mean`.
total_loss = tf_keras.losses.binary_crossentropy(
labels,
model_outputs,
from_logits=True,
label_smoothing=losses_config.label_smoothing,
axis=-1)
# Multiple num_classes to behave like `reduce_sum`.
total_loss = total_loss * self.task_config.model.num_classes
total_loss = tf_utils.safe_mean(total_loss)
if aux_losses:
total_loss += tf.add_n(aux_losses)
total_loss = losses_config.loss_weight * total_loss
return total_loss
def build_metrics(self,
training: bool = True) -> List[tf_keras.metrics.Metric]:
"""Gets streaming metrics for training/validation."""
is_multilabel = self.task_config.train_data.is_multilabel
if not is_multilabel:
k = self.task_config.evaluation.top_k
if (self.task_config.losses.one_hot or
self.task_config.losses.soft_labels):
metrics = [
tf_keras.metrics.CategoricalAccuracy(name='accuracy'),
tf_keras.metrics.TopKCategoricalAccuracy(
k=k, name='top_{}_accuracy'.format(k))]
if hasattr(
self.task_config.evaluation, 'precision_and_recall_thresholds'
) and self.task_config.evaluation.precision_and_recall_thresholds:
thresholds = self.task_config.evaluation.precision_and_recall_thresholds # pylint: disable=line-too-long
# pylint:disable=g-complex-comprehension
metrics += [
tf_keras.metrics.Precision(
thresholds=th,
name='precision_at_threshold_{}'.format(th),
top_k=1) for th in thresholds
]
metrics += [
tf_keras.metrics.Recall(
thresholds=th,
name='recall_at_threshold_{}'.format(th),
top_k=1) for th in thresholds
]
# Add per-class precision and recall.
if hasattr(
self.task_config.evaluation,
'report_per_class_precision_and_recall'
) and self.task_config.evaluation.report_per_class_precision_and_recall:
for class_id in range(self.task_config.model.num_classes):
metrics += [
tf_keras.metrics.Precision(
thresholds=th,
class_id=class_id,
name=f'precision_at_threshold_{th}/{class_id}',
top_k=1) for th in thresholds
]
metrics += [
tf_keras.metrics.Recall(
thresholds=th,
class_id=class_id,
name=f'recall_at_threshold_{th}/{class_id}',
top_k=1) for th in thresholds
]
# pylint:enable=g-complex-comprehension
else:
metrics = [
tf_keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
tf_keras.metrics.SparseTopKCategoricalAccuracy(
k=k, name='top_{}_accuracy'.format(k))]
else:
metrics = []
# These metrics destablize the training if included in training. The jobs
# fail due to OOM.
# TODO(arashwan): Investigate adding following metric to train.
if not training:
metrics = [
tf_keras.metrics.AUC(
name='globalPR-AUC',
curve='PR',
multi_label=False,
from_logits=True),
tf_keras.metrics.AUC(
name='meanPR-AUC',
curve='PR',
multi_label=True,
num_labels=self.task_config.model.num_classes,
from_logits=True),
]
return metrics
def train_step(self,
inputs: Tuple[Any, Any],
model: tf_keras.Model,
optimizer: tf_keras.optimizers.Optimizer,
metrics: Optional[List[Any]] = None):
"""Does forward and backward.
Args:
inputs: A tuple of input tensors of (features, labels).
model: A tf_keras.Model instance.
optimizer: The optimizer for this training step.
metrics: A nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features, labels = inputs
is_multilabel = self.task_config.train_data.is_multilabel
if self.task_config.losses.one_hot and not is_multilabel:
labels = tf.one_hot(labels, self.task_config.model.num_classes)
if self.task_config.losses.use_binary_cross_entropy:
# BCE loss converts the multiclass classification to multilabel. The
# corresponding label value of objects present in the image would be one.
if self.task_config.train_data.mixup_and_cutmix is not None:
# label values below off_value_threshold would be mapped to zero and
# above that would be mapped to one. Negative labels are guaranteed to
# have value less than or equal value of the off_value from mixup.
off_value_threshold = (
self.task_config.train_data.mixup_and_cutmix.label_smoothing
/ self.task_config.model.num_classes
)
labels = tf.where(
tf.less(labels, off_value_threshold + _EPSILON), 0.0, 1.0)
elif tf.rank(labels) == 1:
labels = tf.one_hot(labels, self.task_config.model.num_classes)
num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
with tf.GradientTape() as tape:
outputs = model(features, training=True)
# Casting output layer as float32 is necessary when mixed_precision is
# mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
outputs = tf.nest.map_structure(
lambda x: tf.cast(x, tf.float32), outputs)
# Computes per-replica loss.
loss = self.build_losses(
model_outputs=outputs,
labels=labels,
aux_losses=model.losses)
# Scales loss as the default gradients allreduce performs sum inside the
# optimizer.
scaled_loss = loss / num_replicas
# For mixed_precision policy, when LossScaleOptimizer is used, loss is
# scaled for numerical stability.
if isinstance(
optimizer, tf_keras.mixed_precision.LossScaleOptimizer):
scaled_loss = optimizer.get_scaled_loss(scaled_loss)
tvars = model.trainable_variables
grads = tape.gradient(scaled_loss, tvars)
# Scales back gradient before apply_gradients when LossScaleOptimizer is
# used.
if isinstance(
optimizer, tf_keras.mixed_precision.LossScaleOptimizer):
grads = optimizer.get_unscaled_gradients(grads)
optimizer.apply_gradients(list(zip(grads, tvars)))
logs = {self.loss: loss}
# Convert logits to softmax for metric computation if needed.
if hasattr(self.task_config.model,
'output_softmax') and self.task_config.model.output_softmax:
outputs = tf.nn.softmax(outputs, axis=-1)
if metrics:
self.process_metrics(metrics, labels, outputs)
elif model.compiled_metrics:
self.process_compiled_metrics(model.compiled_metrics, labels, outputs)
logs.update({m.name: m.result() for m in model.metrics})
return logs
def validation_step(self,
inputs: Tuple[Any, Any],
model: tf_keras.Model,
metrics: Optional[List[Any]] = None):
"""Runs validatation step.
Args:
inputs: A tuple of input tensors of (features, labels).
model: A tf_keras.Model instance.
metrics: A nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features, labels = inputs
one_hot = self.task_config.losses.one_hot
soft_labels = self.task_config.losses.soft_labels
is_multilabel = self.task_config.train_data.is_multilabel
# Note: `soft_labels`` only apply to the training phrase. In the validation
# phrase, labels should still be integer ids and need to be converted to
# one hot format.
if (one_hot or soft_labels) and not is_multilabel:
labels = tf.one_hot(labels, self.task_config.model.num_classes)
outputs = self.inference_step(features, model)
outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
loss = self.build_losses(
model_outputs=outputs,
labels=labels,
aux_losses=model.losses)
logs = {self.loss: loss}
# Convert logits to softmax for metric computation if needed.
if hasattr(self.task_config.model,
'output_softmax') and self.task_config.model.output_softmax:
outputs = tf.nn.softmax(outputs, axis=-1)
if metrics:
self.process_metrics(metrics, labels, outputs)
elif model.compiled_metrics:
self.process_compiled_metrics(model.compiled_metrics, labels, outputs)
logs.update({m.name: m.result() for m in model.metrics})
return logs
def inference_step(self, inputs: tf.Tensor, model: tf_keras.Model):
"""Performs the forward step."""
return model(inputs, training=False)