# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """RetinaNet task definition.""" from typing import Any, List, Mapping, Optional, Tuple from absl import logging import tensorflow as tf, tf_keras from official.common import dataset_fn from official.core import base_task from official.core import task_factory from official.vision.configs import retinanet as exp_cfg from official.vision.dataloaders import input_reader from official.vision.dataloaders import input_reader_factory from official.vision.dataloaders import retinanet_input from official.vision.dataloaders import tf_example_decoder from official.vision.dataloaders import tfds_factory from official.vision.dataloaders import tf_example_label_map_decoder from official.vision.evaluation import coco_evaluator from official.vision.losses import focal_loss from official.vision.losses import loss_utils from official.vision.modeling import factory from official.vision.utils.object_detection import visualization_utils @task_factory.register_task_cls(exp_cfg.RetinaNetTask) class RetinaNetTask(base_task.Task): """A single-replica view of training procedure. RetinaNet task provides artifacts for training/evalution procedures, including loading/iterating over Datasets, initializing the model, calculating the loss, post-processing, and customized metrics with reduction. """ def build_model(self): """Build RetinaNet model.""" input_specs = tf_keras.layers.InputSpec( shape=[None] + self.task_config.model.input_size) l2_weight_decay = self.task_config.losses.l2_weight_decay # Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss. # (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2) # (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss) l2_regularizer = (tf_keras.regularizers.l2( l2_weight_decay / 2.0) if l2_weight_decay else None) model = factory.build_retinanet( input_specs=input_specs, model_config=self.task_config.model, l2_regularizer=l2_regularizer) if self.task_config.freeze_backbone: model.backbone.trainable = False return model def initialize(self, model: tf_keras.Model): """Loading pretrained checkpoint.""" if not self.task_config.init_checkpoint: return ckpt_dir_or_file = self.task_config.init_checkpoint if tf.io.gfile.isdir(ckpt_dir_or_file): ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file) # Restoring checkpoint. if self.task_config.init_checkpoint_modules == 'all': ckpt = tf.train.Checkpoint(**model.checkpoint_items) status = ckpt.read(ckpt_dir_or_file) status.expect_partial().assert_existing_objects_matched() else: ckpt_items = {} if 'backbone' in self.task_config.init_checkpoint_modules: ckpt_items.update(backbone=model.backbone) if 'decoder' in self.task_config.init_checkpoint_modules: ckpt_items.update(decoder=model.decoder) ckpt = tf.train.Checkpoint(**ckpt_items) status = ckpt.read(ckpt_dir_or_file) status.expect_partial().assert_existing_objects_matched() logging.info('Finished loading pretrained checkpoint from %s', ckpt_dir_or_file) def build_inputs(self, params: exp_cfg.DataConfig, input_context: Optional[tf.distribute.InputContext] = None): """Build input dataset.""" if params.tfds_name: decoder = tfds_factory.get_detection_decoder(params.tfds_name) else: decoder_cfg = params.decoder.get() if params.decoder.type == 'simple_decoder': decoder = tf_example_decoder.TfExampleDecoder( regenerate_source_id=decoder_cfg.regenerate_source_id, attribute_names=decoder_cfg.attribute_names, ) elif params.decoder.type == 'label_map_decoder': decoder = tf_example_label_map_decoder.TfExampleDecoderLabelMap( label_map=decoder_cfg.label_map, regenerate_source_id=decoder_cfg.regenerate_source_id) else: raise ValueError('Unknown decoder type: {}!'.format( params.decoder.type)) parser = retinanet_input.Parser( output_size=self.task_config.model.input_size[:2], min_level=self.task_config.model.min_level, max_level=self.task_config.model.max_level, num_scales=self.task_config.model.anchor.num_scales, aspect_ratios=self.task_config.model.anchor.aspect_ratios, anchor_size=self.task_config.model.anchor.anchor_size, dtype=params.dtype, match_threshold=params.parser.match_threshold, unmatched_threshold=params.parser.unmatched_threshold, box_coder_weights=( self.task_config.model.detection_generator.box_coder_weights ), aug_type=params.parser.aug_type, aug_rand_hflip=params.parser.aug_rand_hflip, aug_scale_min=params.parser.aug_scale_min, aug_scale_max=params.parser.aug_scale_max, skip_crowd_during_training=params.parser.skip_crowd_during_training, max_num_instances=params.parser.max_num_instances, pad=params.parser.pad, keep_aspect_ratio=params.parser.keep_aspect_ratio, ) reader = input_reader_factory.input_reader_generator( params, dataset_fn=dataset_fn.pick_dataset_fn(params.file_type), decoder_fn=decoder.decode, combine_fn=input_reader.create_combine_fn(params), parser_fn=parser.parse_fn(params.is_training)) dataset = reader.read(input_context=input_context) return dataset def build_attribute_loss(self, attribute_heads: List[exp_cfg.AttributeHead], outputs: Mapping[str, Any], labels: Mapping[str, Any], box_sample_weight: tf.Tensor) -> float: """Computes attribute loss. Args: attribute_heads: a list of attribute head configs. outputs: RetinaNet model outputs. labels: RetinaNet labels. box_sample_weight: normalized bounding box sample weights. Returns: Attribute loss of all attribute heads. """ params = self.task_config attribute_loss = 0.0 for head in attribute_heads: if head.name not in labels['attribute_targets']: raise ValueError(f'Attribute {head.name} not found in label targets.') if head.name not in outputs['attribute_outputs']: raise ValueError(f'Attribute {head.name} not found in model outputs.') if head.type == 'regression': y_true_att = loss_utils.multi_level_flatten( labels['attribute_targets'][head.name], last_dim=head.size ) y_pred_att = loss_utils.multi_level_flatten( outputs['attribute_outputs'][head.name], last_dim=head.size ) att_loss_fn = tf_keras.losses.Huber( 1.0, reduction=tf_keras.losses.Reduction.SUM) att_loss = att_loss_fn( y_true=y_true_att, y_pred=y_pred_att, sample_weight=box_sample_weight) elif head.type == 'classification': y_true_att = loss_utils.multi_level_flatten( labels['attribute_targets'][head.name], last_dim=None ) y_true_att = tf.one_hot(y_true_att, head.size) y_pred_att = loss_utils.multi_level_flatten( outputs['attribute_outputs'][head.name], last_dim=head.size ) cls_loss_fn = focal_loss.FocalLoss( alpha=params.losses.focal_loss_alpha, gamma=params.losses.focal_loss_gamma, reduction=tf_keras.losses.Reduction.SUM, ) att_loss = cls_loss_fn( y_true=y_true_att, y_pred=y_pred_att, sample_weight=box_sample_weight, ) else: raise ValueError(f'Attribute type {head.type} not supported.') attribute_loss += att_loss return attribute_loss def build_losses( self, outputs: Mapping[str, Any], labels: Mapping[str, Any], aux_losses: Optional[Any] = None, ): """Build RetinaNet losses.""" params = self.task_config attribute_heads = self.task_config.model.head.attribute_heads cls_loss_fn = focal_loss.FocalLoss( alpha=params.losses.focal_loss_alpha, gamma=params.losses.focal_loss_gamma, reduction=tf_keras.losses.Reduction.SUM) box_loss_fn = tf_keras.losses.Huber( params.losses.huber_loss_delta, reduction=tf_keras.losses.Reduction.SUM) # Sums all positives in a batch for normalization and avoids zero # num_positives_sum, which would lead to inf loss during training cls_sample_weight = labels['cls_weights'] box_sample_weight = labels['box_weights'] num_positives = tf.reduce_sum(box_sample_weight) + 1.0 cls_sample_weight = cls_sample_weight / num_positives box_sample_weight = box_sample_weight / num_positives y_true_cls = loss_utils.multi_level_flatten( labels['cls_targets'], last_dim=None) y_true_cls = tf.one_hot(y_true_cls, params.model.num_classes) y_pred_cls = loss_utils.multi_level_flatten( outputs['cls_outputs'], last_dim=params.model.num_classes) y_true_box = loss_utils.multi_level_flatten( labels['box_targets'], last_dim=4) y_pred_box = loss_utils.multi_level_flatten( outputs['box_outputs'], last_dim=4) cls_loss = cls_loss_fn( y_true=y_true_cls, y_pred=y_pred_cls, sample_weight=cls_sample_weight) box_loss = box_loss_fn( y_true=y_true_box, y_pred=y_pred_box, sample_weight=box_sample_weight) model_loss = cls_loss + params.losses.box_loss_weight * box_loss if attribute_heads: model_loss += self.build_attribute_loss(attribute_heads, outputs, labels, box_sample_weight) total_loss = model_loss if aux_losses: reg_loss = tf.reduce_sum(aux_losses) total_loss = model_loss + reg_loss total_loss = params.losses.loss_weight * total_loss return total_loss, cls_loss, box_loss, model_loss def build_metrics(self, training: bool = True): """Build detection metrics.""" metrics = [] metric_names = ['total_loss', 'cls_loss', 'box_loss', 'model_loss'] for name in metric_names: metrics.append(tf_keras.metrics.Mean(name, dtype=tf.float32)) if not training: if ( self.task_config.validation_data.tfds_name and self.task_config.annotation_file ): raise ValueError( "Can't evaluate using annotation file when TFDS is used." ) if self._task_config.use_coco_metrics: self.coco_metric = coco_evaluator.COCOEvaluator( annotation_file=self.task_config.annotation_file, include_mask=False, per_category_metrics=self.task_config.per_category_metrics, max_num_eval_detections=self.task_config.max_num_eval_detections, ) if self._task_config.use_wod_metrics: # To use Waymo open dataset metrics, please install one of the pip # package `waymo-open-dataset-tf-*` from # https://github.com/waymo-research/waymo-open-dataset/blob/master/docs/quick_start.md#use-pre-compiled-pippip3-packages-for-linux # Note that the package is built with specific tensorflow version and # will produce error if it does not match the tf version that is # currently used. try: from official.vision.evaluation import wod_detection_evaluator # pylint: disable=g-import-not-at-top except ModuleNotFoundError: logging.error('waymo-open-dataset should be installed to enable Waymo' ' evaluator.') raise self.wod_metric = wod_detection_evaluator.WOD2dDetectionEvaluator() return metrics def train_step(self, inputs: Tuple[Any, Any], model: tf_keras.Model, optimizer: tf_keras.optimizers.Optimizer, metrics: Optional[List[Any]] = None): """Does forward and backward. Args: inputs: a dictionary of input tensors. model: the model, forward pass definition. optimizer: the optimizer for this training step. metrics: a nested structure of metrics objects. Returns: A dictionary of logs. """ features, labels = inputs num_replicas = tf.distribute.get_strategy().num_replicas_in_sync with tf.GradientTape() as tape: outputs = model(features, training=True) outputs = tf.nest.map_structure( lambda x: tf.cast(x, tf.float32), outputs) # Computes per-replica loss. loss, cls_loss, box_loss, model_loss = self.build_losses( outputs=outputs, labels=labels, aux_losses=model.losses ) scaled_loss = loss / num_replicas # For mixed_precision policy, when LossScaleOptimizer is used, loss is # scaled for numerical stability. if isinstance(optimizer, tf_keras.mixed_precision.LossScaleOptimizer): scaled_loss = optimizer.get_scaled_loss(scaled_loss) tvars = model.trainable_variables grads = tape.gradient(scaled_loss, tvars) # Scales back gradient when LossScaleOptimizer is used. if isinstance(optimizer, tf_keras.mixed_precision.LossScaleOptimizer): grads = optimizer.get_unscaled_gradients(grads) optimizer.apply_gradients(list(zip(grads, tvars))) logs = {self.loss: loss} all_losses = { 'total_loss': loss, 'cls_loss': cls_loss, 'box_loss': box_loss, 'model_loss': model_loss, } if metrics: for m in metrics: m.update_state(all_losses[m.name]) logs.update({m.name: m.result()}) return logs def validation_step(self, inputs: Tuple[Any, Any], model: tf_keras.Model, metrics: Optional[List[Any]] = None): """Validatation step. Args: inputs: a dictionary of input tensors. model: the keras.Model. metrics: a nested structure of metrics objects. Returns: A dictionary of logs. """ features, labels = inputs outputs = model(features, anchor_boxes=labels['anchor_boxes'], image_shape=labels['image_info'][:, 1, :], training=False) loss, cls_loss, box_loss, model_loss = self.build_losses( outputs=outputs, labels=labels, aux_losses=model.losses ) logs = {self.loss: loss} all_losses = { 'total_loss': loss, 'cls_loss': cls_loss, 'box_loss': box_loss, 'model_loss': model_loss, } if self._task_config.use_coco_metrics: coco_model_outputs = { 'detection_boxes': outputs['detection_boxes'], 'detection_scores': outputs['detection_scores'], 'detection_classes': outputs['detection_classes'], 'num_detections': outputs['num_detections'], 'source_id': labels['groundtruths']['source_id'], 'image_info': labels['image_info'] } logs.update( {self.coco_metric.name: (labels['groundtruths'], coco_model_outputs)}) if self.task_config.use_wod_metrics: wod_model_outputs = { 'detection_boxes': outputs['detection_boxes'], 'detection_scores': outputs['detection_scores'], 'detection_classes': outputs['detection_classes'], 'num_detections': outputs['num_detections'], 'source_id': labels['groundtruths']['source_id'], 'image_info': labels['image_info'] } logs.update( {self.wod_metric.name: (labels['groundtruths'], wod_model_outputs)}) if metrics: for m in metrics: m.update_state(all_losses[m.name]) logs.update({m.name: m.result()}) if ( hasattr(self.task_config, 'allow_image_summary') and self.task_config.allow_image_summary ): logs.update( {'visualization': (tf.cast(features, dtype=tf.float32), outputs)} ) return logs def aggregate_logs(self, state=None, step_outputs=None): if self._task_config.use_coco_metrics: if state is None: self.coco_metric.reset_states() self.coco_metric.update_state(step_outputs[self.coco_metric.name][0], step_outputs[self.coco_metric.name][1]) if self._task_config.use_wod_metrics: if state is None: self.wod_metric.reset_states() self.wod_metric.update_state(step_outputs[self.wod_metric.name][0], step_outputs[self.wod_metric.name][1]) if 'visualization' in step_outputs: # Update detection state for writing summary if there are artifacts for # visualization. if state is None: state = {} state.update(visualization_utils.update_detection_state(step_outputs)) if state is None: # Create an arbitrary state to indicate it's not the first step in the # following calls to this function. state = True return state def reduce_aggregated_logs(self, aggregated_logs, global_step=None): logs = {} if self._task_config.use_coco_metrics: logs.update(self.coco_metric.result()) if self._task_config.use_wod_metrics: logs.update(self.wod_metric.result()) # Add visualization for summary. if isinstance(aggregated_logs, dict) and 'image' in aggregated_logs: validation_outputs = visualization_utils.visualize_outputs( logs=aggregated_logs, task_config=self.task_config ) logs.update(validation_outputs) return logs