Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Video classification task definition.""" | |
from typing import Any, Optional, List, Tuple | |
from absl import logging | |
import tensorflow as tf, tf_keras | |
from official.core import base_task | |
from official.core import task_factory | |
from official.modeling import tf_utils | |
from official.vision.configs import video_classification as exp_cfg | |
from official.vision.dataloaders import input_reader_factory | |
from official.vision.dataloaders import video_input | |
from official.vision.modeling import factory_3d | |
from official.vision.ops import augment | |
class VideoClassificationTask(base_task.Task): | |
"""A task for video classification.""" | |
def _get_num_classes(self): | |
"""Gets the number of classes.""" | |
return self.task_config.train_data.num_classes | |
def _get_feature_shape(self): | |
"""Get the common feature shape for train and eval.""" | |
return [ | |
d1 if d1 == d2 else None | |
for d1, d2 in zip(self.task_config.train_data.feature_shape, | |
self.task_config.validation_data.feature_shape) | |
] | |
def _get_num_test_views(self): | |
"""Gets number of views for test.""" | |
num_test_clips = self.task_config.validation_data.num_test_clips | |
num_test_crops = self.task_config.validation_data.num_test_crops | |
num_test_views = num_test_clips * num_test_crops | |
return num_test_views | |
def _is_multilabel(self): | |
"""If the label is multi-labels.""" | |
return self.task_config.train_data.is_multilabel | |
def build_model(self): | |
"""Builds video classification model.""" | |
common_input_shape = self._get_feature_shape() | |
input_specs = tf_keras.layers.InputSpec(shape=[None] + common_input_shape) | |
logging.info('Build model input %r', common_input_shape) | |
l2_weight_decay = float(self.task_config.losses.l2_weight_decay) | |
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss. | |
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2) | |
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss) | |
l2_regularizer = (tf_keras.regularizers.l2( | |
l2_weight_decay / 2.0) if l2_weight_decay else None) | |
model = factory_3d.build_model( | |
self.task_config.model.model_type, | |
input_specs=input_specs, | |
model_config=self.task_config.model, | |
num_classes=self._get_num_classes(), | |
l2_regularizer=l2_regularizer) | |
if self.task_config.freeze_backbone: | |
logging.info('Freezing model backbone.') | |
model.backbone.trainable = False | |
return model | |
def initialize(self, model: tf_keras.Model): | |
"""Loads pretrained checkpoint.""" | |
if not self.task_config.init_checkpoint: | |
return | |
ckpt_dir_or_file = self.task_config.init_checkpoint | |
if tf.io.gfile.isdir(ckpt_dir_or_file): | |
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file) | |
# Restoring checkpoint. | |
if self.task_config.init_checkpoint_modules == 'all': | |
ckpt = tf.train.Checkpoint(model=model) | |
status = ckpt.read(ckpt_dir_or_file) | |
status.expect_partial().assert_existing_objects_matched() | |
elif self.task_config.init_checkpoint_modules == 'backbone': | |
ckpt = tf.train.Checkpoint(backbone=model.backbone) | |
status = ckpt.read(ckpt_dir_or_file) | |
status.expect_partial().assert_existing_objects_matched() | |
else: | |
raise ValueError( | |
"Only 'all' or 'backbone' can be used to initialize the model.") | |
logging.info('Finished loading pretrained checkpoint from %s', | |
ckpt_dir_or_file) | |
def _get_dataset_fn(self, params): | |
if params.file_type == 'tfrecord': | |
return tf.data.TFRecordDataset | |
else: | |
raise ValueError('Unknown input file type {!r}'.format(params.file_type)) | |
def _get_decoder_fn(self, params): | |
if params.tfds_name: | |
decoder = video_input.VideoTfdsDecoder( | |
image_key=params.image_field_key, label_key=params.label_field_key) | |
else: | |
decoder = video_input.Decoder( | |
image_key=params.image_field_key, label_key=params.label_field_key) | |
if self.task_config.train_data.output_audio: | |
assert self.task_config.train_data.audio_feature, 'audio feature is empty' | |
decoder.add_feature(self.task_config.train_data.audio_feature, | |
tf.io.VarLenFeature(dtype=tf.float32)) | |
return decoder.decode | |
def build_inputs(self, | |
params: exp_cfg.DataConfig, | |
input_context: Optional[tf.distribute.InputContext] = None): | |
"""Builds classification input.""" | |
parser = video_input.Parser( | |
input_params=params, | |
image_key=params.image_field_key, | |
label_key=params.label_field_key) | |
postprocess_fn = video_input.PostBatchProcessor(params) | |
if params.mixup_and_cutmix is not None: | |
def mixup_and_cutmix(features, labels): | |
augmenter = augment.MixupAndCutmix( | |
mixup_alpha=params.mixup_and_cutmix.mixup_alpha, | |
cutmix_alpha=params.mixup_and_cutmix.cutmix_alpha, | |
prob=params.mixup_and_cutmix.prob, | |
label_smoothing=params.mixup_and_cutmix.label_smoothing, | |
num_classes=self._get_num_classes()) | |
features['image'], labels = augmenter(features['image'], labels) | |
return features, labels | |
postprocess_fn = mixup_and_cutmix | |
reader = input_reader_factory.input_reader_generator( | |
params, | |
dataset_fn=self._get_dataset_fn(params), | |
decoder_fn=self._get_decoder_fn(params), | |
parser_fn=parser.parse_fn(params.is_training), | |
postprocess_fn=postprocess_fn) | |
dataset = reader.read(input_context=input_context) | |
return dataset | |
def build_losses(self, | |
labels: Any, | |
model_outputs: Any, | |
aux_losses: Optional[Any] = None): | |
"""Sparse categorical cross entropy loss. | |
Args: | |
labels: labels. | |
model_outputs: Output logits of the classifier. | |
aux_losses: auxiliarly loss tensors, i.e. `losses` in keras.Model. | |
Returns: | |
The total loss tensor. | |
""" | |
all_losses = {} | |
losses_config = self.task_config.losses | |
total_loss = None | |
if self._is_multilabel(): | |
entropy = -tf.reduce_mean( | |
tf.reduce_sum(model_outputs * tf.math.log(model_outputs + 1e-8), -1)) | |
total_loss = tf_keras.losses.binary_crossentropy( | |
labels, model_outputs, from_logits=False) | |
all_losses.update({ | |
'class_loss': total_loss, | |
'entropy': entropy, | |
}) | |
else: | |
if losses_config.one_hot: | |
total_loss = tf_keras.losses.categorical_crossentropy( | |
labels, | |
model_outputs, | |
from_logits=False, | |
label_smoothing=losses_config.label_smoothing) | |
else: | |
total_loss = tf_keras.losses.sparse_categorical_crossentropy( | |
labels, model_outputs, from_logits=False) | |
total_loss = tf_utils.safe_mean(total_loss) | |
all_losses.update({ | |
'class_loss': total_loss, | |
}) | |
if aux_losses: | |
all_losses.update({ | |
'reg_loss': aux_losses, | |
}) | |
total_loss += tf.add_n(aux_losses) | |
all_losses[self.loss] = total_loss | |
return all_losses | |
def build_metrics(self, training: bool = True): | |
"""Gets streaming metrics for training/validation.""" | |
if self.task_config.losses.one_hot: | |
metrics = [ | |
tf_keras.metrics.CategoricalAccuracy(name='accuracy'), | |
tf_keras.metrics.TopKCategoricalAccuracy(k=1, name='top_1_accuracy'), | |
tf_keras.metrics.TopKCategoricalAccuracy(k=5, name='top_5_accuracy') | |
] | |
if self._is_multilabel(): | |
metrics.append( | |
tf_keras.metrics.AUC( | |
curve='ROC', multi_label=self._is_multilabel(), name='ROC-AUC')) | |
metrics.append( | |
tf_keras.metrics.RecallAtPrecision( | |
0.95, name='RecallAtPrecision95')) | |
metrics.append( | |
tf_keras.metrics.AUC( | |
curve='PR', multi_label=self._is_multilabel(), name='PR-AUC')) | |
if self.task_config.metrics.use_per_class_recall: | |
for i in range(self._get_num_classes()): | |
metrics.append( | |
tf_keras.metrics.Recall(class_id=i, name=f'recall-{i}')) | |
else: | |
metrics = [ | |
tf_keras.metrics.SparseCategoricalAccuracy(name='accuracy'), | |
tf_keras.metrics.SparseTopKCategoricalAccuracy( | |
k=1, name='top_1_accuracy'), | |
tf_keras.metrics.SparseTopKCategoricalAccuracy( | |
k=5, name='top_5_accuracy') | |
] | |
return metrics | |
def process_metrics(self, metrics: List[Any], labels: Any, | |
model_outputs: Any): | |
"""Process and update metrics. | |
Called when using custom training loop API. | |
Args: | |
metrics: a nested structure of metrics objects. The return of function | |
self.build_metrics. | |
labels: a tensor or a nested structure of tensors. | |
model_outputs: a tensor or a nested structure of tensors. For example, | |
output of the keras model built by self.build_model. | |
""" | |
for metric in metrics: | |
metric.update_state(labels, model_outputs) | |
def train_step(self, | |
inputs: Tuple[Any, Any], | |
model: tf_keras.Model, | |
optimizer: tf_keras.optimizers.Optimizer, | |
metrics: Optional[List[Any]] = None): | |
"""Does forward and backward. | |
Args: | |
inputs: a dictionary of input tensors. | |
model: the model, forward pass definition. | |
optimizer: the optimizer for this training step. | |
metrics: a nested structure of metrics objects. | |
Returns: | |
A dictionary of logs. | |
""" | |
features, labels = inputs | |
input_partition_dims = self.task_config.train_input_partition_dims | |
if input_partition_dims: | |
strategy = tf.distribute.get_strategy() | |
features['image'] = strategy.experimental_split_to_logical_devices( | |
features['image'], input_partition_dims) | |
num_replicas = tf.distribute.get_strategy().num_replicas_in_sync | |
with tf.GradientTape() as tape: | |
outputs = model(features, training=True) | |
# Casting output layer as float32 is necessary when mixed_precision is | |
# mixed_float16 or mixed_bfloat16 to ensure output is casted as float32. | |
outputs = tf.nest.map_structure( | |
lambda x: tf.cast(x, tf.float32), outputs) | |
# Computes per-replica loss. | |
if self._is_multilabel(): | |
outputs = tf.nest.map_structure(tf.math.sigmoid, outputs) | |
else: | |
outputs = tf.nest.map_structure(tf.math.softmax, outputs) | |
all_losses = self.build_losses( | |
model_outputs=outputs, labels=labels, aux_losses=model.losses) | |
loss = all_losses[self.loss] | |
# Scales loss as the default gradients allreduce performs sum inside the | |
# optimizer. | |
scaled_loss = loss / num_replicas | |
# For mixed_precision policy, when LossScaleOptimizer is used, loss is | |
# scaled for numerical stability. | |
if isinstance( | |
optimizer, tf_keras.mixed_precision.LossScaleOptimizer): | |
scaled_loss = optimizer.get_scaled_loss(scaled_loss) | |
tvars = model.trainable_variables | |
grads = tape.gradient(scaled_loss, tvars) | |
# Scales back gradient before apply_gradients when LossScaleOptimizer is | |
# used. | |
if isinstance(optimizer, tf_keras.mixed_precision.LossScaleOptimizer): | |
grads = optimizer.get_unscaled_gradients(grads) | |
optimizer.apply_gradients(list(zip(grads, tvars))) | |
logs = all_losses | |
if metrics: | |
self.process_metrics(metrics, labels, outputs) | |
logs.update({m.name: m.result() for m in metrics}) | |
elif model.compiled_metrics: | |
self.process_compiled_metrics(model.compiled_metrics, labels, outputs) | |
logs.update({m.name: m.result() for m in model.metrics}) | |
return logs | |
def validation_step(self, | |
inputs: Tuple[Any, Any], | |
model: tf_keras.Model, | |
metrics: Optional[List[Any]] = None): | |
"""Validatation step. | |
Args: | |
inputs: a dictionary of input tensors. | |
model: the keras.Model. | |
metrics: a nested structure of metrics objects. | |
Returns: | |
A dictionary of logs. | |
""" | |
features, labels = inputs | |
input_partition_dims = self.task_config.eval_input_partition_dims | |
if input_partition_dims: | |
strategy = tf.distribute.get_strategy() | |
features['image'] = strategy.experimental_split_to_logical_devices( | |
features['image'], input_partition_dims) | |
outputs = self.inference_step(features, model) | |
outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs) | |
logs = self.build_losses(model_outputs=outputs, labels=labels, | |
aux_losses=model.losses) | |
if metrics: | |
self.process_metrics(metrics, labels, outputs) | |
logs.update({m.name: m.result() for m in metrics}) | |
elif model.compiled_metrics: | |
self.process_compiled_metrics(model.compiled_metrics, labels, outputs) | |
logs.update({m.name: m.result() for m in model.metrics}) | |
return logs | |
def inference_step(self, features: tf.Tensor, model: tf_keras.Model): | |
"""Performs the forward step.""" | |
outputs = model(features, training=False) | |
if self._is_multilabel(): | |
outputs = tf.nest.map_structure(tf.math.sigmoid, outputs) | |
else: | |
outputs = tf.nest.map_structure(tf.math.softmax, outputs) | |
num_test_views = self._get_num_test_views() | |
if num_test_views > 1: | |
# Averaging output probabilities across multiples views. | |
outputs = tf.reshape(outputs, [-1, num_test_views, outputs.shape[-1]]) | |
outputs = tf.reduce_mean(outputs, axis=1) | |
return outputs | |