deanna-emery's picture
updates
93528c6
raw
history blame
14.3 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Video classification task definition."""
from typing import Any, Optional, List, Tuple
from absl import logging
import tensorflow as tf, tf_keras
from official.core import base_task
from official.core import task_factory
from official.modeling import tf_utils
from official.vision.configs import video_classification as exp_cfg
from official.vision.dataloaders import input_reader_factory
from official.vision.dataloaders import video_input
from official.vision.modeling import factory_3d
from official.vision.ops import augment
@task_factory.register_task_cls(exp_cfg.VideoClassificationTask)
class VideoClassificationTask(base_task.Task):
"""A task for video classification."""
def _get_num_classes(self):
"""Gets the number of classes."""
return self.task_config.train_data.num_classes
def _get_feature_shape(self):
"""Get the common feature shape for train and eval."""
return [
d1 if d1 == d2 else None
for d1, d2 in zip(self.task_config.train_data.feature_shape,
self.task_config.validation_data.feature_shape)
]
def _get_num_test_views(self):
"""Gets number of views for test."""
num_test_clips = self.task_config.validation_data.num_test_clips
num_test_crops = self.task_config.validation_data.num_test_crops
num_test_views = num_test_clips * num_test_crops
return num_test_views
def _is_multilabel(self):
"""If the label is multi-labels."""
return self.task_config.train_data.is_multilabel
def build_model(self):
"""Builds video classification model."""
common_input_shape = self._get_feature_shape()
input_specs = tf_keras.layers.InputSpec(shape=[None] + common_input_shape)
logging.info('Build model input %r', common_input_shape)
l2_weight_decay = float(self.task_config.losses.l2_weight_decay)
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
l2_regularizer = (tf_keras.regularizers.l2(
l2_weight_decay / 2.0) if l2_weight_decay else None)
model = factory_3d.build_model(
self.task_config.model.model_type,
input_specs=input_specs,
model_config=self.task_config.model,
num_classes=self._get_num_classes(),
l2_regularizer=l2_regularizer)
if self.task_config.freeze_backbone:
logging.info('Freezing model backbone.')
model.backbone.trainable = False
return model
def initialize(self, model: tf_keras.Model):
"""Loads pretrained checkpoint."""
if not self.task_config.init_checkpoint:
return
ckpt_dir_or_file = self.task_config.init_checkpoint
if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
# Restoring checkpoint.
if self.task_config.init_checkpoint_modules == 'all':
ckpt = tf.train.Checkpoint(model=model)
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
elif self.task_config.init_checkpoint_modules == 'backbone':
ckpt = tf.train.Checkpoint(backbone=model.backbone)
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
else:
raise ValueError(
"Only 'all' or 'backbone' can be used to initialize the model.")
logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file)
def _get_dataset_fn(self, params):
if params.file_type == 'tfrecord':
return tf.data.TFRecordDataset
else:
raise ValueError('Unknown input file type {!r}'.format(params.file_type))
def _get_decoder_fn(self, params):
if params.tfds_name:
decoder = video_input.VideoTfdsDecoder(
image_key=params.image_field_key, label_key=params.label_field_key)
else:
decoder = video_input.Decoder(
image_key=params.image_field_key, label_key=params.label_field_key)
if self.task_config.train_data.output_audio:
assert self.task_config.train_data.audio_feature, 'audio feature is empty'
decoder.add_feature(self.task_config.train_data.audio_feature,
tf.io.VarLenFeature(dtype=tf.float32))
return decoder.decode
def build_inputs(self,
params: exp_cfg.DataConfig,
input_context: Optional[tf.distribute.InputContext] = None):
"""Builds classification input."""
parser = video_input.Parser(
input_params=params,
image_key=params.image_field_key,
label_key=params.label_field_key)
postprocess_fn = video_input.PostBatchProcessor(params)
if params.mixup_and_cutmix is not None:
def mixup_and_cutmix(features, labels):
augmenter = augment.MixupAndCutmix(
mixup_alpha=params.mixup_and_cutmix.mixup_alpha,
cutmix_alpha=params.mixup_and_cutmix.cutmix_alpha,
prob=params.mixup_and_cutmix.prob,
label_smoothing=params.mixup_and_cutmix.label_smoothing,
num_classes=self._get_num_classes())
features['image'], labels = augmenter(features['image'], labels)
return features, labels
postprocess_fn = mixup_and_cutmix
reader = input_reader_factory.input_reader_generator(
params,
dataset_fn=self._get_dataset_fn(params),
decoder_fn=self._get_decoder_fn(params),
parser_fn=parser.parse_fn(params.is_training),
postprocess_fn=postprocess_fn)
dataset = reader.read(input_context=input_context)
return dataset
def build_losses(self,
labels: Any,
model_outputs: Any,
aux_losses: Optional[Any] = None):
"""Sparse categorical cross entropy loss.
Args:
labels: labels.
model_outputs: Output logits of the classifier.
aux_losses: auxiliarly loss tensors, i.e. `losses` in keras.Model.
Returns:
The total loss tensor.
"""
all_losses = {}
losses_config = self.task_config.losses
total_loss = None
if self._is_multilabel():
entropy = -tf.reduce_mean(
tf.reduce_sum(model_outputs * tf.math.log(model_outputs + 1e-8), -1))
total_loss = tf_keras.losses.binary_crossentropy(
labels, model_outputs, from_logits=False)
all_losses.update({
'class_loss': total_loss,
'entropy': entropy,
})
else:
if losses_config.one_hot:
total_loss = tf_keras.losses.categorical_crossentropy(
labels,
model_outputs,
from_logits=False,
label_smoothing=losses_config.label_smoothing)
else:
total_loss = tf_keras.losses.sparse_categorical_crossentropy(
labels, model_outputs, from_logits=False)
total_loss = tf_utils.safe_mean(total_loss)
all_losses.update({
'class_loss': total_loss,
})
if aux_losses:
all_losses.update({
'reg_loss': aux_losses,
})
total_loss += tf.add_n(aux_losses)
all_losses[self.loss] = total_loss
return all_losses
def build_metrics(self, training: bool = True):
"""Gets streaming metrics for training/validation."""
if self.task_config.losses.one_hot:
metrics = [
tf_keras.metrics.CategoricalAccuracy(name='accuracy'),
tf_keras.metrics.TopKCategoricalAccuracy(k=1, name='top_1_accuracy'),
tf_keras.metrics.TopKCategoricalAccuracy(k=5, name='top_5_accuracy')
]
if self._is_multilabel():
metrics.append(
tf_keras.metrics.AUC(
curve='ROC', multi_label=self._is_multilabel(), name='ROC-AUC'))
metrics.append(
tf_keras.metrics.RecallAtPrecision(
0.95, name='RecallAtPrecision95'))
metrics.append(
tf_keras.metrics.AUC(
curve='PR', multi_label=self._is_multilabel(), name='PR-AUC'))
if self.task_config.metrics.use_per_class_recall:
for i in range(self._get_num_classes()):
metrics.append(
tf_keras.metrics.Recall(class_id=i, name=f'recall-{i}'))
else:
metrics = [
tf_keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
tf_keras.metrics.SparseTopKCategoricalAccuracy(
k=1, name='top_1_accuracy'),
tf_keras.metrics.SparseTopKCategoricalAccuracy(
k=5, name='top_5_accuracy')
]
return metrics
def process_metrics(self, metrics: List[Any], labels: Any,
model_outputs: Any):
"""Process and update metrics.
Called when using custom training loop API.
Args:
metrics: a nested structure of metrics objects. The return of function
self.build_metrics.
labels: a tensor or a nested structure of tensors.
model_outputs: a tensor or a nested structure of tensors. For example,
output of the keras model built by self.build_model.
"""
for metric in metrics:
metric.update_state(labels, model_outputs)
def train_step(self,
inputs: Tuple[Any, Any],
model: tf_keras.Model,
optimizer: tf_keras.optimizers.Optimizer,
metrics: Optional[List[Any]] = None):
"""Does forward and backward.
Args:
inputs: a dictionary of input tensors.
model: the model, forward pass definition.
optimizer: the optimizer for this training step.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features, labels = inputs
input_partition_dims = self.task_config.train_input_partition_dims
if input_partition_dims:
strategy = tf.distribute.get_strategy()
features['image'] = strategy.experimental_split_to_logical_devices(
features['image'], input_partition_dims)
num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
with tf.GradientTape() as tape:
outputs = model(features, training=True)
# Casting output layer as float32 is necessary when mixed_precision is
# mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
outputs = tf.nest.map_structure(
lambda x: tf.cast(x, tf.float32), outputs)
# Computes per-replica loss.
if self._is_multilabel():
outputs = tf.nest.map_structure(tf.math.sigmoid, outputs)
else:
outputs = tf.nest.map_structure(tf.math.softmax, outputs)
all_losses = self.build_losses(
model_outputs=outputs, labels=labels, aux_losses=model.losses)
loss = all_losses[self.loss]
# Scales loss as the default gradients allreduce performs sum inside the
# optimizer.
scaled_loss = loss / num_replicas
# For mixed_precision policy, when LossScaleOptimizer is used, loss is
# scaled for numerical stability.
if isinstance(
optimizer, tf_keras.mixed_precision.LossScaleOptimizer):
scaled_loss = optimizer.get_scaled_loss(scaled_loss)
tvars = model.trainable_variables
grads = tape.gradient(scaled_loss, tvars)
# Scales back gradient before apply_gradients when LossScaleOptimizer is
# used.
if isinstance(optimizer, tf_keras.mixed_precision.LossScaleOptimizer):
grads = optimizer.get_unscaled_gradients(grads)
optimizer.apply_gradients(list(zip(grads, tvars)))
logs = all_losses
if metrics:
self.process_metrics(metrics, labels, outputs)
logs.update({m.name: m.result() for m in metrics})
elif model.compiled_metrics:
self.process_compiled_metrics(model.compiled_metrics, labels, outputs)
logs.update({m.name: m.result() for m in model.metrics})
return logs
def validation_step(self,
inputs: Tuple[Any, Any],
model: tf_keras.Model,
metrics: Optional[List[Any]] = None):
"""Validatation step.
Args:
inputs: a dictionary of input tensors.
model: the keras.Model.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features, labels = inputs
input_partition_dims = self.task_config.eval_input_partition_dims
if input_partition_dims:
strategy = tf.distribute.get_strategy()
features['image'] = strategy.experimental_split_to_logical_devices(
features['image'], input_partition_dims)
outputs = self.inference_step(features, model)
outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
logs = self.build_losses(model_outputs=outputs, labels=labels,
aux_losses=model.losses)
if metrics:
self.process_metrics(metrics, labels, outputs)
logs.update({m.name: m.result() for m in metrics})
elif model.compiled_metrics:
self.process_compiled_metrics(model.compiled_metrics, labels, outputs)
logs.update({m.name: m.result() for m in model.metrics})
return logs
def inference_step(self, features: tf.Tensor, model: tf_keras.Model):
"""Performs the forward step."""
outputs = model(features, training=False)
if self._is_multilabel():
outputs = tf.nest.map_structure(tf.math.sigmoid, outputs)
else:
outputs = tf.nest.map_structure(tf.math.softmax, outputs)
num_test_views = self._get_num_test_views()
if num_test_views > 1:
# Averaging output probabilities across multiples views.
outputs = tf.reshape(outputs, [-1, num_test_views, outputs.shape[-1]])
outputs = tf.reduce_mean(outputs, axis=1)
return outputs