# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """A function to build a DetectionModel from configuration.""" import functools from object_detection.builders import anchor_generator_builder from object_detection.builders import box_coder_builder from object_detection.builders import box_predictor_builder from object_detection.builders import hyperparams_builder from object_detection.builders import image_resizer_builder from object_detection.builders import losses_builder from object_detection.builders import matcher_builder from object_detection.builders import post_processing_builder from object_detection.builders import region_similarity_calculator_builder as sim_calc from object_detection.core import balanced_positive_negative_sampler as sampler from object_detection.core import post_processing from object_detection.core import target_assigner from object_detection.meta_architectures import center_net_meta_arch from object_detection.meta_architectures import context_rcnn_meta_arch from object_detection.meta_architectures import faster_rcnn_meta_arch from object_detection.meta_architectures import rfcn_meta_arch from object_detection.meta_architectures import ssd_meta_arch from object_detection.predictors.heads import mask_head from object_detection.protos import losses_pb2 from object_detection.protos import model_pb2 from object_detection.utils import label_map_util from object_detection.utils import ops from object_detection.utils import tf_version ## Feature Extractors for TF ## This section conditionally imports different feature extractors based on the ## Tensorflow version. ## # pylint: disable=g-import-not-at-top if tf_version.is_tf2(): from object_detection.models import center_net_hourglass_feature_extractor from object_detection.models import center_net_resnet_feature_extractor from object_detection.models import center_net_resnet_v1_fpn_feature_extractor from object_detection.models import faster_rcnn_inception_resnet_v2_keras_feature_extractor as frcnn_inc_res_keras from object_detection.models import faster_rcnn_resnet_keras_feature_extractor as frcnn_resnet_keras from object_detection.models import ssd_resnet_v1_fpn_keras_feature_extractor as ssd_resnet_v1_fpn_keras from object_detection.models.ssd_mobilenet_v1_fpn_keras_feature_extractor import SSDMobileNetV1FpnKerasFeatureExtractor from object_detection.models.ssd_mobilenet_v1_keras_feature_extractor import SSDMobileNetV1KerasFeatureExtractor from object_detection.models.ssd_mobilenet_v2_fpn_keras_feature_extractor import SSDMobileNetV2FpnKerasFeatureExtractor from object_detection.models.ssd_mobilenet_v2_keras_feature_extractor import SSDMobileNetV2KerasFeatureExtractor from object_detection.predictors import rfcn_keras_box_predictor if tf_version.is_tf1(): from object_detection.models import faster_rcnn_inception_resnet_v2_feature_extractor as frcnn_inc_res from object_detection.models import faster_rcnn_inception_v2_feature_extractor as frcnn_inc_v2 from object_detection.models import faster_rcnn_nas_feature_extractor as frcnn_nas from object_detection.models import faster_rcnn_pnas_feature_extractor as frcnn_pnas from object_detection.models import faster_rcnn_resnet_v1_feature_extractor as frcnn_resnet_v1 from object_detection.models import ssd_resnet_v1_fpn_feature_extractor as ssd_resnet_v1_fpn from object_detection.models import ssd_resnet_v1_ppn_feature_extractor as ssd_resnet_v1_ppn from object_detection.models.embedded_ssd_mobilenet_v1_feature_extractor import EmbeddedSSDMobileNetV1FeatureExtractor from object_detection.models.ssd_inception_v2_feature_extractor import SSDInceptionV2FeatureExtractor from object_detection.models.ssd_mobilenet_v2_fpn_feature_extractor import SSDMobileNetV2FpnFeatureExtractor from object_detection.models.ssd_mobilenet_v2_mnasfpn_feature_extractor import SSDMobileNetV2MnasFPNFeatureExtractor from object_detection.models.ssd_inception_v3_feature_extractor import SSDInceptionV3FeatureExtractor from object_detection.models.ssd_mobilenet_edgetpu_feature_extractor import SSDMobileNetEdgeTPUFeatureExtractor from object_detection.models.ssd_mobilenet_v1_feature_extractor import SSDMobileNetV1FeatureExtractor from object_detection.models.ssd_mobilenet_v1_fpn_feature_extractor import SSDMobileNetV1FpnFeatureExtractor from object_detection.models.ssd_mobilenet_v1_ppn_feature_extractor import SSDMobileNetV1PpnFeatureExtractor from object_detection.models.ssd_mobilenet_v2_feature_extractor import SSDMobileNetV2FeatureExtractor from object_detection.models.ssd_mobilenet_v3_feature_extractor import SSDMobileNetV3LargeFeatureExtractor from object_detection.models.ssd_mobilenet_v3_feature_extractor import SSDMobileNetV3SmallFeatureExtractor from object_detection.models.ssd_mobiledet_feature_extractor import SSDMobileDetCPUFeatureExtractor from object_detection.models.ssd_mobiledet_feature_extractor import SSDMobileDetDSPFeatureExtractor from object_detection.models.ssd_mobiledet_feature_extractor import SSDMobileDetEdgeTPUFeatureExtractor from object_detection.models.ssd_mobiledet_feature_extractor import SSDMobileDetGPUFeatureExtractor from object_detection.models.ssd_pnasnet_feature_extractor import SSDPNASNetFeatureExtractor from object_detection.predictors import rfcn_box_predictor # pylint: enable=g-import-not-at-top if tf_version.is_tf2(): SSD_KERAS_FEATURE_EXTRACTOR_CLASS_MAP = { 'ssd_mobilenet_v1_keras': SSDMobileNetV1KerasFeatureExtractor, 'ssd_mobilenet_v1_fpn_keras': SSDMobileNetV1FpnKerasFeatureExtractor, 'ssd_mobilenet_v2_keras': SSDMobileNetV2KerasFeatureExtractor, 'ssd_mobilenet_v2_fpn_keras': SSDMobileNetV2FpnKerasFeatureExtractor, 'ssd_resnet50_v1_fpn_keras': ssd_resnet_v1_fpn_keras.SSDResNet50V1FpnKerasFeatureExtractor, 'ssd_resnet101_v1_fpn_keras': ssd_resnet_v1_fpn_keras.SSDResNet101V1FpnKerasFeatureExtractor, 'ssd_resnet152_v1_fpn_keras': ssd_resnet_v1_fpn_keras.SSDResNet152V1FpnKerasFeatureExtractor, } FASTER_RCNN_KERAS_FEATURE_EXTRACTOR_CLASS_MAP = { 'faster_rcnn_resnet50_keras': frcnn_resnet_keras.FasterRCNNResnet50KerasFeatureExtractor, 'faster_rcnn_resnet101_keras': frcnn_resnet_keras.FasterRCNNResnet101KerasFeatureExtractor, 'faster_rcnn_resnet152_keras': frcnn_resnet_keras.FasterRCNNResnet152KerasFeatureExtractor, 'faster_rcnn_inception_resnet_v2_keras': frcnn_inc_res_keras.FasterRCNNInceptionResnetV2KerasFeatureExtractor, } CENTER_NET_EXTRACTOR_FUNCTION_MAP = { 'resnet_v2_50': center_net_resnet_feature_extractor.resnet_v2_50, 'resnet_v2_101': center_net_resnet_feature_extractor.resnet_v2_101, 'resnet_v1_50_fpn': center_net_resnet_v1_fpn_feature_extractor.resnet_v1_50_fpn, 'resnet_v1_101_fpn': center_net_resnet_v1_fpn_feature_extractor.resnet_v1_101_fpn, 'hourglass_104': center_net_hourglass_feature_extractor.hourglass_104, } FEATURE_EXTRACTOR_MAPS = [ CENTER_NET_EXTRACTOR_FUNCTION_MAP, FASTER_RCNN_KERAS_FEATURE_EXTRACTOR_CLASS_MAP, SSD_KERAS_FEATURE_EXTRACTOR_CLASS_MAP ] if tf_version.is_tf1(): SSD_FEATURE_EXTRACTOR_CLASS_MAP = { 'ssd_inception_v2': SSDInceptionV2FeatureExtractor, 'ssd_inception_v3': SSDInceptionV3FeatureExtractor, 'ssd_mobilenet_v1': SSDMobileNetV1FeatureExtractor, 'ssd_mobilenet_v1_fpn': SSDMobileNetV1FpnFeatureExtractor, 'ssd_mobilenet_v1_ppn': SSDMobileNetV1PpnFeatureExtractor, 'ssd_mobilenet_v2': SSDMobileNetV2FeatureExtractor, 'ssd_mobilenet_v2_fpn': SSDMobileNetV2FpnFeatureExtractor, 'ssd_mobilenet_v2_mnasfpn': SSDMobileNetV2MnasFPNFeatureExtractor, 'ssd_mobilenet_v3_large': SSDMobileNetV3LargeFeatureExtractor, 'ssd_mobilenet_v3_small': SSDMobileNetV3SmallFeatureExtractor, 'ssd_mobilenet_edgetpu': SSDMobileNetEdgeTPUFeatureExtractor, 'ssd_resnet50_v1_fpn': ssd_resnet_v1_fpn.SSDResnet50V1FpnFeatureExtractor, 'ssd_resnet101_v1_fpn': ssd_resnet_v1_fpn.SSDResnet101V1FpnFeatureExtractor, 'ssd_resnet152_v1_fpn': ssd_resnet_v1_fpn.SSDResnet152V1FpnFeatureExtractor, 'ssd_resnet50_v1_ppn': ssd_resnet_v1_ppn.SSDResnet50V1PpnFeatureExtractor, 'ssd_resnet101_v1_ppn': ssd_resnet_v1_ppn.SSDResnet101V1PpnFeatureExtractor, 'ssd_resnet152_v1_ppn': ssd_resnet_v1_ppn.SSDResnet152V1PpnFeatureExtractor, 'embedded_ssd_mobilenet_v1': EmbeddedSSDMobileNetV1FeatureExtractor, 'ssd_pnasnet': SSDPNASNetFeatureExtractor, 'ssd_mobiledet_cpu': SSDMobileDetCPUFeatureExtractor, 'ssd_mobiledet_dsp': SSDMobileDetDSPFeatureExtractor, 'ssd_mobiledet_edgetpu': SSDMobileDetEdgeTPUFeatureExtractor, 'ssd_mobiledet_gpu': SSDMobileDetGPUFeatureExtractor, } FASTER_RCNN_FEATURE_EXTRACTOR_CLASS_MAP = { 'faster_rcnn_nas': frcnn_nas.FasterRCNNNASFeatureExtractor, 'faster_rcnn_pnas': frcnn_pnas.FasterRCNNPNASFeatureExtractor, 'faster_rcnn_inception_resnet_v2': frcnn_inc_res.FasterRCNNInceptionResnetV2FeatureExtractor, 'faster_rcnn_inception_v2': frcnn_inc_v2.FasterRCNNInceptionV2FeatureExtractor, 'faster_rcnn_resnet50': frcnn_resnet_v1.FasterRCNNResnet50FeatureExtractor, 'faster_rcnn_resnet101': frcnn_resnet_v1.FasterRCNNResnet101FeatureExtractor, 'faster_rcnn_resnet152': frcnn_resnet_v1.FasterRCNNResnet152FeatureExtractor, } FEATURE_EXTRACTOR_MAPS = [ SSD_FEATURE_EXTRACTOR_CLASS_MAP, FASTER_RCNN_FEATURE_EXTRACTOR_CLASS_MAP ] def _check_feature_extractor_exists(feature_extractor_type): feature_extractors = set().union(*FEATURE_EXTRACTOR_MAPS) if feature_extractor_type not in feature_extractors: raise ValueError('{} is not supported. See `model_builder.py` for features ' 'extractors compatible with different versions of ' 'Tensorflow'.format(feature_extractor_type)) def _build_ssd_feature_extractor(feature_extractor_config, is_training, freeze_batchnorm, reuse_weights=None): """Builds a ssd_meta_arch.SSDFeatureExtractor based on config. Args: feature_extractor_config: A SSDFeatureExtractor proto config from ssd.proto. is_training: True if this feature extractor is being built for training. freeze_batchnorm: Whether to freeze batch norm parameters during training or not. When training with a small batch size (e.g. 1), it is desirable to freeze batch norm update and use pretrained batch norm params. reuse_weights: if the feature extractor should reuse weights. Returns: ssd_meta_arch.SSDFeatureExtractor based on config. Raises: ValueError: On invalid feature extractor type. """ feature_type = feature_extractor_config.type depth_multiplier = feature_extractor_config.depth_multiplier min_depth = feature_extractor_config.min_depth pad_to_multiple = feature_extractor_config.pad_to_multiple use_explicit_padding = feature_extractor_config.use_explicit_padding use_depthwise = feature_extractor_config.use_depthwise is_keras = tf_version.is_tf2() if is_keras: conv_hyperparams = hyperparams_builder.KerasLayerHyperparams( feature_extractor_config.conv_hyperparams) else: conv_hyperparams = hyperparams_builder.build( feature_extractor_config.conv_hyperparams, is_training) override_base_feature_extractor_hyperparams = ( feature_extractor_config.override_base_feature_extractor_hyperparams) if not is_keras and feature_type not in SSD_FEATURE_EXTRACTOR_CLASS_MAP: raise ValueError('Unknown ssd feature_extractor: {}'.format(feature_type)) if is_keras: feature_extractor_class = SSD_KERAS_FEATURE_EXTRACTOR_CLASS_MAP[ feature_type] else: feature_extractor_class = SSD_FEATURE_EXTRACTOR_CLASS_MAP[feature_type] kwargs = { 'is_training': is_training, 'depth_multiplier': depth_multiplier, 'min_depth': min_depth, 'pad_to_multiple': pad_to_multiple, 'use_explicit_padding': use_explicit_padding, 'use_depthwise': use_depthwise, 'override_base_feature_extractor_hyperparams': override_base_feature_extractor_hyperparams } if feature_extractor_config.HasField('replace_preprocessor_with_placeholder'): kwargs.update({ 'replace_preprocessor_with_placeholder': feature_extractor_config.replace_preprocessor_with_placeholder }) if feature_extractor_config.HasField('num_layers'): kwargs.update({'num_layers': feature_extractor_config.num_layers}) if is_keras: kwargs.update({ 'conv_hyperparams': conv_hyperparams, 'inplace_batchnorm_update': False, 'freeze_batchnorm': freeze_batchnorm }) else: kwargs.update({ 'conv_hyperparams_fn': conv_hyperparams, 'reuse_weights': reuse_weights, }) if feature_extractor_config.HasField('fpn'): kwargs.update({ 'fpn_min_level': feature_extractor_config.fpn.min_level, 'fpn_max_level': feature_extractor_config.fpn.max_level, 'additional_layer_depth': feature_extractor_config.fpn.additional_layer_depth, }) return feature_extractor_class(**kwargs) def _build_ssd_model(ssd_config, is_training, add_summaries): """Builds an SSD detection model based on the model config. Args: ssd_config: A ssd.proto object containing the config for the desired SSDMetaArch. is_training: True if this model is being built for training purposes. add_summaries: Whether to add tf summaries in the model. Returns: SSDMetaArch based on the config. Raises: ValueError: If ssd_config.type is not recognized (i.e. not registered in model_class_map). """ num_classes = ssd_config.num_classes _check_feature_extractor_exists(ssd_config.feature_extractor.type) # Feature extractor feature_extractor = _build_ssd_feature_extractor( feature_extractor_config=ssd_config.feature_extractor, freeze_batchnorm=ssd_config.freeze_batchnorm, is_training=is_training) box_coder = box_coder_builder.build(ssd_config.box_coder) matcher = matcher_builder.build(ssd_config.matcher) region_similarity_calculator = sim_calc.build( ssd_config.similarity_calculator) encode_background_as_zeros = ssd_config.encode_background_as_zeros negative_class_weight = ssd_config.negative_class_weight anchor_generator = anchor_generator_builder.build( ssd_config.anchor_generator) if feature_extractor.is_keras_model: ssd_box_predictor = box_predictor_builder.build_keras( hyperparams_fn=hyperparams_builder.KerasLayerHyperparams, freeze_batchnorm=ssd_config.freeze_batchnorm, inplace_batchnorm_update=False, num_predictions_per_location_list=anchor_generator .num_anchors_per_location(), box_predictor_config=ssd_config.box_predictor, is_training=is_training, num_classes=num_classes, add_background_class=ssd_config.add_background_class) else: ssd_box_predictor = box_predictor_builder.build( hyperparams_builder.build, ssd_config.box_predictor, is_training, num_classes, ssd_config.add_background_class) image_resizer_fn = image_resizer_builder.build(ssd_config.image_resizer) non_max_suppression_fn, score_conversion_fn = post_processing_builder.build( ssd_config.post_processing) (classification_loss, localization_loss, classification_weight, localization_weight, hard_example_miner, random_example_sampler, expected_loss_weights_fn) = losses_builder.build(ssd_config.loss) normalize_loss_by_num_matches = ssd_config.normalize_loss_by_num_matches normalize_loc_loss_by_codesize = ssd_config.normalize_loc_loss_by_codesize equalization_loss_config = ops.EqualizationLossConfig( weight=ssd_config.loss.equalization_loss.weight, exclude_prefixes=ssd_config.loss.equalization_loss.exclude_prefixes) target_assigner_instance = target_assigner.TargetAssigner( region_similarity_calculator, matcher, box_coder, negative_class_weight=negative_class_weight) ssd_meta_arch_fn = ssd_meta_arch.SSDMetaArch kwargs = {} return ssd_meta_arch_fn( is_training=is_training, anchor_generator=anchor_generator, box_predictor=ssd_box_predictor, box_coder=box_coder, feature_extractor=feature_extractor, encode_background_as_zeros=encode_background_as_zeros, image_resizer_fn=image_resizer_fn, non_max_suppression_fn=non_max_suppression_fn, score_conversion_fn=score_conversion_fn, classification_loss=classification_loss, localization_loss=localization_loss, classification_loss_weight=classification_weight, localization_loss_weight=localization_weight, normalize_loss_by_num_matches=normalize_loss_by_num_matches, hard_example_miner=hard_example_miner, target_assigner_instance=target_assigner_instance, add_summaries=add_summaries, normalize_loc_loss_by_codesize=normalize_loc_loss_by_codesize, freeze_batchnorm=ssd_config.freeze_batchnorm, inplace_batchnorm_update=ssd_config.inplace_batchnorm_update, add_background_class=ssd_config.add_background_class, explicit_background_class=ssd_config.explicit_background_class, random_example_sampler=random_example_sampler, expected_loss_weights_fn=expected_loss_weights_fn, use_confidences_as_targets=ssd_config.use_confidences_as_targets, implicit_example_weight=ssd_config.implicit_example_weight, equalization_loss_config=equalization_loss_config, return_raw_detections_during_predict=( ssd_config.return_raw_detections_during_predict), **kwargs) def _build_faster_rcnn_feature_extractor( feature_extractor_config, is_training, reuse_weights=True, inplace_batchnorm_update=False): """Builds a faster_rcnn_meta_arch.FasterRCNNFeatureExtractor based on config. Args: feature_extractor_config: A FasterRcnnFeatureExtractor proto config from faster_rcnn.proto. is_training: True if this feature extractor is being built for training. reuse_weights: if the feature extractor should reuse weights. inplace_batchnorm_update: Whether to update batch_norm inplace during training. This is required for batch norm to work correctly on TPUs. When this is false, user must add a control dependency on tf.GraphKeys.UPDATE_OPS for train/loss op in order to update the batch norm moving average parameters. Returns: faster_rcnn_meta_arch.FasterRCNNFeatureExtractor based on config. Raises: ValueError: On invalid feature extractor type. """ if inplace_batchnorm_update: raise ValueError('inplace batchnorm updates not supported.') feature_type = feature_extractor_config.type first_stage_features_stride = ( feature_extractor_config.first_stage_features_stride) batch_norm_trainable = feature_extractor_config.batch_norm_trainable if feature_type not in FASTER_RCNN_FEATURE_EXTRACTOR_CLASS_MAP: raise ValueError('Unknown Faster R-CNN feature_extractor: {}'.format( feature_type)) feature_extractor_class = FASTER_RCNN_FEATURE_EXTRACTOR_CLASS_MAP[ feature_type] return feature_extractor_class( is_training, first_stage_features_stride, batch_norm_trainable, reuse_weights=reuse_weights) def _build_faster_rcnn_keras_feature_extractor( feature_extractor_config, is_training, inplace_batchnorm_update=False): """Builds a faster_rcnn_meta_arch.FasterRCNNKerasFeatureExtractor from config. Args: feature_extractor_config: A FasterRcnnFeatureExtractor proto config from faster_rcnn.proto. is_training: True if this feature extractor is being built for training. inplace_batchnorm_update: Whether to update batch_norm inplace during training. This is required for batch norm to work correctly on TPUs. When this is false, user must add a control dependency on tf.GraphKeys.UPDATE_OPS for train/loss op in order to update the batch norm moving average parameters. Returns: faster_rcnn_meta_arch.FasterRCNNKerasFeatureExtractor based on config. Raises: ValueError: On invalid feature extractor type. """ if inplace_batchnorm_update: raise ValueError('inplace batchnorm updates not supported.') feature_type = feature_extractor_config.type first_stage_features_stride = ( feature_extractor_config.first_stage_features_stride) batch_norm_trainable = feature_extractor_config.batch_norm_trainable if feature_type not in FASTER_RCNN_KERAS_FEATURE_EXTRACTOR_CLASS_MAP: raise ValueError('Unknown Faster R-CNN feature_extractor: {}'.format( feature_type)) feature_extractor_class = FASTER_RCNN_KERAS_FEATURE_EXTRACTOR_CLASS_MAP[ feature_type] return feature_extractor_class( is_training, first_stage_features_stride, batch_norm_trainable) def _build_faster_rcnn_model(frcnn_config, is_training, add_summaries): """Builds a Faster R-CNN or R-FCN detection model based on the model config. Builds R-FCN model if the second_stage_box_predictor in the config is of type `rfcn_box_predictor` else builds a Faster R-CNN model. Args: frcnn_config: A faster_rcnn.proto object containing the config for the desired FasterRCNNMetaArch or RFCNMetaArch. is_training: True if this model is being built for training purposes. add_summaries: Whether to add tf summaries in the model. Returns: FasterRCNNMetaArch based on the config. Raises: ValueError: If frcnn_config.type is not recognized (i.e. not registered in model_class_map). """ num_classes = frcnn_config.num_classes image_resizer_fn = image_resizer_builder.build(frcnn_config.image_resizer) _check_feature_extractor_exists(frcnn_config.feature_extractor.type) is_keras = tf_version.is_tf2() if is_keras: feature_extractor = _build_faster_rcnn_keras_feature_extractor( frcnn_config.feature_extractor, is_training, inplace_batchnorm_update=frcnn_config.inplace_batchnorm_update) else: feature_extractor = _build_faster_rcnn_feature_extractor( frcnn_config.feature_extractor, is_training, inplace_batchnorm_update=frcnn_config.inplace_batchnorm_update) number_of_stages = frcnn_config.number_of_stages first_stage_anchor_generator = anchor_generator_builder.build( frcnn_config.first_stage_anchor_generator) first_stage_target_assigner = target_assigner.create_target_assigner( 'FasterRCNN', 'proposal', use_matmul_gather=frcnn_config.use_matmul_gather_in_matcher) first_stage_atrous_rate = frcnn_config.first_stage_atrous_rate if is_keras: first_stage_box_predictor_arg_scope_fn = ( hyperparams_builder.KerasLayerHyperparams( frcnn_config.first_stage_box_predictor_conv_hyperparams)) else: first_stage_box_predictor_arg_scope_fn = hyperparams_builder.build( frcnn_config.first_stage_box_predictor_conv_hyperparams, is_training) first_stage_box_predictor_kernel_size = ( frcnn_config.first_stage_box_predictor_kernel_size) first_stage_box_predictor_depth = frcnn_config.first_stage_box_predictor_depth first_stage_minibatch_size = frcnn_config.first_stage_minibatch_size use_static_shapes = frcnn_config.use_static_shapes and ( frcnn_config.use_static_shapes_for_eval or is_training) first_stage_sampler = sampler.BalancedPositiveNegativeSampler( positive_fraction=frcnn_config.first_stage_positive_balance_fraction, is_static=(frcnn_config.use_static_balanced_label_sampler and use_static_shapes)) first_stage_max_proposals = frcnn_config.first_stage_max_proposals if (frcnn_config.first_stage_nms_iou_threshold < 0 or frcnn_config.first_stage_nms_iou_threshold > 1.0): raise ValueError('iou_threshold not in [0, 1.0].') if (is_training and frcnn_config.second_stage_batch_size > first_stage_max_proposals): raise ValueError('second_stage_batch_size should be no greater than ' 'first_stage_max_proposals.') first_stage_non_max_suppression_fn = functools.partial( post_processing.batch_multiclass_non_max_suppression, score_thresh=frcnn_config.first_stage_nms_score_threshold, iou_thresh=frcnn_config.first_stage_nms_iou_threshold, max_size_per_class=frcnn_config.first_stage_max_proposals, max_total_size=frcnn_config.first_stage_max_proposals, use_static_shapes=use_static_shapes, use_partitioned_nms=frcnn_config.use_partitioned_nms_in_first_stage, use_combined_nms=frcnn_config.use_combined_nms_in_first_stage) first_stage_loc_loss_weight = ( frcnn_config.first_stage_localization_loss_weight) first_stage_obj_loss_weight = frcnn_config.first_stage_objectness_loss_weight initial_crop_size = frcnn_config.initial_crop_size maxpool_kernel_size = frcnn_config.maxpool_kernel_size maxpool_stride = frcnn_config.maxpool_stride second_stage_target_assigner = target_assigner.create_target_assigner( 'FasterRCNN', 'detection', use_matmul_gather=frcnn_config.use_matmul_gather_in_matcher) if is_keras: second_stage_box_predictor = box_predictor_builder.build_keras( hyperparams_builder.KerasLayerHyperparams, freeze_batchnorm=False, inplace_batchnorm_update=False, num_predictions_per_location_list=[1], box_predictor_config=frcnn_config.second_stage_box_predictor, is_training=is_training, num_classes=num_classes) else: second_stage_box_predictor = box_predictor_builder.build( hyperparams_builder.build, frcnn_config.second_stage_box_predictor, is_training=is_training, num_classes=num_classes) second_stage_batch_size = frcnn_config.second_stage_batch_size second_stage_sampler = sampler.BalancedPositiveNegativeSampler( positive_fraction=frcnn_config.second_stage_balance_fraction, is_static=(frcnn_config.use_static_balanced_label_sampler and use_static_shapes)) (second_stage_non_max_suppression_fn, second_stage_score_conversion_fn ) = post_processing_builder.build(frcnn_config.second_stage_post_processing) second_stage_localization_loss_weight = ( frcnn_config.second_stage_localization_loss_weight) second_stage_classification_loss = ( losses_builder.build_faster_rcnn_classification_loss( frcnn_config.second_stage_classification_loss)) second_stage_classification_loss_weight = ( frcnn_config.second_stage_classification_loss_weight) second_stage_mask_prediction_loss_weight = ( frcnn_config.second_stage_mask_prediction_loss_weight) hard_example_miner = None if frcnn_config.HasField('hard_example_miner'): hard_example_miner = losses_builder.build_hard_example_miner( frcnn_config.hard_example_miner, second_stage_classification_loss_weight, second_stage_localization_loss_weight) crop_and_resize_fn = ( ops.matmul_crop_and_resize if frcnn_config.use_matmul_crop_and_resize else ops.native_crop_and_resize) clip_anchors_to_image = ( frcnn_config.clip_anchors_to_image) common_kwargs = { 'is_training': is_training, 'num_classes': num_classes, 'image_resizer_fn': image_resizer_fn, 'feature_extractor': feature_extractor, 'number_of_stages': number_of_stages, 'first_stage_anchor_generator': first_stage_anchor_generator, 'first_stage_target_assigner': first_stage_target_assigner, 'first_stage_atrous_rate': first_stage_atrous_rate, 'first_stage_box_predictor_arg_scope_fn': first_stage_box_predictor_arg_scope_fn, 'first_stage_box_predictor_kernel_size': first_stage_box_predictor_kernel_size, 'first_stage_box_predictor_depth': first_stage_box_predictor_depth, 'first_stage_minibatch_size': first_stage_minibatch_size, 'first_stage_sampler': first_stage_sampler, 'first_stage_non_max_suppression_fn': first_stage_non_max_suppression_fn, 'first_stage_max_proposals': first_stage_max_proposals, 'first_stage_localization_loss_weight': first_stage_loc_loss_weight, 'first_stage_objectness_loss_weight': first_stage_obj_loss_weight, 'second_stage_target_assigner': second_stage_target_assigner, 'second_stage_batch_size': second_stage_batch_size, 'second_stage_sampler': second_stage_sampler, 'second_stage_non_max_suppression_fn': second_stage_non_max_suppression_fn, 'second_stage_score_conversion_fn': second_stage_score_conversion_fn, 'second_stage_localization_loss_weight': second_stage_localization_loss_weight, 'second_stage_classification_loss': second_stage_classification_loss, 'second_stage_classification_loss_weight': second_stage_classification_loss_weight, 'hard_example_miner': hard_example_miner, 'add_summaries': add_summaries, 'crop_and_resize_fn': crop_and_resize_fn, 'clip_anchors_to_image': clip_anchors_to_image, 'use_static_shapes': use_static_shapes, 'resize_masks': frcnn_config.resize_masks, 'return_raw_detections_during_predict': frcnn_config.return_raw_detections_during_predict, 'output_final_box_features': frcnn_config.output_final_box_features } if ((not is_keras and isinstance(second_stage_box_predictor, rfcn_box_predictor.RfcnBoxPredictor)) or (is_keras and isinstance(second_stage_box_predictor, rfcn_keras_box_predictor.RfcnKerasBoxPredictor))): return rfcn_meta_arch.RFCNMetaArch( second_stage_rfcn_box_predictor=second_stage_box_predictor, **common_kwargs) elif frcnn_config.HasField('context_config'): context_config = frcnn_config.context_config common_kwargs.update({ 'attention_bottleneck_dimension': context_config.attention_bottleneck_dimension, 'attention_temperature': context_config.attention_temperature }) return context_rcnn_meta_arch.ContextRCNNMetaArch( initial_crop_size=initial_crop_size, maxpool_kernel_size=maxpool_kernel_size, maxpool_stride=maxpool_stride, second_stage_mask_rcnn_box_predictor=second_stage_box_predictor, second_stage_mask_prediction_loss_weight=( second_stage_mask_prediction_loss_weight), **common_kwargs) else: return faster_rcnn_meta_arch.FasterRCNNMetaArch( initial_crop_size=initial_crop_size, maxpool_kernel_size=maxpool_kernel_size, maxpool_stride=maxpool_stride, second_stage_mask_rcnn_box_predictor=second_stage_box_predictor, second_stage_mask_prediction_loss_weight=( second_stage_mask_prediction_loss_weight), **common_kwargs) EXPERIMENTAL_META_ARCH_BUILDER_MAP = { } def _build_experimental_model(config, is_training, add_summaries=True): return EXPERIMENTAL_META_ARCH_BUILDER_MAP[config.name]( is_training, add_summaries) # The class ID in the groundtruth/model architecture is usually 0-based while # the ID in the label map is 1-based. The offset is used to convert between the # the two. CLASS_ID_OFFSET = 1 KEYPOINT_STD_DEV_DEFAULT = 1.0 def keypoint_proto_to_params(kp_config, keypoint_map_dict): """Converts CenterNet.KeypointEstimation proto to parameter namedtuple.""" label_map_item = keypoint_map_dict[kp_config.keypoint_class_name] classification_loss, localization_loss, _, _, _, _, _ = ( losses_builder.build(kp_config.loss)) keypoint_indices = [ keypoint.id for keypoint in label_map_item.keypoints ] keypoint_labels = [ keypoint.label for keypoint in label_map_item.keypoints ] keypoint_std_dev_dict = { label: KEYPOINT_STD_DEV_DEFAULT for label in keypoint_labels } if kp_config.keypoint_label_to_std: for label, value in kp_config.keypoint_label_to_std.items(): keypoint_std_dev_dict[label] = value keypoint_std_dev = [keypoint_std_dev_dict[label] for label in keypoint_labels] return center_net_meta_arch.KeypointEstimationParams( task_name=kp_config.task_name, class_id=label_map_item.id - CLASS_ID_OFFSET, keypoint_indices=keypoint_indices, classification_loss=classification_loss, localization_loss=localization_loss, keypoint_labels=keypoint_labels, keypoint_std_dev=keypoint_std_dev, task_loss_weight=kp_config.task_loss_weight, keypoint_regression_loss_weight=kp_config.keypoint_regression_loss_weight, keypoint_heatmap_loss_weight=kp_config.keypoint_heatmap_loss_weight, keypoint_offset_loss_weight=kp_config.keypoint_offset_loss_weight, heatmap_bias_init=kp_config.heatmap_bias_init, keypoint_candidate_score_threshold=( kp_config.keypoint_candidate_score_threshold), num_candidates_per_keypoint=kp_config.num_candidates_per_keypoint, peak_max_pool_kernel_size=kp_config.peak_max_pool_kernel_size, unmatched_keypoint_score=kp_config.unmatched_keypoint_score, box_scale=kp_config.box_scale, candidate_search_scale=kp_config.candidate_search_scale, candidate_ranking_mode=kp_config.candidate_ranking_mode, offset_peak_radius=kp_config.offset_peak_radius, per_keypoint_offset=kp_config.per_keypoint_offset) def object_detection_proto_to_params(od_config): """Converts CenterNet.ObjectDetection proto to parameter namedtuple.""" loss = losses_pb2.Loss() # Add dummy classification loss to avoid the loss_builder throwing error. # TODO(yuhuic): update the loss builder to take the classification loss # directly. loss.classification_loss.weighted_sigmoid.CopyFrom( losses_pb2.WeightedSigmoidClassificationLoss()) loss.localization_loss.CopyFrom(od_config.localization_loss) _, localization_loss, _, _, _, _, _ = (losses_builder.build(loss)) return center_net_meta_arch.ObjectDetectionParams( localization_loss=localization_loss, scale_loss_weight=od_config.scale_loss_weight, offset_loss_weight=od_config.offset_loss_weight, task_loss_weight=od_config.task_loss_weight) def object_center_proto_to_params(oc_config): """Converts CenterNet.ObjectCenter proto to parameter namedtuple.""" loss = losses_pb2.Loss() # Add dummy localization loss to avoid the loss_builder throwing error. # TODO(yuhuic): update the loss builder to take the localization loss # directly. loss.localization_loss.weighted_l2.CopyFrom( losses_pb2.WeightedL2LocalizationLoss()) loss.classification_loss.CopyFrom(oc_config.classification_loss) classification_loss, _, _, _, _, _, _ = (losses_builder.build(loss)) return center_net_meta_arch.ObjectCenterParams( classification_loss=classification_loss, object_center_loss_weight=oc_config.object_center_loss_weight, heatmap_bias_init=oc_config.heatmap_bias_init, min_box_overlap_iou=oc_config.min_box_overlap_iou, max_box_predictions=oc_config.max_box_predictions, use_labeled_classes=oc_config.use_labeled_classes) def mask_proto_to_params(mask_config): """Converts CenterNet.MaskEstimation proto to parameter namedtuple.""" loss = losses_pb2.Loss() # Add dummy localization loss to avoid the loss_builder throwing error. loss.localization_loss.weighted_l2.CopyFrom( losses_pb2.WeightedL2LocalizationLoss()) loss.classification_loss.CopyFrom(mask_config.classification_loss) classification_loss, _, _, _, _, _, _ = (losses_builder.build(loss)) return center_net_meta_arch.MaskParams( classification_loss=classification_loss, task_loss_weight=mask_config.task_loss_weight, mask_height=mask_config.mask_height, mask_width=mask_config.mask_width, score_threshold=mask_config.score_threshold, heatmap_bias_init=mask_config.heatmap_bias_init) def _build_center_net_model(center_net_config, is_training, add_summaries): """Build a CenterNet detection model. Args: center_net_config: A CenterNet proto object with model configuration. is_training: True if this model is being built for training purposes. add_summaries: Whether to add tf summaries in the model. Returns: CenterNetMetaArch based on the config. """ image_resizer_fn = image_resizer_builder.build( center_net_config.image_resizer) _check_feature_extractor_exists(center_net_config.feature_extractor.type) feature_extractor = _build_center_net_feature_extractor( center_net_config.feature_extractor) object_center_params = object_center_proto_to_params( center_net_config.object_center_params) object_detection_params = None if center_net_config.HasField('object_detection_task'): object_detection_params = object_detection_proto_to_params( center_net_config.object_detection_task) keypoint_params_dict = None if center_net_config.keypoint_estimation_task: label_map_proto = label_map_util.load_labelmap( center_net_config.keypoint_label_map_path) keypoint_map_dict = { item.name: item for item in label_map_proto.item if item.keypoints } keypoint_params_dict = {} keypoint_class_id_set = set() all_keypoint_indices = [] for task in center_net_config.keypoint_estimation_task: kp_params = keypoint_proto_to_params(task, keypoint_map_dict) keypoint_params_dict[task.task_name] = kp_params all_keypoint_indices.extend(kp_params.keypoint_indices) if kp_params.class_id in keypoint_class_id_set: raise ValueError(('Multiple keypoint tasks map to the same class id is ' 'not allowed: %d' % kp_params.class_id)) else: keypoint_class_id_set.add(kp_params.class_id) if len(all_keypoint_indices) > len(set(all_keypoint_indices)): raise ValueError('Some keypoint indices are used more than once.') mask_params = None if center_net_config.HasField('mask_estimation_task'): mask_params = mask_proto_to_params(center_net_config.mask_estimation_task) return center_net_meta_arch.CenterNetMetaArch( is_training=is_training, add_summaries=add_summaries, num_classes=center_net_config.num_classes, feature_extractor=feature_extractor, image_resizer_fn=image_resizer_fn, object_center_params=object_center_params, object_detection_params=object_detection_params, keypoint_params_dict=keypoint_params_dict, mask_params=mask_params) def _build_center_net_feature_extractor( feature_extractor_config): """Build a CenterNet feature extractor from the given config.""" if feature_extractor_config.type not in CENTER_NET_EXTRACTOR_FUNCTION_MAP: raise ValueError('\'{}\' is not a known CenterNet feature extractor type' .format(feature_extractor_config.type)) return CENTER_NET_EXTRACTOR_FUNCTION_MAP[feature_extractor_config.type]( channel_means=list(feature_extractor_config.channel_means), channel_stds=list(feature_extractor_config.channel_stds), bgr_ordering=feature_extractor_config.bgr_ordering ) META_ARCH_BUILDER_MAP = { 'ssd': _build_ssd_model, 'faster_rcnn': _build_faster_rcnn_model, 'experimental_model': _build_experimental_model, 'center_net': _build_center_net_model } def build(model_config, is_training, add_summaries=True): """Builds a DetectionModel based on the model config. Args: model_config: A model.proto object containing the config for the desired DetectionModel. is_training: True if this model is being built for training purposes. add_summaries: Whether to add tensorflow summaries in the model graph. Returns: DetectionModel based on the config. Raises: ValueError: On invalid meta architecture or model. """ if not isinstance(model_config, model_pb2.DetectionModel): raise ValueError('model_config not of type model_pb2.DetectionModel.') meta_architecture = model_config.WhichOneof('model') if meta_architecture not in META_ARCH_BUILDER_MAP: raise ValueError('Unknown meta architecture: {}'.format(meta_architecture)) else: build_func = META_ARCH_BUILDER_MAP[meta_architecture] return build_func(getattr(model_config, meta_architecture), is_training, add_summaries)