deanna-emery's picture
updates
93528c6
raw
history blame
12.1 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Model definition for the ShapeMask Model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf, tf_keras
from official.legacy.detection.dataloader import anchor
from official.legacy.detection.dataloader import mode_keys
from official.legacy.detection.evaluation import factory as eval_factory
from official.legacy.detection.modeling import base_model
from official.legacy.detection.modeling import losses
from official.legacy.detection.modeling.architecture import factory
from official.legacy.detection.ops import postprocess_ops
from official.legacy.detection.utils import box_utils
class ShapeMaskModel(base_model.Model):
"""ShapeMask model function."""
def __init__(self, params):
super(ShapeMaskModel, self).__init__(params)
self._params = params
self._keras_model = None
# Architecture generators.
self._backbone_fn = factory.backbone_generator(params)
self._fpn_fn = factory.multilevel_features_generator(params)
self._retinanet_head_fn = factory.retinanet_head_generator(params)
self._shape_prior_head_fn = factory.shapeprior_head_generator(params)
self._coarse_mask_fn = factory.coarsemask_head_generator(params)
self._fine_mask_fn = factory.finemask_head_generator(params)
# Loss functions.
self._cls_loss_fn = losses.RetinanetClassLoss(
params.retinanet_loss, params.architecture.num_classes)
self._box_loss_fn = losses.RetinanetBoxLoss(params.retinanet_loss)
self._box_loss_weight = params.retinanet_loss.box_loss_weight
# Mask loss function.
self._shapemask_prior_loss_fn = losses.ShapemaskMseLoss()
self._shapemask_loss_fn = losses.ShapemaskLoss()
self._shape_prior_loss_weight = (
params.shapemask_loss.shape_prior_loss_weight)
self._coarse_mask_loss_weight = (
params.shapemask_loss.coarse_mask_loss_weight)
self._fine_mask_loss_weight = (params.shapemask_loss.fine_mask_loss_weight)
# Predict function.
self._generate_detections_fn = postprocess_ops.MultilevelDetectionGenerator(
params.architecture.min_level, params.architecture.max_level,
params.postprocess)
def build_outputs(self, inputs, mode):
is_training = mode == mode_keys.TRAIN
images = inputs['image']
if 'anchor_boxes' in inputs:
anchor_boxes = inputs['anchor_boxes']
else:
anchor_boxes = anchor.Anchor(
self._params.architecture.min_level,
self._params.architecture.max_level, self._params.anchor.num_scales,
self._params.anchor.aspect_ratios, self._params.anchor.anchor_size,
images.get_shape().as_list()[1:3]).multilevel_boxes
batch_size = tf.shape(images)[0]
for level in anchor_boxes:
anchor_boxes[level] = tf.tile(
tf.expand_dims(anchor_boxes[level], 0), [batch_size, 1, 1, 1])
backbone_features = self._backbone_fn(images, is_training=is_training)
fpn_features = self._fpn_fn(backbone_features, is_training=is_training)
cls_outputs, box_outputs = self._retinanet_head_fn(
fpn_features, is_training=is_training)
valid_boxes, valid_scores, valid_classes, valid_detections = (
self._generate_detections_fn(box_outputs, cls_outputs, anchor_boxes,
inputs['image_info'][:, 1:2, :]))
image_size = images.get_shape().as_list()[1:3]
valid_outer_boxes = box_utils.compute_outer_boxes(
tf.reshape(valid_boxes, [-1, 4]),
image_size,
scale=self._params.shapemask_parser.outer_box_scale)
valid_outer_boxes = tf.reshape(valid_outer_boxes, tf.shape(valid_boxes))
# Wrapping if else code paths into a layer to make the checkpoint loadable
# in prediction mode.
class SampledBoxesLayer(tf_keras.layers.Layer):
"""ShapeMask model function."""
def call(self, inputs, val_boxes, val_classes, val_outer_boxes, training):
if training:
boxes = inputs['mask_boxes']
outer_boxes = inputs['mask_outer_boxes']
classes = inputs['mask_classes']
else:
boxes = val_boxes
classes = val_classes
outer_boxes = val_outer_boxes
return boxes, classes, outer_boxes
boxes, classes, outer_boxes = SampledBoxesLayer()(
inputs,
valid_boxes,
valid_classes,
valid_outer_boxes,
training=is_training)
instance_features, prior_masks = self._shape_prior_head_fn(
fpn_features, boxes, outer_boxes, classes, is_training)
coarse_mask_logits = self._coarse_mask_fn(instance_features, prior_masks,
classes, is_training)
fine_mask_logits = self._fine_mask_fn(instance_features, coarse_mask_logits,
classes, is_training)
model_outputs = {
'cls_outputs': cls_outputs,
'box_outputs': box_outputs,
'fine_mask_logits': fine_mask_logits,
'coarse_mask_logits': coarse_mask_logits,
'prior_masks': prior_masks,
}
if not is_training:
model_outputs.update({
'num_detections': valid_detections,
'detection_boxes': valid_boxes,
'detection_outer_boxes': valid_outer_boxes,
'detection_masks': fine_mask_logits,
'detection_classes': valid_classes,
'detection_scores': valid_scores,
})
return model_outputs
def build_loss_fn(self):
if self._keras_model is None:
raise ValueError('build_loss_fn() must be called after build_model().')
filter_fn = self.make_filter_trainable_variables_fn()
trainable_variables = filter_fn(self._keras_model.trainable_variables)
def _total_loss_fn(labels, outputs):
cls_loss = self._cls_loss_fn(outputs['cls_outputs'],
labels['cls_targets'],
labels['num_positives'])
box_loss = self._box_loss_fn(outputs['box_outputs'],
labels['box_targets'],
labels['num_positives'])
# Adds Shapemask model losses.
shape_prior_loss = self._shapemask_prior_loss_fn(outputs['prior_masks'],
labels['mask_targets'],
labels['mask_is_valid'])
coarse_mask_loss = self._shapemask_loss_fn(outputs['coarse_mask_logits'],
labels['mask_targets'],
labels['mask_is_valid'])
fine_mask_loss = self._shapemask_loss_fn(outputs['fine_mask_logits'],
labels['fine_mask_targets'],
labels['mask_is_valid'])
model_loss = (
cls_loss + self._box_loss_weight * box_loss +
shape_prior_loss * self._shape_prior_loss_weight +
coarse_mask_loss * self._coarse_mask_loss_weight +
fine_mask_loss * self._fine_mask_loss_weight)
l2_regularization_loss = self.weight_decay_loss(trainable_variables)
total_loss = model_loss + l2_regularization_loss
shapemask_losses = {
'total_loss': total_loss,
'loss': total_loss,
'retinanet_cls_loss': cls_loss,
'l2_regularization_loss': l2_regularization_loss,
'retinanet_box_loss': box_loss,
'shapemask_prior_loss': shape_prior_loss,
'shapemask_coarse_mask_loss': coarse_mask_loss,
'shapemask_fine_mask_loss': fine_mask_loss,
'model_loss': model_loss,
}
return shapemask_losses
return _total_loss_fn
def build_input_layers(self, params, mode):
is_training = mode == mode_keys.TRAIN
input_shape = (
params.shapemask_parser.output_size +
[params.shapemask_parser.num_channels])
if is_training:
batch_size = params.train.batch_size
input_layer = {
'image':
tf_keras.layers.Input(
shape=input_shape,
batch_size=batch_size,
name='image',
dtype=tf.bfloat16 if self._use_bfloat16 else tf.float32),
'image_info':
tf_keras.layers.Input(
shape=[4, 2], batch_size=batch_size, name='image_info'),
'mask_classes':
tf_keras.layers.Input(
shape=[params.shapemask_parser.num_sampled_masks],
batch_size=batch_size,
name='mask_classes',
dtype=tf.int64),
'mask_outer_boxes':
tf_keras.layers.Input(
shape=[params.shapemask_parser.num_sampled_masks, 4],
batch_size=batch_size,
name='mask_outer_boxes',
dtype=tf.float32),
'mask_boxes':
tf_keras.layers.Input(
shape=[params.shapemask_parser.num_sampled_masks, 4],
batch_size=batch_size,
name='mask_boxes',
dtype=tf.float32),
}
else:
batch_size = params.eval.batch_size
input_layer = {
'image':
tf_keras.layers.Input(
shape=input_shape,
batch_size=batch_size,
name='image',
dtype=tf.bfloat16 if self._use_bfloat16 else tf.float32),
'image_info':
tf_keras.layers.Input(
shape=[4, 2], batch_size=batch_size, name='image_info'),
}
return input_layer
def build_model(self, params, mode):
if self._keras_model is None:
input_layers = self.build_input_layers(self._params, mode)
outputs = self.model_outputs(input_layers, mode)
model = tf_keras.models.Model(
inputs=input_layers, outputs=outputs, name='shapemask')
assert model is not None, 'Fail to build tf_keras.Model.'
model.optimizer = self.build_optimizer()
self._keras_model = model
return self._keras_model
def post_processing(self, labels, outputs):
required_output_fields = [
'num_detections', 'detection_boxes', 'detection_classes',
'detection_masks', 'detection_scores'
]
for field in required_output_fields:
if field not in outputs:
raise ValueError(
'"{}" is missing in outputs, requried {} found {}'.format(
field, required_output_fields, outputs.keys()))
required_label_fields = ['image_info']
for field in required_label_fields:
if field not in labels:
raise ValueError(
'"{}" is missing in labels, requried {} found {}'.format(
field, required_label_fields, labels.keys()))
predictions = {
'image_info': labels['image_info'],
'num_detections': outputs['num_detections'],
'detection_boxes': outputs['detection_boxes'],
'detection_outer_boxes': outputs['detection_outer_boxes'],
'detection_classes': outputs['detection_classes'],
'detection_scores': outputs['detection_scores'],
'detection_masks': outputs['detection_masks'],
}
if 'groundtruths' in labels:
predictions['source_id'] = labels['groundtruths']['source_id']
labels = labels['groundtruths']
return labels, predictions
def eval_metrics(self):
return eval_factory.evaluator_generator(self._params.eval)