deanna-emery's picture
updates
93528c6
raw
history blame
18.2 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""RetinaNet task definition."""
from typing import Any, List, Mapping, Optional, Tuple
from absl import logging
import tensorflow as tf, tf_keras
from official.common import dataset_fn
from official.core import base_task
from official.core import task_factory
from official.vision.configs import retinanet as exp_cfg
from official.vision.dataloaders import input_reader
from official.vision.dataloaders import input_reader_factory
from official.vision.dataloaders import retinanet_input
from official.vision.dataloaders import tf_example_decoder
from official.vision.dataloaders import tfds_factory
from official.vision.dataloaders import tf_example_label_map_decoder
from official.vision.evaluation import coco_evaluator
from official.vision.losses import focal_loss
from official.vision.losses import loss_utils
from official.vision.modeling import factory
from official.vision.utils.object_detection import visualization_utils
@task_factory.register_task_cls(exp_cfg.RetinaNetTask)
class RetinaNetTask(base_task.Task):
"""A single-replica view of training procedure.
RetinaNet task provides artifacts for training/evalution procedures, including
loading/iterating over Datasets, initializing the model, calculating the loss,
post-processing, and customized metrics with reduction.
"""
def build_model(self):
"""Build RetinaNet model."""
input_specs = tf_keras.layers.InputSpec(
shape=[None] + self.task_config.model.input_size)
l2_weight_decay = self.task_config.losses.l2_weight_decay
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
l2_regularizer = (tf_keras.regularizers.l2(
l2_weight_decay / 2.0) if l2_weight_decay else None)
model = factory.build_retinanet(
input_specs=input_specs,
model_config=self.task_config.model,
l2_regularizer=l2_regularizer)
if self.task_config.freeze_backbone:
model.backbone.trainable = False
return model
def initialize(self, model: tf_keras.Model):
"""Loading pretrained checkpoint."""
if not self.task_config.init_checkpoint:
return
ckpt_dir_or_file = self.task_config.init_checkpoint
if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
# Restoring checkpoint.
if self.task_config.init_checkpoint_modules == 'all':
ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
else:
ckpt_items = {}
if 'backbone' in self.task_config.init_checkpoint_modules:
ckpt_items.update(backbone=model.backbone)
if 'decoder' in self.task_config.init_checkpoint_modules:
ckpt_items.update(decoder=model.decoder)
ckpt = tf.train.Checkpoint(**ckpt_items)
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file)
def build_inputs(self,
params: exp_cfg.DataConfig,
input_context: Optional[tf.distribute.InputContext] = None):
"""Build input dataset."""
if params.tfds_name:
decoder = tfds_factory.get_detection_decoder(params.tfds_name)
else:
decoder_cfg = params.decoder.get()
if params.decoder.type == 'simple_decoder':
decoder = tf_example_decoder.TfExampleDecoder(
regenerate_source_id=decoder_cfg.regenerate_source_id,
attribute_names=decoder_cfg.attribute_names,
)
elif params.decoder.type == 'label_map_decoder':
decoder = tf_example_label_map_decoder.TfExampleDecoderLabelMap(
label_map=decoder_cfg.label_map,
regenerate_source_id=decoder_cfg.regenerate_source_id)
else:
raise ValueError('Unknown decoder type: {}!'.format(
params.decoder.type))
parser = retinanet_input.Parser(
output_size=self.task_config.model.input_size[:2],
min_level=self.task_config.model.min_level,
max_level=self.task_config.model.max_level,
num_scales=self.task_config.model.anchor.num_scales,
aspect_ratios=self.task_config.model.anchor.aspect_ratios,
anchor_size=self.task_config.model.anchor.anchor_size,
dtype=params.dtype,
match_threshold=params.parser.match_threshold,
unmatched_threshold=params.parser.unmatched_threshold,
box_coder_weights=(
self.task_config.model.detection_generator.box_coder_weights
),
aug_type=params.parser.aug_type,
aug_rand_hflip=params.parser.aug_rand_hflip,
aug_scale_min=params.parser.aug_scale_min,
aug_scale_max=params.parser.aug_scale_max,
skip_crowd_during_training=params.parser.skip_crowd_during_training,
max_num_instances=params.parser.max_num_instances,
pad=params.parser.pad,
keep_aspect_ratio=params.parser.keep_aspect_ratio,
)
reader = input_reader_factory.input_reader_generator(
params,
dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
decoder_fn=decoder.decode,
combine_fn=input_reader.create_combine_fn(params),
parser_fn=parser.parse_fn(params.is_training))
dataset = reader.read(input_context=input_context)
return dataset
def build_attribute_loss(self,
attribute_heads: List[exp_cfg.AttributeHead],
outputs: Mapping[str, Any],
labels: Mapping[str, Any],
box_sample_weight: tf.Tensor) -> float:
"""Computes attribute loss.
Args:
attribute_heads: a list of attribute head configs.
outputs: RetinaNet model outputs.
labels: RetinaNet labels.
box_sample_weight: normalized bounding box sample weights.
Returns:
Attribute loss of all attribute heads.
"""
params = self.task_config
attribute_loss = 0.0
for head in attribute_heads:
if head.name not in labels['attribute_targets']:
raise ValueError(f'Attribute {head.name} not found in label targets.')
if head.name not in outputs['attribute_outputs']:
raise ValueError(f'Attribute {head.name} not found in model outputs.')
if head.type == 'regression':
y_true_att = loss_utils.multi_level_flatten(
labels['attribute_targets'][head.name], last_dim=head.size
)
y_pred_att = loss_utils.multi_level_flatten(
outputs['attribute_outputs'][head.name], last_dim=head.size
)
att_loss_fn = tf_keras.losses.Huber(
1.0, reduction=tf_keras.losses.Reduction.SUM)
att_loss = att_loss_fn(
y_true=y_true_att,
y_pred=y_pred_att,
sample_weight=box_sample_weight)
elif head.type == 'classification':
y_true_att = loss_utils.multi_level_flatten(
labels['attribute_targets'][head.name], last_dim=None
)
y_true_att = tf.one_hot(y_true_att, head.size)
y_pred_att = loss_utils.multi_level_flatten(
outputs['attribute_outputs'][head.name], last_dim=head.size
)
cls_loss_fn = focal_loss.FocalLoss(
alpha=params.losses.focal_loss_alpha,
gamma=params.losses.focal_loss_gamma,
reduction=tf_keras.losses.Reduction.SUM,
)
att_loss = cls_loss_fn(
y_true=y_true_att,
y_pred=y_pred_att,
sample_weight=box_sample_weight,
)
else:
raise ValueError(f'Attribute type {head.type} not supported.')
attribute_loss += att_loss
return attribute_loss
def build_losses(
self,
outputs: Mapping[str, Any],
labels: Mapping[str, Any],
aux_losses: Optional[Any] = None,
):
"""Build RetinaNet losses."""
params = self.task_config
attribute_heads = self.task_config.model.head.attribute_heads
cls_loss_fn = focal_loss.FocalLoss(
alpha=params.losses.focal_loss_alpha,
gamma=params.losses.focal_loss_gamma,
reduction=tf_keras.losses.Reduction.SUM)
box_loss_fn = tf_keras.losses.Huber(
params.losses.huber_loss_delta, reduction=tf_keras.losses.Reduction.SUM)
# Sums all positives in a batch for normalization and avoids zero
# num_positives_sum, which would lead to inf loss during training
cls_sample_weight = labels['cls_weights']
box_sample_weight = labels['box_weights']
num_positives = tf.reduce_sum(box_sample_weight) + 1.0
cls_sample_weight = cls_sample_weight / num_positives
box_sample_weight = box_sample_weight / num_positives
y_true_cls = loss_utils.multi_level_flatten(
labels['cls_targets'], last_dim=None)
y_true_cls = tf.one_hot(y_true_cls, params.model.num_classes)
y_pred_cls = loss_utils.multi_level_flatten(
outputs['cls_outputs'], last_dim=params.model.num_classes)
y_true_box = loss_utils.multi_level_flatten(
labels['box_targets'], last_dim=4)
y_pred_box = loss_utils.multi_level_flatten(
outputs['box_outputs'], last_dim=4)
cls_loss = cls_loss_fn(
y_true=y_true_cls, y_pred=y_pred_cls, sample_weight=cls_sample_weight)
box_loss = box_loss_fn(
y_true=y_true_box, y_pred=y_pred_box, sample_weight=box_sample_weight)
model_loss = cls_loss + params.losses.box_loss_weight * box_loss
if attribute_heads:
model_loss += self.build_attribute_loss(attribute_heads, outputs, labels,
box_sample_weight)
total_loss = model_loss
if aux_losses:
reg_loss = tf.reduce_sum(aux_losses)
total_loss = model_loss + reg_loss
total_loss = params.losses.loss_weight * total_loss
return total_loss, cls_loss, box_loss, model_loss
def build_metrics(self, training: bool = True):
"""Build detection metrics."""
metrics = []
metric_names = ['total_loss', 'cls_loss', 'box_loss', 'model_loss']
for name in metric_names:
metrics.append(tf_keras.metrics.Mean(name, dtype=tf.float32))
if not training:
if (
self.task_config.validation_data.tfds_name
and self.task_config.annotation_file
):
raise ValueError(
"Can't evaluate using annotation file when TFDS is used."
)
if self._task_config.use_coco_metrics:
self.coco_metric = coco_evaluator.COCOEvaluator(
annotation_file=self.task_config.annotation_file,
include_mask=False,
per_category_metrics=self.task_config.per_category_metrics,
max_num_eval_detections=self.task_config.max_num_eval_detections,
)
if self._task_config.use_wod_metrics:
# To use Waymo open dataset metrics, please install one of the pip
# package `waymo-open-dataset-tf-*` from
# https://github.com/waymo-research/waymo-open-dataset/blob/master/docs/quick_start.md#use-pre-compiled-pippip3-packages-for-linux
# Note that the package is built with specific tensorflow version and
# will produce error if it does not match the tf version that is
# currently used.
try:
from official.vision.evaluation import wod_detection_evaluator # pylint: disable=g-import-not-at-top
except ModuleNotFoundError:
logging.error('waymo-open-dataset should be installed to enable Waymo'
' evaluator.')
raise
self.wod_metric = wod_detection_evaluator.WOD2dDetectionEvaluator()
return metrics
def train_step(self,
inputs: Tuple[Any, Any],
model: tf_keras.Model,
optimizer: tf_keras.optimizers.Optimizer,
metrics: Optional[List[Any]] = None):
"""Does forward and backward.
Args:
inputs: a dictionary of input tensors.
model: the model, forward pass definition.
optimizer: the optimizer for this training step.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features, labels = inputs
num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
with tf.GradientTape() as tape:
outputs = model(features, training=True)
outputs = tf.nest.map_structure(
lambda x: tf.cast(x, tf.float32), outputs)
# Computes per-replica loss.
loss, cls_loss, box_loss, model_loss = self.build_losses(
outputs=outputs, labels=labels, aux_losses=model.losses
)
scaled_loss = loss / num_replicas
# For mixed_precision policy, when LossScaleOptimizer is used, loss is
# scaled for numerical stability.
if isinstance(optimizer, tf_keras.mixed_precision.LossScaleOptimizer):
scaled_loss = optimizer.get_scaled_loss(scaled_loss)
tvars = model.trainable_variables
grads = tape.gradient(scaled_loss, tvars)
# Scales back gradient when LossScaleOptimizer is used.
if isinstance(optimizer, tf_keras.mixed_precision.LossScaleOptimizer):
grads = optimizer.get_unscaled_gradients(grads)
optimizer.apply_gradients(list(zip(grads, tvars)))
logs = {self.loss: loss}
all_losses = {
'total_loss': loss,
'cls_loss': cls_loss,
'box_loss': box_loss,
'model_loss': model_loss,
}
if metrics:
for m in metrics:
m.update_state(all_losses[m.name])
logs.update({m.name: m.result()})
return logs
def validation_step(self,
inputs: Tuple[Any, Any],
model: tf_keras.Model,
metrics: Optional[List[Any]] = None):
"""Validatation step.
Args:
inputs: a dictionary of input tensors.
model: the keras.Model.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features, labels = inputs
outputs = model(features, anchor_boxes=labels['anchor_boxes'],
image_shape=labels['image_info'][:, 1, :],
training=False)
loss, cls_loss, box_loss, model_loss = self.build_losses(
outputs=outputs, labels=labels, aux_losses=model.losses
)
logs = {self.loss: loss}
all_losses = {
'total_loss': loss,
'cls_loss': cls_loss,
'box_loss': box_loss,
'model_loss': model_loss,
}
if self._task_config.use_coco_metrics:
coco_model_outputs = {
'detection_boxes': outputs['detection_boxes'],
'detection_scores': outputs['detection_scores'],
'detection_classes': outputs['detection_classes'],
'num_detections': outputs['num_detections'],
'source_id': labels['groundtruths']['source_id'],
'image_info': labels['image_info']
}
logs.update(
{self.coco_metric.name: (labels['groundtruths'], coco_model_outputs)})
if self.task_config.use_wod_metrics:
wod_model_outputs = {
'detection_boxes': outputs['detection_boxes'],
'detection_scores': outputs['detection_scores'],
'detection_classes': outputs['detection_classes'],
'num_detections': outputs['num_detections'],
'source_id': labels['groundtruths']['source_id'],
'image_info': labels['image_info']
}
logs.update(
{self.wod_metric.name: (labels['groundtruths'], wod_model_outputs)})
if metrics:
for m in metrics:
m.update_state(all_losses[m.name])
logs.update({m.name: m.result()})
if (
hasattr(self.task_config, 'allow_image_summary')
and self.task_config.allow_image_summary
):
logs.update(
{'visualization': (tf.cast(features, dtype=tf.float32), outputs)}
)
return logs
def aggregate_logs(self, state=None, step_outputs=None):
if self._task_config.use_coco_metrics:
if state is None:
self.coco_metric.reset_states()
self.coco_metric.update_state(step_outputs[self.coco_metric.name][0],
step_outputs[self.coco_metric.name][1])
if self._task_config.use_wod_metrics:
if state is None:
self.wod_metric.reset_states()
self.wod_metric.update_state(step_outputs[self.wod_metric.name][0],
step_outputs[self.wod_metric.name][1])
if 'visualization' in step_outputs:
# Update detection state for writing summary if there are artifacts for
# visualization.
if state is None:
state = {}
state.update(visualization_utils.update_detection_state(step_outputs))
if state is None:
# Create an arbitrary state to indicate it's not the first step in the
# following calls to this function.
state = True
return state
def reduce_aggregated_logs(self, aggregated_logs, global_step=None):
logs = {}
if self._task_config.use_coco_metrics:
logs.update(self.coco_metric.result())
if self._task_config.use_wod_metrics:
logs.update(self.wod_metric.result())
# Add visualization for summary.
if isinstance(aggregated_logs, dict) and 'image' in aggregated_logs:
validation_outputs = visualization_utils.visualize_outputs(
logs=aggregated_logs, task_config=self.task_config
)
logs.update(validation_outputs)
return logs