deanna-emery's picture
updates
93528c6
raw
history blame
14.8 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for maskrcnn_model.py."""
import os
# Import libraries
from absl.testing import parameterized
import numpy as np
import tensorflow as tf, tf_keras
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.vision.modeling import maskrcnn_model
from official.vision.modeling.backbones import resnet
from official.vision.modeling.decoders import fpn
from official.vision.modeling.heads import dense_prediction_heads
from official.vision.modeling.heads import instance_heads
from official.vision.modeling.layers import detection_generator
from official.vision.modeling.layers import mask_sampler
from official.vision.modeling.layers import roi_aligner
from official.vision.modeling.layers import roi_generator
from official.vision.modeling.layers import roi_sampler
from official.vision.ops import anchor
class MaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
@combinations.generate(
combinations.combine(
include_mask=[True, False],
use_separable_conv=[True, False],
build_anchor_boxes=[True, False],
use_outer_boxes=[True, False],
is_training=[True, False]))
def test_build_model(self, include_mask, use_separable_conv,
build_anchor_boxes, use_outer_boxes, is_training):
num_classes = 3
min_level = 3
max_level = 7
num_scales = 3
aspect_ratios = [1.0]
anchor_size = 3
resnet_model_id = 50
num_anchors_per_location = num_scales * len(aspect_ratios)
image_size = 384
images = np.random.rand(2, image_size, image_size, 3)
image_shape = np.array([[image_size, image_size], [image_size, image_size]])
if build_anchor_boxes:
anchor_boxes = anchor.Anchor(
min_level=min_level,
max_level=max_level,
num_scales=num_scales,
aspect_ratios=aspect_ratios,
anchor_size=3,
image_size=(image_size, image_size)).multilevel_boxes
for l in anchor_boxes:
anchor_boxes[l] = tf.tile(
tf.expand_dims(anchor_boxes[l], axis=0), [2, 1, 1, 1])
else:
anchor_boxes = None
backbone = resnet.ResNet(model_id=resnet_model_id)
decoder = fpn.FPN(
input_specs=backbone.output_specs,
min_level=min_level,
max_level=max_level,
use_separable_conv=use_separable_conv)
rpn_head = dense_prediction_heads.RPNHead(
min_level=min_level,
max_level=max_level,
num_anchors_per_location=num_anchors_per_location,
num_convs=1)
detection_head = instance_heads.DetectionHead(num_classes=num_classes)
roi_generator_obj = roi_generator.MultilevelROIGenerator()
roi_sampler_obj = roi_sampler.ROISampler()
roi_aligner_obj = roi_aligner.MultilevelROIAligner()
detection_generator_obj = detection_generator.DetectionGenerator()
if include_mask:
mask_head = instance_heads.MaskHead(
num_classes=num_classes, upsample_factor=2)
mask_sampler_obj = mask_sampler.MaskSampler(
mask_target_size=28, num_sampled_masks=1)
mask_roi_aligner_obj = roi_aligner.MultilevelROIAligner(crop_size=14)
else:
mask_head = None
mask_sampler_obj = None
mask_roi_aligner_obj = None
model = maskrcnn_model.MaskRCNNModel(
backbone,
decoder,
rpn_head,
detection_head,
roi_generator_obj,
roi_sampler_obj,
roi_aligner_obj,
detection_generator_obj,
mask_head,
mask_sampler_obj,
mask_roi_aligner_obj,
min_level=min_level,
max_level=max_level,
num_scales=num_scales,
aspect_ratios=aspect_ratios,
anchor_size=anchor_size)
gt_boxes = np.array(
[[[10, 10, 15, 15], [2.5, 2.5, 7.5, 7.5], [-1, -1, -1, -1]],
[[100, 100, 150, 150], [-1, -1, -1, -1], [-1, -1, -1, -1]]],
dtype=np.float32)
gt_outer_boxes = None
if use_outer_boxes:
gt_outer_boxes = np.array(
[[[11, 11, 16.5, 16.5], [2.75, 2.75, 8.25, 8.25], [-1, -1, -1, -1]],
[[110, 110, 165, 165], [-1, -1, -1, -1], [-1, -1, -1, -1]]],
dtype=np.float32)
gt_classes = np.array([[2, 1, -1], [1, -1, -1]], dtype=np.int32)
if include_mask:
gt_masks = np.ones((2, 3, 100, 100))
else:
gt_masks = None
# Results will be checked in test_forward.
_ = model(
images,
image_shape,
anchor_boxes,
gt_boxes,
gt_classes,
gt_masks,
gt_outer_boxes,
training=is_training)
@combinations.generate(
combinations.combine(
strategy=[
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],
include_mask=[True, False],
build_anchor_boxes=[True, False],
use_cascade_heads=[True, False],
training=[True, False],
))
def test_forward(self, strategy, include_mask, build_anchor_boxes, training,
use_cascade_heads):
num_classes = 3
min_level = 3
max_level = 4
num_scales = 3
aspect_ratios = [1.0]
anchor_size = 3
if use_cascade_heads:
cascade_iou_thresholds = [0.6]
class_agnostic_bbox_pred = True
cascade_class_ensemble = True
else:
cascade_iou_thresholds = None
class_agnostic_bbox_pred = False
cascade_class_ensemble = False
image_size = (256, 256)
images = np.random.rand(2, image_size[0], image_size[1], 3)
image_shape = np.array([[224, 100], [100, 224]])
with strategy.scope():
if build_anchor_boxes:
anchor_boxes = anchor.Anchor(
min_level=min_level,
max_level=max_level,
num_scales=num_scales,
aspect_ratios=aspect_ratios,
anchor_size=anchor_size,
image_size=image_size).multilevel_boxes
else:
anchor_boxes = None
num_anchors_per_location = len(aspect_ratios) * num_scales
input_specs = tf_keras.layers.InputSpec(shape=[None, None, None, 3])
backbone = resnet.ResNet(model_id=50, input_specs=input_specs)
decoder = fpn.FPN(
min_level=min_level,
max_level=max_level,
input_specs=backbone.output_specs)
rpn_head = dense_prediction_heads.RPNHead(
min_level=min_level,
max_level=max_level,
num_anchors_per_location=num_anchors_per_location)
detection_head = instance_heads.DetectionHead(
num_classes=num_classes,
class_agnostic_bbox_pred=class_agnostic_bbox_pred)
roi_generator_obj = roi_generator.MultilevelROIGenerator()
roi_sampler_cascade = []
roi_sampler_obj = roi_sampler.ROISampler()
roi_sampler_cascade.append(roi_sampler_obj)
if cascade_iou_thresholds:
for iou in cascade_iou_thresholds:
roi_sampler_obj = roi_sampler.ROISampler(
mix_gt_boxes=False,
foreground_iou_threshold=iou,
background_iou_high_threshold=iou,
background_iou_low_threshold=0.0,
skip_subsampling=True)
roi_sampler_cascade.append(roi_sampler_obj)
roi_aligner_obj = roi_aligner.MultilevelROIAligner()
detection_generator_obj = detection_generator.DetectionGenerator()
if include_mask:
mask_head = instance_heads.MaskHead(
num_classes=num_classes, upsample_factor=2)
mask_sampler_obj = mask_sampler.MaskSampler(
mask_target_size=28, num_sampled_masks=1)
mask_roi_aligner_obj = roi_aligner.MultilevelROIAligner(crop_size=14)
else:
mask_head = None
mask_sampler_obj = None
mask_roi_aligner_obj = None
model = maskrcnn_model.MaskRCNNModel(
backbone,
decoder,
rpn_head,
detection_head,
roi_generator_obj,
roi_sampler_obj,
roi_aligner_obj,
detection_generator_obj,
mask_head,
mask_sampler_obj,
mask_roi_aligner_obj,
class_agnostic_bbox_pred=class_agnostic_bbox_pred,
cascade_class_ensemble=cascade_class_ensemble,
min_level=min_level,
max_level=max_level,
num_scales=num_scales,
aspect_ratios=aspect_ratios,
anchor_size=anchor_size)
gt_boxes = np.array(
[[[10, 10, 15, 15], [2.5, 2.5, 7.5, 7.5], [-1, -1, -1, -1]],
[[100, 100, 150, 150], [-1, -1, -1, -1], [-1, -1, -1, -1]]],
dtype=np.float32)
gt_outer_boxes = np.array(
[[[11, 11, 16.5, 16.5], [2.75, 2.75, 8.25, 8.25], [-1, -1, -1, -1]],
[[110, 110, 165, 165], [-1, -1, -1, -1], [-1, -1, -1, -1]]],
dtype=np.float32)
gt_classes = np.array([[2, 1, -1], [1, -1, -1]], dtype=np.int32)
if include_mask:
gt_masks = np.ones((2, 3, 100, 100))
else:
gt_masks = None
results = model(
images,
image_shape,
anchor_boxes,
gt_boxes,
gt_classes,
gt_masks,
gt_outer_boxes,
training=training)
self.assertIn('rpn_boxes', results)
self.assertIn('rpn_scores', results)
if training:
self.assertIn('class_targets', results)
self.assertIn('box_targets', results)
self.assertIn('class_outputs', results)
self.assertIn('box_outputs', results)
if include_mask:
self.assertIn('mask_outputs', results)
else:
self.assertIn('detection_boxes', results)
self.assertIn('detection_scores', results)
self.assertIn('detection_classes', results)
self.assertIn('num_detections', results)
if include_mask:
self.assertIn('detection_masks', results)
@parameterized.parameters(
(False,),
(True,),
)
def test_serialize_deserialize(self, include_mask):
input_specs = tf_keras.layers.InputSpec(shape=[None, None, None, 3])
backbone = resnet.ResNet(model_id=50, input_specs=input_specs)
decoder = fpn.FPN(
min_level=3, max_level=7, input_specs=backbone.output_specs)
rpn_head = dense_prediction_heads.RPNHead(
min_level=3, max_level=7, num_anchors_per_location=3)
detection_head = instance_heads.DetectionHead(num_classes=2)
roi_generator_obj = roi_generator.MultilevelROIGenerator()
roi_sampler_obj = roi_sampler.ROISampler()
roi_aligner_obj = roi_aligner.MultilevelROIAligner()
detection_generator_obj = detection_generator.DetectionGenerator()
if include_mask:
mask_head = instance_heads.MaskHead(num_classes=2, upsample_factor=2)
mask_sampler_obj = mask_sampler.MaskSampler(
mask_target_size=28, num_sampled_masks=1)
mask_roi_aligner_obj = roi_aligner.MultilevelROIAligner(crop_size=14)
else:
mask_head = None
mask_sampler_obj = None
mask_roi_aligner_obj = None
model = maskrcnn_model.MaskRCNNModel(
backbone,
decoder,
rpn_head,
detection_head,
roi_generator_obj,
roi_sampler_obj,
roi_aligner_obj,
detection_generator_obj,
mask_head,
mask_sampler_obj,
mask_roi_aligner_obj,
min_level=3,
max_level=7,
num_scales=3,
aspect_ratios=[1.0],
anchor_size=3)
config = model.get_config()
new_model = maskrcnn_model.MaskRCNNModel.from_config(config)
# Validate that the config can be forced to JSON.
_ = new_model.to_json()
# If the serialization was successful, the new config should match the old.
self.assertAllEqual(model.get_config(), new_model.get_config())
@parameterized.parameters(
(False,),
(True,),
)
def test_checkpoint(self, include_mask):
input_specs = tf_keras.layers.InputSpec(shape=[None, None, None, 3])
backbone = resnet.ResNet(model_id=50, input_specs=input_specs)
decoder = fpn.FPN(
min_level=3, max_level=7, input_specs=backbone.output_specs)
rpn_head = dense_prediction_heads.RPNHead(
min_level=3, max_level=7, num_anchors_per_location=3)
detection_head = instance_heads.DetectionHead(num_classes=2)
roi_generator_obj = roi_generator.MultilevelROIGenerator()
roi_sampler_obj = roi_sampler.ROISampler()
roi_aligner_obj = roi_aligner.MultilevelROIAligner()
detection_generator_obj = detection_generator.DetectionGenerator()
if include_mask:
mask_head = instance_heads.MaskHead(num_classes=2, upsample_factor=2)
mask_sampler_obj = mask_sampler.MaskSampler(
mask_target_size=28, num_sampled_masks=1)
mask_roi_aligner_obj = roi_aligner.MultilevelROIAligner(crop_size=14)
else:
mask_head = None
mask_sampler_obj = None
mask_roi_aligner_obj = None
model = maskrcnn_model.MaskRCNNModel(
backbone,
decoder,
rpn_head,
detection_head,
roi_generator_obj,
roi_sampler_obj,
roi_aligner_obj,
detection_generator_obj,
mask_head,
mask_sampler_obj,
mask_roi_aligner_obj,
min_level=3,
max_level=7,
num_scales=3,
aspect_ratios=[1.0],
anchor_size=3)
expect_checkpoint_items = dict(
backbone=backbone,
decoder=decoder,
rpn_head=rpn_head,
detection_head=[detection_head])
if include_mask:
expect_checkpoint_items['mask_head'] = mask_head
self.assertAllEqual(expect_checkpoint_items, model.checkpoint_items)
# Test save and load checkpoints.
ckpt = tf.train.Checkpoint(model=model, **model.checkpoint_items)
save_dir = self.create_tempdir().full_path
ckpt.save(os.path.join(save_dir, 'ckpt'))
partial_ckpt = tf.train.Checkpoint(backbone=backbone)
partial_ckpt.read(tf.train.latest_checkpoint(
save_dir)).expect_partial().assert_existing_objects_matched()
if include_mask:
partial_ckpt_mask = tf.train.Checkpoint(
backbone=backbone, mask_head=mask_head)
partial_ckpt_mask.restore(tf.train.latest_checkpoint(
save_dir)).expect_partial().assert_existing_objects_matched()
if __name__ == '__main__':
tf.test.main()