deanna-emery's picture
updates
93528c6
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Model defination for the RetinaNet Model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf, tf_keras
from official.legacy.detection.dataloader import mode_keys
from official.legacy.detection.evaluation import factory as eval_factory
from official.legacy.detection.modeling import base_model
from official.legacy.detection.modeling import losses
from official.legacy.detection.modeling.architecture import factory
from official.legacy.detection.ops import postprocess_ops
class RetinanetModel(base_model.Model):
"""RetinaNet model function."""
def __init__(self, params):
super(RetinanetModel, self).__init__(params)
# For eval metrics.
self._params = params
# Architecture generators.
self._backbone_fn = factory.backbone_generator(params)
self._fpn_fn = factory.multilevel_features_generator(params)
self._head_fn = factory.retinanet_head_generator(params)
# Loss function.
self._cls_loss_fn = losses.RetinanetClassLoss(
params.retinanet_loss, params.architecture.num_classes)
self._box_loss_fn = losses.RetinanetBoxLoss(params.retinanet_loss)
self._box_loss_weight = params.retinanet_loss.box_loss_weight
self._keras_model = None
# Predict function.
self._generate_detections_fn = postprocess_ops.MultilevelDetectionGenerator(
params.architecture.min_level, params.architecture.max_level,
params.postprocess)
self._transpose_input = params.train.transpose_input
assert not self._transpose_input, 'Transpose input is not supported.'
# Input layer.
self._input_layer = tf_keras.layers.Input(
shape=(None, None, params.retinanet_parser.num_channels),
name='',
dtype=tf.bfloat16 if self._use_bfloat16 else tf.float32)
def build_outputs(self, inputs, mode):
# If the input image is transposed (from NHWC to HWCN), we need to revert it
# back to the original shape before it's used in the computation.
if self._transpose_input:
inputs = tf.transpose(inputs, [3, 0, 1, 2])
backbone_features = self._backbone_fn(
inputs, is_training=(mode == mode_keys.TRAIN))
fpn_features = self._fpn_fn(
backbone_features, is_training=(mode == mode_keys.TRAIN))
cls_outputs, box_outputs = self._head_fn(
fpn_features, is_training=(mode == mode_keys.TRAIN))
if self._use_bfloat16:
levels = cls_outputs.keys()
for level in levels:
cls_outputs[level] = tf.cast(cls_outputs[level], tf.float32)
box_outputs[level] = tf.cast(box_outputs[level], tf.float32)
model_outputs = {
'cls_outputs': cls_outputs,
'box_outputs': box_outputs,
}
return model_outputs
def build_loss_fn(self):
if self._keras_model is None:
raise ValueError('build_loss_fn() must be called after build_model().')
filter_fn = self.make_filter_trainable_variables_fn()
trainable_variables = filter_fn(self._keras_model.trainable_variables)
def _total_loss_fn(labels, outputs):
cls_loss = self._cls_loss_fn(outputs['cls_outputs'],
labels['cls_targets'],
labels['num_positives'])
box_loss = self._box_loss_fn(outputs['box_outputs'],
labels['box_targets'],
labels['num_positives'])
model_loss = cls_loss + self._box_loss_weight * box_loss
l2_regularization_loss = self.weight_decay_loss(trainable_variables)
total_loss = model_loss + l2_regularization_loss
return {
'total_loss': total_loss,
'cls_loss': cls_loss,
'box_loss': box_loss,
'model_loss': model_loss,
'l2_regularization_loss': l2_regularization_loss,
}
return _total_loss_fn
def build_model(self, params, mode=None):
if self._keras_model is None:
outputs = self.model_outputs(self._input_layer, mode)
model = tf_keras.models.Model(
inputs=self._input_layer, outputs=outputs, name='retinanet')
assert model is not None, 'Fail to build tf_keras.Model.'
model.optimizer = self.build_optimizer()
self._keras_model = model
return self._keras_model
def post_processing(self, labels, outputs):
# TODO(yeqing): Moves the output related part into build_outputs.
required_output_fields = ['cls_outputs', 'box_outputs']
for field in required_output_fields:
if field not in outputs:
raise ValueError('"%s" is missing in outputs, requried %s found %s' %
(field, required_output_fields, outputs.keys()))
required_label_fields = ['image_info', 'groundtruths']
for field in required_label_fields:
if field not in labels:
raise ValueError('"%s" is missing in outputs, requried %s found %s' %
(field, required_label_fields, labels.keys()))
boxes, scores, classes, valid_detections = self._generate_detections_fn(
outputs['box_outputs'], outputs['cls_outputs'], labels['anchor_boxes'],
labels['image_info'][:, 1:2, :])
# Discards the old output tensors to save memory. The `cls_outputs` and
# `box_outputs` are pretty big and could potentiall lead to memory issue.
outputs = {
'source_id': labels['groundtruths']['source_id'],
'image_info': labels['image_info'],
'num_detections': valid_detections,
'detection_boxes': boxes,
'detection_classes': classes,
'detection_scores': scores,
}
if 'groundtruths' in labels:
labels['source_id'] = labels['groundtruths']['source_id']
labels['boxes'] = labels['groundtruths']['boxes']
labels['classes'] = labels['groundtruths']['classes']
labels['areas'] = labels['groundtruths']['areas']
labels['is_crowds'] = labels['groundtruths']['is_crowds']
return labels, outputs
def eval_metrics(self):
return eval_factory.evaluator_generator(self._params.eval)