Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Model defination for the Object Localization Network (OLN) Model.""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import tensorflow as tf, tf_keras | |
from official.legacy.detection.dataloader import anchor | |
from official.legacy.detection.dataloader import mode_keys | |
from official.legacy.detection.modeling import losses | |
from official.legacy.detection.modeling.architecture import factory | |
from official.legacy.detection.modeling.maskrcnn_model import MaskrcnnModel | |
from official.legacy.detection.ops import postprocess_ops | |
from official.legacy.detection.ops import roi_ops | |
from official.legacy.detection.ops import spatial_transform_ops | |
from official.legacy.detection.ops import target_ops | |
from official.legacy.detection.utils import box_utils | |
class OlnMaskModel(MaskrcnnModel): | |
"""OLN-Mask model function.""" | |
def __init__(self, params): | |
super(OlnMaskModel, self).__init__(params) | |
self._params = params | |
# Different heads and layers. | |
self._include_rpn_class = params.architecture.include_rpn_class | |
self._include_mask = params.architecture.include_mask | |
self._include_frcnn_class = params.architecture.include_frcnn_class | |
self._include_frcnn_box = params.architecture.include_frcnn_box | |
self._include_centerness = params.rpn_head.has_centerness | |
self._include_box_score = (params.frcnn_head.has_scoring and | |
params.architecture.include_frcnn_box) | |
self._include_mask_score = (params.mrcnn_head.has_scoring and | |
params.architecture.include_mask) | |
# Architecture generators. | |
self._backbone_fn = factory.backbone_generator(params) | |
self._fpn_fn = factory.multilevel_features_generator(params) | |
self._rpn_head_fn = factory.rpn_head_generator(params) | |
if self._include_centerness: | |
self._rpn_head_fn = factory.oln_rpn_head_generator(params) | |
else: | |
self._rpn_head_fn = factory.rpn_head_generator(params) | |
self._generate_rois_fn = roi_ops.OlnROIGenerator(params.roi_proposal) | |
self._sample_rois_fn = target_ops.ROIScoreSampler(params.roi_sampling) | |
self._sample_masks_fn = target_ops.MaskSampler( | |
params.architecture.mask_target_size, | |
params.mask_sampling.num_mask_samples_per_image) | |
if self._include_box_score: | |
self._frcnn_head_fn = factory.oln_box_score_head_generator(params) | |
else: | |
self._frcnn_head_fn = factory.fast_rcnn_head_generator(params) | |
if self._include_mask: | |
if self._include_mask_score: | |
self._mrcnn_head_fn = factory.oln_mask_score_head_generator(params) | |
else: | |
self._mrcnn_head_fn = factory.mask_rcnn_head_generator(params) | |
# Loss function. | |
self._rpn_score_loss_fn = losses.RpnScoreLoss(params.rpn_score_loss) | |
self._rpn_box_loss_fn = losses.RpnBoxLoss(params.rpn_box_loss) | |
if self._include_centerness: | |
self._rpn_iou_loss_fn = losses.OlnRpnIoULoss() | |
self._rpn_center_loss_fn = losses.OlnRpnCenterLoss() | |
self._frcnn_class_loss_fn = losses.FastrcnnClassLoss() | |
self._frcnn_box_loss_fn = losses.FastrcnnBoxLoss(params.frcnn_box_loss) | |
if self._include_box_score: | |
self._frcnn_box_score_loss_fn = losses.OlnBoxScoreLoss( | |
params.frcnn_box_score_loss) | |
if self._include_mask: | |
self._mask_loss_fn = losses.MaskrcnnLoss() | |
self._generate_detections_fn = postprocess_ops.OlnDetectionGenerator( | |
params.postprocess) | |
self._transpose_input = params.train.transpose_input | |
assert not self._transpose_input, 'Transpose input is not supportted.' | |
def build_outputs(self, inputs, mode): | |
is_training = mode == mode_keys.TRAIN | |
model_outputs = {} | |
image = inputs['image'] | |
_, image_height, image_width, _ = image.get_shape().as_list() | |
backbone_features = self._backbone_fn(image, is_training) | |
fpn_features = self._fpn_fn(backbone_features, is_training) | |
# rpn_centerness. | |
if self._include_centerness: | |
rpn_score_outputs, rpn_box_outputs, rpn_center_outputs = ( | |
self._rpn_head_fn(fpn_features, is_training)) | |
model_outputs.update({ | |
'rpn_center_outputs': | |
tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), | |
rpn_center_outputs), | |
}) | |
object_scores = rpn_center_outputs | |
else: | |
rpn_score_outputs, rpn_box_outputs = self._rpn_head_fn( | |
fpn_features, is_training) | |
object_scores = None | |
model_outputs.update({ | |
'rpn_score_outputs': | |
tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), | |
rpn_score_outputs), | |
'rpn_box_outputs': | |
tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), | |
rpn_box_outputs), | |
}) | |
input_anchor = anchor.Anchor(self._params.architecture.min_level, | |
self._params.architecture.max_level, | |
self._params.anchor.num_scales, | |
self._params.anchor.aspect_ratios, | |
self._params.anchor.anchor_size, | |
(image_height, image_width)) | |
rpn_rois, rpn_roi_scores = self._generate_rois_fn( | |
rpn_box_outputs, | |
rpn_score_outputs, | |
input_anchor.multilevel_boxes, | |
inputs['image_info'][:, 1, :], | |
is_training, | |
is_box_lrtb=self._include_centerness, | |
object_scores=object_scores, | |
) | |
if (not self._include_frcnn_class and | |
not self._include_frcnn_box and | |
not self._include_mask): | |
# if not is_training: | |
# For direct RPN detection, | |
# use dummy box_outputs = (dy,dx,dh,dw = 0,0,0,0) | |
box_outputs = tf.zeros_like(rpn_rois) | |
box_outputs = tf.concat([box_outputs, box_outputs], -1) | |
boxes, scores, classes, valid_detections = self._generate_detections_fn( | |
box_outputs, rpn_roi_scores, rpn_rois, | |
inputs['image_info'][:, 1:2, :], | |
is_single_fg_score=True, # if no_background, no softmax is applied. | |
keep_nms=True) | |
model_outputs.update({ | |
'num_detections': valid_detections, | |
'detection_boxes': boxes, | |
'detection_classes': classes, | |
'detection_scores': scores, | |
}) | |
return model_outputs | |
# ---- OLN-Proposal finishes here. ---- | |
if is_training: | |
rpn_rois = tf.stop_gradient(rpn_rois) | |
rpn_roi_scores = tf.stop_gradient(rpn_roi_scores) | |
# Sample proposals. | |
(rpn_rois, rpn_roi_scores, matched_gt_boxes, matched_gt_classes, | |
matched_gt_indices) = ( | |
self._sample_rois_fn(rpn_rois, rpn_roi_scores, inputs['gt_boxes'], | |
inputs['gt_classes'])) | |
# Create bounding box training targets. | |
box_targets = box_utils.encode_boxes( | |
matched_gt_boxes, rpn_rois, weights=[10.0, 10.0, 5.0, 5.0]) | |
# If the target is background, the box target is set to all 0s. | |
box_targets = tf.where( | |
tf.tile( | |
tf.expand_dims(tf.equal(matched_gt_classes, 0), axis=-1), | |
[1, 1, 4]), tf.zeros_like(box_targets), box_targets) | |
model_outputs.update({ | |
'class_targets': matched_gt_classes, | |
'box_targets': box_targets, | |
}) | |
# Create Box-IoU targets. { | |
box_ious = box_utils.bbox_overlap( | |
rpn_rois, inputs['gt_boxes']) | |
matched_box_ious = tf.reduce_max(box_ious, 2) | |
model_outputs.update({ | |
'box_iou_targets': matched_box_ious,}) # } | |
roi_features = spatial_transform_ops.multilevel_crop_and_resize( | |
fpn_features, rpn_rois, output_size=7) | |
if not self._include_box_score: | |
class_outputs, box_outputs = self._frcnn_head_fn( | |
roi_features, is_training) | |
else: | |
class_outputs, box_outputs, score_outputs = self._frcnn_head_fn( | |
roi_features, is_training) | |
model_outputs.update({ | |
'box_score_outputs': | |
tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), | |
score_outputs),}) | |
model_outputs.update({ | |
'class_outputs': | |
tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), | |
class_outputs), | |
'box_outputs': | |
tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), | |
box_outputs), | |
}) | |
# Add this output to train to make the checkpoint loadable in predict mode. | |
# If we skip it in train mode, the heads will be out-of-order and checkpoint | |
# loading will fail. | |
if not self._include_frcnn_box: | |
box_outputs = tf.zeros_like(box_outputs) # dummy zeros. | |
if self._include_box_score: | |
score_outputs = tf.cast(tf.squeeze(score_outputs, -1), | |
rpn_roi_scores.dtype) | |
# box-score = (rpn-centerness * box-iou)^(1/2) | |
# TR: rpn_roi_scores: b,1000, score_outputs: b,512 | |
# TS: rpn_roi_scores: b,1000, score_outputs: b,1000 | |
box_scores = tf.pow( | |
rpn_roi_scores * tf.sigmoid(score_outputs), 1/2.) | |
if not self._include_frcnn_class: | |
boxes, scores, classes, valid_detections = self._generate_detections_fn( | |
box_outputs, | |
box_scores, | |
rpn_rois, | |
inputs['image_info'][:, 1:2, :], | |
is_single_fg_score=True, | |
keep_nms=True,) | |
else: | |
boxes, scores, classes, valid_detections = self._generate_detections_fn( | |
box_outputs, class_outputs, rpn_rois, | |
inputs['image_info'][:, 1:2, :], | |
keep_nms=True,) | |
model_outputs.update({ | |
'num_detections': valid_detections, | |
'detection_boxes': boxes, | |
'detection_classes': classes, | |
'detection_scores': scores, | |
}) | |
# ---- OLN-Box finishes here. ---- | |
if not self._include_mask: | |
return model_outputs | |
if is_training: | |
rpn_rois, classes, mask_targets = self._sample_masks_fn( | |
rpn_rois, matched_gt_boxes, matched_gt_classes, matched_gt_indices, | |
inputs['gt_masks']) | |
mask_targets = tf.stop_gradient(mask_targets) | |
classes = tf.cast(classes, dtype=tf.int32) | |
model_outputs.update({ | |
'mask_targets': mask_targets, | |
'sampled_class_targets': classes, | |
}) | |
else: | |
rpn_rois = boxes | |
classes = tf.cast(classes, dtype=tf.int32) | |
mask_roi_features = spatial_transform_ops.multilevel_crop_and_resize( | |
fpn_features, rpn_rois, output_size=14) | |
mask_outputs = self._mrcnn_head_fn(mask_roi_features, classes, is_training) | |
if is_training: | |
model_outputs.update({ | |
'mask_outputs': | |
tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), | |
mask_outputs), | |
}) | |
else: | |
model_outputs.update({'detection_masks': tf.nn.sigmoid(mask_outputs)}) | |
return model_outputs | |
def build_loss_fn(self): | |
if self._keras_model is None: | |
raise ValueError('build_loss_fn() must be called after build_model().') | |
filter_fn = self.make_filter_trainable_variables_fn() | |
trainable_variables = filter_fn(self._keras_model.trainable_variables) | |
def _total_loss_fn(labels, outputs): | |
if self._include_rpn_class: | |
rpn_score_loss = self._rpn_score_loss_fn(outputs['rpn_score_outputs'], | |
labels['rpn_score_targets']) | |
else: | |
rpn_score_loss = 0.0 | |
if self._include_centerness: | |
rpn_center_loss = self._rpn_center_loss_fn( | |
outputs['rpn_center_outputs'], labels['rpn_center_targets']) | |
rpn_box_loss = self._rpn_iou_loss_fn( | |
outputs['rpn_box_outputs'], labels['rpn_box_targets'], | |
labels['rpn_center_targets']) | |
else: | |
rpn_center_loss = 0.0 | |
rpn_box_loss = self._rpn_box_loss_fn( | |
outputs['rpn_box_outputs'], labels['rpn_box_targets']) | |
if self._include_frcnn_class: | |
frcnn_class_loss = self._frcnn_class_loss_fn( | |
outputs['class_outputs'], outputs['class_targets']) | |
else: | |
frcnn_class_loss = 0.0 | |
if self._include_frcnn_box: | |
frcnn_box_loss = self._frcnn_box_loss_fn( | |
outputs['box_outputs'], outputs['class_targets'], | |
outputs['box_targets']) | |
else: | |
frcnn_box_loss = 0.0 | |
if self._include_box_score: | |
box_score_loss = self._frcnn_box_score_loss_fn( | |
outputs['box_score_outputs'], outputs['box_iou_targets']) | |
else: | |
box_score_loss = 0.0 | |
if self._include_mask: | |
mask_loss = self._mask_loss_fn(outputs['mask_outputs'], | |
outputs['mask_targets'], | |
outputs['sampled_class_targets']) | |
else: | |
mask_loss = 0.0 | |
model_loss = ( | |
rpn_score_loss + rpn_box_loss + rpn_center_loss + | |
frcnn_class_loss + frcnn_box_loss + box_score_loss + | |
mask_loss) | |
l2_regularization_loss = self.weight_decay_loss(trainable_variables) | |
total_loss = model_loss + l2_regularization_loss | |
return { | |
'total_loss': total_loss, | |
'loss': total_loss, | |
'fast_rcnn_class_loss': frcnn_class_loss, | |
'fast_rcnn_box_loss': frcnn_box_loss, | |
'fast_rcnn_box_score_loss': box_score_loss, | |
'mask_loss': mask_loss, | |
'model_loss': model_loss, | |
'l2_regularization_loss': l2_regularization_loss, | |
'rpn_score_loss': rpn_score_loss, | |
'rpn_box_loss': rpn_box_loss, | |
'rpn_center_loss': rpn_center_loss, | |
} | |
return _total_loss_fn | |
def build_input_layers(self, params, mode): | |
is_training = mode == mode_keys.TRAIN | |
input_shape = ( | |
params.olnmask_parser.output_size + | |
[params.olnmask_parser.num_channels]) | |
if is_training: | |
batch_size = params.train.batch_size | |
input_layer = { | |
'image': | |
tf_keras.layers.Input( | |
shape=input_shape, | |
batch_size=batch_size, | |
name='image', | |
dtype=tf.bfloat16 if self._use_bfloat16 else tf.float32), | |
'image_info': | |
tf_keras.layers.Input( | |
shape=[4, 2], | |
batch_size=batch_size, | |
name='image_info', | |
), | |
'gt_boxes': | |
tf_keras.layers.Input( | |
shape=[params.olnmask_parser.max_num_instances, 4], | |
batch_size=batch_size, | |
name='gt_boxes'), | |
'gt_classes': | |
tf_keras.layers.Input( | |
shape=[params.olnmask_parser.max_num_instances], | |
batch_size=batch_size, | |
name='gt_classes', | |
dtype=tf.int64), | |
} | |
if self._include_mask: | |
input_layer['gt_masks'] = tf_keras.layers.Input( | |
shape=[ | |
params.olnmask_parser.max_num_instances, | |
params.olnmask_parser.mask_crop_size, | |
params.olnmask_parser.mask_crop_size | |
], | |
batch_size=batch_size, | |
name='gt_masks') | |
else: | |
batch_size = params.eval.batch_size | |
input_layer = { | |
'image': | |
tf_keras.layers.Input( | |
shape=input_shape, | |
batch_size=batch_size, | |
name='image', | |
dtype=tf.bfloat16 if self._use_bfloat16 else tf.float32), | |
'image_info': | |
tf_keras.layers.Input( | |
shape=[4, 2], | |
batch_size=batch_size, | |
name='image_info', | |
), | |
} | |
return input_layer | |
def build_model(self, params, mode): | |
if self._keras_model is None: | |
input_layers = self.build_input_layers(self._params, mode) | |
outputs = self.model_outputs(input_layers, mode) | |
model = tf_keras.models.Model( | |
inputs=input_layers, outputs=outputs, name='olnmask') | |
assert model is not None, 'Fail to build tf_keras.Model.' | |
model.optimizer = self.build_optimizer() | |
self._keras_model = model | |
return self._keras_model | |