# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Model definition for the ShapeMask Model.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import tensorflow as tf, tf_keras from official.legacy.detection.dataloader import anchor from official.legacy.detection.dataloader import mode_keys from official.legacy.detection.evaluation import factory as eval_factory from official.legacy.detection.modeling import base_model from official.legacy.detection.modeling import losses from official.legacy.detection.modeling.architecture import factory from official.legacy.detection.ops import postprocess_ops from official.legacy.detection.utils import box_utils class ShapeMaskModel(base_model.Model): """ShapeMask model function.""" def __init__(self, params): super(ShapeMaskModel, self).__init__(params) self._params = params self._keras_model = None # Architecture generators. self._backbone_fn = factory.backbone_generator(params) self._fpn_fn = factory.multilevel_features_generator(params) self._retinanet_head_fn = factory.retinanet_head_generator(params) self._shape_prior_head_fn = factory.shapeprior_head_generator(params) self._coarse_mask_fn = factory.coarsemask_head_generator(params) self._fine_mask_fn = factory.finemask_head_generator(params) # Loss functions. self._cls_loss_fn = losses.RetinanetClassLoss( params.retinanet_loss, params.architecture.num_classes) self._box_loss_fn = losses.RetinanetBoxLoss(params.retinanet_loss) self._box_loss_weight = params.retinanet_loss.box_loss_weight # Mask loss function. self._shapemask_prior_loss_fn = losses.ShapemaskMseLoss() self._shapemask_loss_fn = losses.ShapemaskLoss() self._shape_prior_loss_weight = ( params.shapemask_loss.shape_prior_loss_weight) self._coarse_mask_loss_weight = ( params.shapemask_loss.coarse_mask_loss_weight) self._fine_mask_loss_weight = (params.shapemask_loss.fine_mask_loss_weight) # Predict function. self._generate_detections_fn = postprocess_ops.MultilevelDetectionGenerator( params.architecture.min_level, params.architecture.max_level, params.postprocess) def build_outputs(self, inputs, mode): is_training = mode == mode_keys.TRAIN images = inputs['image'] if 'anchor_boxes' in inputs: anchor_boxes = inputs['anchor_boxes'] else: anchor_boxes = anchor.Anchor( self._params.architecture.min_level, self._params.architecture.max_level, self._params.anchor.num_scales, self._params.anchor.aspect_ratios, self._params.anchor.anchor_size, images.get_shape().as_list()[1:3]).multilevel_boxes batch_size = tf.shape(images)[0] for level in anchor_boxes: anchor_boxes[level] = tf.tile( tf.expand_dims(anchor_boxes[level], 0), [batch_size, 1, 1, 1]) backbone_features = self._backbone_fn(images, is_training=is_training) fpn_features = self._fpn_fn(backbone_features, is_training=is_training) cls_outputs, box_outputs = self._retinanet_head_fn( fpn_features, is_training=is_training) valid_boxes, valid_scores, valid_classes, valid_detections = ( self._generate_detections_fn(box_outputs, cls_outputs, anchor_boxes, inputs['image_info'][:, 1:2, :])) image_size = images.get_shape().as_list()[1:3] valid_outer_boxes = box_utils.compute_outer_boxes( tf.reshape(valid_boxes, [-1, 4]), image_size, scale=self._params.shapemask_parser.outer_box_scale) valid_outer_boxes = tf.reshape(valid_outer_boxes, tf.shape(valid_boxes)) # Wrapping if else code paths into a layer to make the checkpoint loadable # in prediction mode. class SampledBoxesLayer(tf_keras.layers.Layer): """ShapeMask model function.""" def call(self, inputs, val_boxes, val_classes, val_outer_boxes, training): if training: boxes = inputs['mask_boxes'] outer_boxes = inputs['mask_outer_boxes'] classes = inputs['mask_classes'] else: boxes = val_boxes classes = val_classes outer_boxes = val_outer_boxes return boxes, classes, outer_boxes boxes, classes, outer_boxes = SampledBoxesLayer()( inputs, valid_boxes, valid_classes, valid_outer_boxes, training=is_training) instance_features, prior_masks = self._shape_prior_head_fn( fpn_features, boxes, outer_boxes, classes, is_training) coarse_mask_logits = self._coarse_mask_fn(instance_features, prior_masks, classes, is_training) fine_mask_logits = self._fine_mask_fn(instance_features, coarse_mask_logits, classes, is_training) model_outputs = { 'cls_outputs': cls_outputs, 'box_outputs': box_outputs, 'fine_mask_logits': fine_mask_logits, 'coarse_mask_logits': coarse_mask_logits, 'prior_masks': prior_masks, } if not is_training: model_outputs.update({ 'num_detections': valid_detections, 'detection_boxes': valid_boxes, 'detection_outer_boxes': valid_outer_boxes, 'detection_masks': fine_mask_logits, 'detection_classes': valid_classes, 'detection_scores': valid_scores, }) return model_outputs def build_loss_fn(self): if self._keras_model is None: raise ValueError('build_loss_fn() must be called after build_model().') filter_fn = self.make_filter_trainable_variables_fn() trainable_variables = filter_fn(self._keras_model.trainable_variables) def _total_loss_fn(labels, outputs): cls_loss = self._cls_loss_fn(outputs['cls_outputs'], labels['cls_targets'], labels['num_positives']) box_loss = self._box_loss_fn(outputs['box_outputs'], labels['box_targets'], labels['num_positives']) # Adds Shapemask model losses. shape_prior_loss = self._shapemask_prior_loss_fn(outputs['prior_masks'], labels['mask_targets'], labels['mask_is_valid']) coarse_mask_loss = self._shapemask_loss_fn(outputs['coarse_mask_logits'], labels['mask_targets'], labels['mask_is_valid']) fine_mask_loss = self._shapemask_loss_fn(outputs['fine_mask_logits'], labels['fine_mask_targets'], labels['mask_is_valid']) model_loss = ( cls_loss + self._box_loss_weight * box_loss + shape_prior_loss * self._shape_prior_loss_weight + coarse_mask_loss * self._coarse_mask_loss_weight + fine_mask_loss * self._fine_mask_loss_weight) l2_regularization_loss = self.weight_decay_loss(trainable_variables) total_loss = model_loss + l2_regularization_loss shapemask_losses = { 'total_loss': total_loss, 'loss': total_loss, 'retinanet_cls_loss': cls_loss, 'l2_regularization_loss': l2_regularization_loss, 'retinanet_box_loss': box_loss, 'shapemask_prior_loss': shape_prior_loss, 'shapemask_coarse_mask_loss': coarse_mask_loss, 'shapemask_fine_mask_loss': fine_mask_loss, 'model_loss': model_loss, } return shapemask_losses return _total_loss_fn def build_input_layers(self, params, mode): is_training = mode == mode_keys.TRAIN input_shape = ( params.shapemask_parser.output_size + [params.shapemask_parser.num_channels]) if is_training: batch_size = params.train.batch_size input_layer = { 'image': tf_keras.layers.Input( shape=input_shape, batch_size=batch_size, name='image', dtype=tf.bfloat16 if self._use_bfloat16 else tf.float32), 'image_info': tf_keras.layers.Input( shape=[4, 2], batch_size=batch_size, name='image_info'), 'mask_classes': tf_keras.layers.Input( shape=[params.shapemask_parser.num_sampled_masks], batch_size=batch_size, name='mask_classes', dtype=tf.int64), 'mask_outer_boxes': tf_keras.layers.Input( shape=[params.shapemask_parser.num_sampled_masks, 4], batch_size=batch_size, name='mask_outer_boxes', dtype=tf.float32), 'mask_boxes': tf_keras.layers.Input( shape=[params.shapemask_parser.num_sampled_masks, 4], batch_size=batch_size, name='mask_boxes', dtype=tf.float32), } else: batch_size = params.eval.batch_size input_layer = { 'image': tf_keras.layers.Input( shape=input_shape, batch_size=batch_size, name='image', dtype=tf.bfloat16 if self._use_bfloat16 else tf.float32), 'image_info': tf_keras.layers.Input( shape=[4, 2], batch_size=batch_size, name='image_info'), } return input_layer def build_model(self, params, mode): if self._keras_model is None: input_layers = self.build_input_layers(self._params, mode) outputs = self.model_outputs(input_layers, mode) model = tf_keras.models.Model( inputs=input_layers, outputs=outputs, name='shapemask') assert model is not None, 'Fail to build tf_keras.Model.' model.optimizer = self.build_optimizer() self._keras_model = model return self._keras_model def post_processing(self, labels, outputs): required_output_fields = [ 'num_detections', 'detection_boxes', 'detection_classes', 'detection_masks', 'detection_scores' ] for field in required_output_fields: if field not in outputs: raise ValueError( '"{}" is missing in outputs, requried {} found {}'.format( field, required_output_fields, outputs.keys())) required_label_fields = ['image_info'] for field in required_label_fields: if field not in labels: raise ValueError( '"{}" is missing in labels, requried {} found {}'.format( field, required_label_fields, labels.keys())) predictions = { 'image_info': labels['image_info'], 'num_detections': outputs['num_detections'], 'detection_boxes': outputs['detection_boxes'], 'detection_outer_boxes': outputs['detection_outer_boxes'], 'detection_classes': outputs['detection_classes'], 'detection_scores': outputs['detection_scores'], 'detection_masks': outputs['detection_masks'], } if 'groundtruths' in labels: predictions['source_id'] = labels['groundtruths']['source_id'] labels = labels['groundtruths'] return labels, predictions def eval_metrics(self): return eval_factory.evaluator_generator(self._params.eval)