NCTCMumbai's picture
Upload 2583 files
97b6013 verified
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A function to build a DetectionModel from configuration."""
import functools
from object_detection.builders import anchor_generator_builder
from object_detection.builders import box_coder_builder
from object_detection.builders import box_predictor_builder
from object_detection.builders import hyperparams_builder
from object_detection.builders import image_resizer_builder
from object_detection.builders import losses_builder
from object_detection.builders import matcher_builder
from object_detection.builders import post_processing_builder
from object_detection.builders import region_similarity_calculator_builder as sim_calc
from object_detection.core import balanced_positive_negative_sampler as sampler
from object_detection.core import post_processing
from object_detection.core import target_assigner
from object_detection.meta_architectures import center_net_meta_arch
from object_detection.meta_architectures import context_rcnn_meta_arch
from object_detection.meta_architectures import faster_rcnn_meta_arch
from object_detection.meta_architectures import rfcn_meta_arch
from object_detection.meta_architectures import ssd_meta_arch
from object_detection.predictors.heads import mask_head
from object_detection.protos import losses_pb2
from object_detection.protos import model_pb2
from object_detection.utils import label_map_util
from object_detection.utils import ops
from object_detection.utils import tf_version
## Feature Extractors for TF
## This section conditionally imports different feature extractors based on the
## Tensorflow version.
##
# pylint: disable=g-import-not-at-top
if tf_version.is_tf2():
from object_detection.models import center_net_hourglass_feature_extractor
from object_detection.models import center_net_resnet_feature_extractor
from object_detection.models import center_net_resnet_v1_fpn_feature_extractor
from object_detection.models import faster_rcnn_inception_resnet_v2_keras_feature_extractor as frcnn_inc_res_keras
from object_detection.models import faster_rcnn_resnet_keras_feature_extractor as frcnn_resnet_keras
from object_detection.models import ssd_resnet_v1_fpn_keras_feature_extractor as ssd_resnet_v1_fpn_keras
from object_detection.models.ssd_mobilenet_v1_fpn_keras_feature_extractor import SSDMobileNetV1FpnKerasFeatureExtractor
from object_detection.models.ssd_mobilenet_v1_keras_feature_extractor import SSDMobileNetV1KerasFeatureExtractor
from object_detection.models.ssd_mobilenet_v2_fpn_keras_feature_extractor import SSDMobileNetV2FpnKerasFeatureExtractor
from object_detection.models.ssd_mobilenet_v2_keras_feature_extractor import SSDMobileNetV2KerasFeatureExtractor
from object_detection.predictors import rfcn_keras_box_predictor
if tf_version.is_tf1():
from object_detection.models import faster_rcnn_inception_resnet_v2_feature_extractor as frcnn_inc_res
from object_detection.models import faster_rcnn_inception_v2_feature_extractor as frcnn_inc_v2
from object_detection.models import faster_rcnn_nas_feature_extractor as frcnn_nas
from object_detection.models import faster_rcnn_pnas_feature_extractor as frcnn_pnas
from object_detection.models import faster_rcnn_resnet_v1_feature_extractor as frcnn_resnet_v1
from object_detection.models import ssd_resnet_v1_fpn_feature_extractor as ssd_resnet_v1_fpn
from object_detection.models import ssd_resnet_v1_ppn_feature_extractor as ssd_resnet_v1_ppn
from object_detection.models.embedded_ssd_mobilenet_v1_feature_extractor import EmbeddedSSDMobileNetV1FeatureExtractor
from object_detection.models.ssd_inception_v2_feature_extractor import SSDInceptionV2FeatureExtractor
from object_detection.models.ssd_mobilenet_v2_fpn_feature_extractor import SSDMobileNetV2FpnFeatureExtractor
from object_detection.models.ssd_mobilenet_v2_mnasfpn_feature_extractor import SSDMobileNetV2MnasFPNFeatureExtractor
from object_detection.models.ssd_inception_v3_feature_extractor import SSDInceptionV3FeatureExtractor
from object_detection.models.ssd_mobilenet_edgetpu_feature_extractor import SSDMobileNetEdgeTPUFeatureExtractor
from object_detection.models.ssd_mobilenet_v1_feature_extractor import SSDMobileNetV1FeatureExtractor
from object_detection.models.ssd_mobilenet_v1_fpn_feature_extractor import SSDMobileNetV1FpnFeatureExtractor
from object_detection.models.ssd_mobilenet_v1_ppn_feature_extractor import SSDMobileNetV1PpnFeatureExtractor
from object_detection.models.ssd_mobilenet_v2_feature_extractor import SSDMobileNetV2FeatureExtractor
from object_detection.models.ssd_mobilenet_v3_feature_extractor import SSDMobileNetV3LargeFeatureExtractor
from object_detection.models.ssd_mobilenet_v3_feature_extractor import SSDMobileNetV3SmallFeatureExtractor
from object_detection.models.ssd_mobiledet_feature_extractor import SSDMobileDetCPUFeatureExtractor
from object_detection.models.ssd_mobiledet_feature_extractor import SSDMobileDetDSPFeatureExtractor
from object_detection.models.ssd_mobiledet_feature_extractor import SSDMobileDetEdgeTPUFeatureExtractor
from object_detection.models.ssd_mobiledet_feature_extractor import SSDMobileDetGPUFeatureExtractor
from object_detection.models.ssd_pnasnet_feature_extractor import SSDPNASNetFeatureExtractor
from object_detection.predictors import rfcn_box_predictor
# pylint: enable=g-import-not-at-top
if tf_version.is_tf2():
SSD_KERAS_FEATURE_EXTRACTOR_CLASS_MAP = {
'ssd_mobilenet_v1_keras': SSDMobileNetV1KerasFeatureExtractor,
'ssd_mobilenet_v1_fpn_keras': SSDMobileNetV1FpnKerasFeatureExtractor,
'ssd_mobilenet_v2_keras': SSDMobileNetV2KerasFeatureExtractor,
'ssd_mobilenet_v2_fpn_keras': SSDMobileNetV2FpnKerasFeatureExtractor,
'ssd_resnet50_v1_fpn_keras':
ssd_resnet_v1_fpn_keras.SSDResNet50V1FpnKerasFeatureExtractor,
'ssd_resnet101_v1_fpn_keras':
ssd_resnet_v1_fpn_keras.SSDResNet101V1FpnKerasFeatureExtractor,
'ssd_resnet152_v1_fpn_keras':
ssd_resnet_v1_fpn_keras.SSDResNet152V1FpnKerasFeatureExtractor,
}
FASTER_RCNN_KERAS_FEATURE_EXTRACTOR_CLASS_MAP = {
'faster_rcnn_resnet50_keras':
frcnn_resnet_keras.FasterRCNNResnet50KerasFeatureExtractor,
'faster_rcnn_resnet101_keras':
frcnn_resnet_keras.FasterRCNNResnet101KerasFeatureExtractor,
'faster_rcnn_resnet152_keras':
frcnn_resnet_keras.FasterRCNNResnet152KerasFeatureExtractor,
'faster_rcnn_inception_resnet_v2_keras':
frcnn_inc_res_keras.FasterRCNNInceptionResnetV2KerasFeatureExtractor,
}
CENTER_NET_EXTRACTOR_FUNCTION_MAP = {
'resnet_v2_50': center_net_resnet_feature_extractor.resnet_v2_50,
'resnet_v2_101': center_net_resnet_feature_extractor.resnet_v2_101,
'resnet_v1_50_fpn':
center_net_resnet_v1_fpn_feature_extractor.resnet_v1_50_fpn,
'resnet_v1_101_fpn':
center_net_resnet_v1_fpn_feature_extractor.resnet_v1_101_fpn,
'hourglass_104': center_net_hourglass_feature_extractor.hourglass_104,
}
FEATURE_EXTRACTOR_MAPS = [
CENTER_NET_EXTRACTOR_FUNCTION_MAP,
FASTER_RCNN_KERAS_FEATURE_EXTRACTOR_CLASS_MAP,
SSD_KERAS_FEATURE_EXTRACTOR_CLASS_MAP
]
if tf_version.is_tf1():
SSD_FEATURE_EXTRACTOR_CLASS_MAP = {
'ssd_inception_v2':
SSDInceptionV2FeatureExtractor,
'ssd_inception_v3':
SSDInceptionV3FeatureExtractor,
'ssd_mobilenet_v1':
SSDMobileNetV1FeatureExtractor,
'ssd_mobilenet_v1_fpn':
SSDMobileNetV1FpnFeatureExtractor,
'ssd_mobilenet_v1_ppn':
SSDMobileNetV1PpnFeatureExtractor,
'ssd_mobilenet_v2':
SSDMobileNetV2FeatureExtractor,
'ssd_mobilenet_v2_fpn':
SSDMobileNetV2FpnFeatureExtractor,
'ssd_mobilenet_v2_mnasfpn':
SSDMobileNetV2MnasFPNFeatureExtractor,
'ssd_mobilenet_v3_large':
SSDMobileNetV3LargeFeatureExtractor,
'ssd_mobilenet_v3_small':
SSDMobileNetV3SmallFeatureExtractor,
'ssd_mobilenet_edgetpu':
SSDMobileNetEdgeTPUFeatureExtractor,
'ssd_resnet50_v1_fpn':
ssd_resnet_v1_fpn.SSDResnet50V1FpnFeatureExtractor,
'ssd_resnet101_v1_fpn':
ssd_resnet_v1_fpn.SSDResnet101V1FpnFeatureExtractor,
'ssd_resnet152_v1_fpn':
ssd_resnet_v1_fpn.SSDResnet152V1FpnFeatureExtractor,
'ssd_resnet50_v1_ppn':
ssd_resnet_v1_ppn.SSDResnet50V1PpnFeatureExtractor,
'ssd_resnet101_v1_ppn':
ssd_resnet_v1_ppn.SSDResnet101V1PpnFeatureExtractor,
'ssd_resnet152_v1_ppn':
ssd_resnet_v1_ppn.SSDResnet152V1PpnFeatureExtractor,
'embedded_ssd_mobilenet_v1':
EmbeddedSSDMobileNetV1FeatureExtractor,
'ssd_pnasnet':
SSDPNASNetFeatureExtractor,
'ssd_mobiledet_cpu':
SSDMobileDetCPUFeatureExtractor,
'ssd_mobiledet_dsp':
SSDMobileDetDSPFeatureExtractor,
'ssd_mobiledet_edgetpu':
SSDMobileDetEdgeTPUFeatureExtractor,
'ssd_mobiledet_gpu':
SSDMobileDetGPUFeatureExtractor,
}
FASTER_RCNN_FEATURE_EXTRACTOR_CLASS_MAP = {
'faster_rcnn_nas':
frcnn_nas.FasterRCNNNASFeatureExtractor,
'faster_rcnn_pnas':
frcnn_pnas.FasterRCNNPNASFeatureExtractor,
'faster_rcnn_inception_resnet_v2':
frcnn_inc_res.FasterRCNNInceptionResnetV2FeatureExtractor,
'faster_rcnn_inception_v2':
frcnn_inc_v2.FasterRCNNInceptionV2FeatureExtractor,
'faster_rcnn_resnet50':
frcnn_resnet_v1.FasterRCNNResnet50FeatureExtractor,
'faster_rcnn_resnet101':
frcnn_resnet_v1.FasterRCNNResnet101FeatureExtractor,
'faster_rcnn_resnet152':
frcnn_resnet_v1.FasterRCNNResnet152FeatureExtractor,
}
FEATURE_EXTRACTOR_MAPS = [
SSD_FEATURE_EXTRACTOR_CLASS_MAP,
FASTER_RCNN_FEATURE_EXTRACTOR_CLASS_MAP
]
def _check_feature_extractor_exists(feature_extractor_type):
feature_extractors = set().union(*FEATURE_EXTRACTOR_MAPS)
if feature_extractor_type not in feature_extractors:
raise ValueError('{} is not supported. See `model_builder.py` for features '
'extractors compatible with different versions of '
'Tensorflow'.format(feature_extractor_type))
def _build_ssd_feature_extractor(feature_extractor_config,
is_training,
freeze_batchnorm,
reuse_weights=None):
"""Builds a ssd_meta_arch.SSDFeatureExtractor based on config.
Args:
feature_extractor_config: A SSDFeatureExtractor proto config from ssd.proto.
is_training: True if this feature extractor is being built for training.
freeze_batchnorm: Whether to freeze batch norm parameters during
training or not. When training with a small batch size (e.g. 1), it is
desirable to freeze batch norm update and use pretrained batch norm
params.
reuse_weights: if the feature extractor should reuse weights.
Returns:
ssd_meta_arch.SSDFeatureExtractor based on config.
Raises:
ValueError: On invalid feature extractor type.
"""
feature_type = feature_extractor_config.type
depth_multiplier = feature_extractor_config.depth_multiplier
min_depth = feature_extractor_config.min_depth
pad_to_multiple = feature_extractor_config.pad_to_multiple
use_explicit_padding = feature_extractor_config.use_explicit_padding
use_depthwise = feature_extractor_config.use_depthwise
is_keras = tf_version.is_tf2()
if is_keras:
conv_hyperparams = hyperparams_builder.KerasLayerHyperparams(
feature_extractor_config.conv_hyperparams)
else:
conv_hyperparams = hyperparams_builder.build(
feature_extractor_config.conv_hyperparams, is_training)
override_base_feature_extractor_hyperparams = (
feature_extractor_config.override_base_feature_extractor_hyperparams)
if not is_keras and feature_type not in SSD_FEATURE_EXTRACTOR_CLASS_MAP:
raise ValueError('Unknown ssd feature_extractor: {}'.format(feature_type))
if is_keras:
feature_extractor_class = SSD_KERAS_FEATURE_EXTRACTOR_CLASS_MAP[
feature_type]
else:
feature_extractor_class = SSD_FEATURE_EXTRACTOR_CLASS_MAP[feature_type]
kwargs = {
'is_training':
is_training,
'depth_multiplier':
depth_multiplier,
'min_depth':
min_depth,
'pad_to_multiple':
pad_to_multiple,
'use_explicit_padding':
use_explicit_padding,
'use_depthwise':
use_depthwise,
'override_base_feature_extractor_hyperparams':
override_base_feature_extractor_hyperparams
}
if feature_extractor_config.HasField('replace_preprocessor_with_placeholder'):
kwargs.update({
'replace_preprocessor_with_placeholder':
feature_extractor_config.replace_preprocessor_with_placeholder
})
if feature_extractor_config.HasField('num_layers'):
kwargs.update({'num_layers': feature_extractor_config.num_layers})
if is_keras:
kwargs.update({
'conv_hyperparams': conv_hyperparams,
'inplace_batchnorm_update': False,
'freeze_batchnorm': freeze_batchnorm
})
else:
kwargs.update({
'conv_hyperparams_fn': conv_hyperparams,
'reuse_weights': reuse_weights,
})
if feature_extractor_config.HasField('fpn'):
kwargs.update({
'fpn_min_level':
feature_extractor_config.fpn.min_level,
'fpn_max_level':
feature_extractor_config.fpn.max_level,
'additional_layer_depth':
feature_extractor_config.fpn.additional_layer_depth,
})
return feature_extractor_class(**kwargs)
def _build_ssd_model(ssd_config, is_training, add_summaries):
"""Builds an SSD detection model based on the model config.
Args:
ssd_config: A ssd.proto object containing the config for the desired
SSDMetaArch.
is_training: True if this model is being built for training purposes.
add_summaries: Whether to add tf summaries in the model.
Returns:
SSDMetaArch based on the config.
Raises:
ValueError: If ssd_config.type is not recognized (i.e. not registered in
model_class_map).
"""
num_classes = ssd_config.num_classes
_check_feature_extractor_exists(ssd_config.feature_extractor.type)
# Feature extractor
feature_extractor = _build_ssd_feature_extractor(
feature_extractor_config=ssd_config.feature_extractor,
freeze_batchnorm=ssd_config.freeze_batchnorm,
is_training=is_training)
box_coder = box_coder_builder.build(ssd_config.box_coder)
matcher = matcher_builder.build(ssd_config.matcher)
region_similarity_calculator = sim_calc.build(
ssd_config.similarity_calculator)
encode_background_as_zeros = ssd_config.encode_background_as_zeros
negative_class_weight = ssd_config.negative_class_weight
anchor_generator = anchor_generator_builder.build(
ssd_config.anchor_generator)
if feature_extractor.is_keras_model:
ssd_box_predictor = box_predictor_builder.build_keras(
hyperparams_fn=hyperparams_builder.KerasLayerHyperparams,
freeze_batchnorm=ssd_config.freeze_batchnorm,
inplace_batchnorm_update=False,
num_predictions_per_location_list=anchor_generator
.num_anchors_per_location(),
box_predictor_config=ssd_config.box_predictor,
is_training=is_training,
num_classes=num_classes,
add_background_class=ssd_config.add_background_class)
else:
ssd_box_predictor = box_predictor_builder.build(
hyperparams_builder.build, ssd_config.box_predictor, is_training,
num_classes, ssd_config.add_background_class)
image_resizer_fn = image_resizer_builder.build(ssd_config.image_resizer)
non_max_suppression_fn, score_conversion_fn = post_processing_builder.build(
ssd_config.post_processing)
(classification_loss, localization_loss, classification_weight,
localization_weight, hard_example_miner, random_example_sampler,
expected_loss_weights_fn) = losses_builder.build(ssd_config.loss)
normalize_loss_by_num_matches = ssd_config.normalize_loss_by_num_matches
normalize_loc_loss_by_codesize = ssd_config.normalize_loc_loss_by_codesize
equalization_loss_config = ops.EqualizationLossConfig(
weight=ssd_config.loss.equalization_loss.weight,
exclude_prefixes=ssd_config.loss.equalization_loss.exclude_prefixes)
target_assigner_instance = target_assigner.TargetAssigner(
region_similarity_calculator,
matcher,
box_coder,
negative_class_weight=negative_class_weight)
ssd_meta_arch_fn = ssd_meta_arch.SSDMetaArch
kwargs = {}
return ssd_meta_arch_fn(
is_training=is_training,
anchor_generator=anchor_generator,
box_predictor=ssd_box_predictor,
box_coder=box_coder,
feature_extractor=feature_extractor,
encode_background_as_zeros=encode_background_as_zeros,
image_resizer_fn=image_resizer_fn,
non_max_suppression_fn=non_max_suppression_fn,
score_conversion_fn=score_conversion_fn,
classification_loss=classification_loss,
localization_loss=localization_loss,
classification_loss_weight=classification_weight,
localization_loss_weight=localization_weight,
normalize_loss_by_num_matches=normalize_loss_by_num_matches,
hard_example_miner=hard_example_miner,
target_assigner_instance=target_assigner_instance,
add_summaries=add_summaries,
normalize_loc_loss_by_codesize=normalize_loc_loss_by_codesize,
freeze_batchnorm=ssd_config.freeze_batchnorm,
inplace_batchnorm_update=ssd_config.inplace_batchnorm_update,
add_background_class=ssd_config.add_background_class,
explicit_background_class=ssd_config.explicit_background_class,
random_example_sampler=random_example_sampler,
expected_loss_weights_fn=expected_loss_weights_fn,
use_confidences_as_targets=ssd_config.use_confidences_as_targets,
implicit_example_weight=ssd_config.implicit_example_weight,
equalization_loss_config=equalization_loss_config,
return_raw_detections_during_predict=(
ssd_config.return_raw_detections_during_predict),
**kwargs)
def _build_faster_rcnn_feature_extractor(
feature_extractor_config, is_training, reuse_weights=True,
inplace_batchnorm_update=False):
"""Builds a faster_rcnn_meta_arch.FasterRCNNFeatureExtractor based on config.
Args:
feature_extractor_config: A FasterRcnnFeatureExtractor proto config from
faster_rcnn.proto.
is_training: True if this feature extractor is being built for training.
reuse_weights: if the feature extractor should reuse weights.
inplace_batchnorm_update: Whether to update batch_norm inplace during
training. This is required for batch norm to work correctly on TPUs. When
this is false, user must add a control dependency on
tf.GraphKeys.UPDATE_OPS for train/loss op in order to update the batch
norm moving average parameters.
Returns:
faster_rcnn_meta_arch.FasterRCNNFeatureExtractor based on config.
Raises:
ValueError: On invalid feature extractor type.
"""
if inplace_batchnorm_update:
raise ValueError('inplace batchnorm updates not supported.')
feature_type = feature_extractor_config.type
first_stage_features_stride = (
feature_extractor_config.first_stage_features_stride)
batch_norm_trainable = feature_extractor_config.batch_norm_trainable
if feature_type not in FASTER_RCNN_FEATURE_EXTRACTOR_CLASS_MAP:
raise ValueError('Unknown Faster R-CNN feature_extractor: {}'.format(
feature_type))
feature_extractor_class = FASTER_RCNN_FEATURE_EXTRACTOR_CLASS_MAP[
feature_type]
return feature_extractor_class(
is_training, first_stage_features_stride,
batch_norm_trainable, reuse_weights=reuse_weights)
def _build_faster_rcnn_keras_feature_extractor(
feature_extractor_config, is_training,
inplace_batchnorm_update=False):
"""Builds a faster_rcnn_meta_arch.FasterRCNNKerasFeatureExtractor from config.
Args:
feature_extractor_config: A FasterRcnnFeatureExtractor proto config from
faster_rcnn.proto.
is_training: True if this feature extractor is being built for training.
inplace_batchnorm_update: Whether to update batch_norm inplace during
training. This is required for batch norm to work correctly on TPUs. When
this is false, user must add a control dependency on
tf.GraphKeys.UPDATE_OPS for train/loss op in order to update the batch
norm moving average parameters.
Returns:
faster_rcnn_meta_arch.FasterRCNNKerasFeatureExtractor based on config.
Raises:
ValueError: On invalid feature extractor type.
"""
if inplace_batchnorm_update:
raise ValueError('inplace batchnorm updates not supported.')
feature_type = feature_extractor_config.type
first_stage_features_stride = (
feature_extractor_config.first_stage_features_stride)
batch_norm_trainable = feature_extractor_config.batch_norm_trainable
if feature_type not in FASTER_RCNN_KERAS_FEATURE_EXTRACTOR_CLASS_MAP:
raise ValueError('Unknown Faster R-CNN feature_extractor: {}'.format(
feature_type))
feature_extractor_class = FASTER_RCNN_KERAS_FEATURE_EXTRACTOR_CLASS_MAP[
feature_type]
return feature_extractor_class(
is_training, first_stage_features_stride,
batch_norm_trainable)
def _build_faster_rcnn_model(frcnn_config, is_training, add_summaries):
"""Builds a Faster R-CNN or R-FCN detection model based on the model config.
Builds R-FCN model if the second_stage_box_predictor in the config is of type
`rfcn_box_predictor` else builds a Faster R-CNN model.
Args:
frcnn_config: A faster_rcnn.proto object containing the config for the
desired FasterRCNNMetaArch or RFCNMetaArch.
is_training: True if this model is being built for training purposes.
add_summaries: Whether to add tf summaries in the model.
Returns:
FasterRCNNMetaArch based on the config.
Raises:
ValueError: If frcnn_config.type is not recognized (i.e. not registered in
model_class_map).
"""
num_classes = frcnn_config.num_classes
image_resizer_fn = image_resizer_builder.build(frcnn_config.image_resizer)
_check_feature_extractor_exists(frcnn_config.feature_extractor.type)
is_keras = tf_version.is_tf2()
if is_keras:
feature_extractor = _build_faster_rcnn_keras_feature_extractor(
frcnn_config.feature_extractor, is_training,
inplace_batchnorm_update=frcnn_config.inplace_batchnorm_update)
else:
feature_extractor = _build_faster_rcnn_feature_extractor(
frcnn_config.feature_extractor, is_training,
inplace_batchnorm_update=frcnn_config.inplace_batchnorm_update)
number_of_stages = frcnn_config.number_of_stages
first_stage_anchor_generator = anchor_generator_builder.build(
frcnn_config.first_stage_anchor_generator)
first_stage_target_assigner = target_assigner.create_target_assigner(
'FasterRCNN',
'proposal',
use_matmul_gather=frcnn_config.use_matmul_gather_in_matcher)
first_stage_atrous_rate = frcnn_config.first_stage_atrous_rate
if is_keras:
first_stage_box_predictor_arg_scope_fn = (
hyperparams_builder.KerasLayerHyperparams(
frcnn_config.first_stage_box_predictor_conv_hyperparams))
else:
first_stage_box_predictor_arg_scope_fn = hyperparams_builder.build(
frcnn_config.first_stage_box_predictor_conv_hyperparams, is_training)
first_stage_box_predictor_kernel_size = (
frcnn_config.first_stage_box_predictor_kernel_size)
first_stage_box_predictor_depth = frcnn_config.first_stage_box_predictor_depth
first_stage_minibatch_size = frcnn_config.first_stage_minibatch_size
use_static_shapes = frcnn_config.use_static_shapes and (
frcnn_config.use_static_shapes_for_eval or is_training)
first_stage_sampler = sampler.BalancedPositiveNegativeSampler(
positive_fraction=frcnn_config.first_stage_positive_balance_fraction,
is_static=(frcnn_config.use_static_balanced_label_sampler and
use_static_shapes))
first_stage_max_proposals = frcnn_config.first_stage_max_proposals
if (frcnn_config.first_stage_nms_iou_threshold < 0 or
frcnn_config.first_stage_nms_iou_threshold > 1.0):
raise ValueError('iou_threshold not in [0, 1.0].')
if (is_training and frcnn_config.second_stage_batch_size >
first_stage_max_proposals):
raise ValueError('second_stage_batch_size should be no greater than '
'first_stage_max_proposals.')
first_stage_non_max_suppression_fn = functools.partial(
post_processing.batch_multiclass_non_max_suppression,
score_thresh=frcnn_config.first_stage_nms_score_threshold,
iou_thresh=frcnn_config.first_stage_nms_iou_threshold,
max_size_per_class=frcnn_config.first_stage_max_proposals,
max_total_size=frcnn_config.first_stage_max_proposals,
use_static_shapes=use_static_shapes,
use_partitioned_nms=frcnn_config.use_partitioned_nms_in_first_stage,
use_combined_nms=frcnn_config.use_combined_nms_in_first_stage)
first_stage_loc_loss_weight = (
frcnn_config.first_stage_localization_loss_weight)
first_stage_obj_loss_weight = frcnn_config.first_stage_objectness_loss_weight
initial_crop_size = frcnn_config.initial_crop_size
maxpool_kernel_size = frcnn_config.maxpool_kernel_size
maxpool_stride = frcnn_config.maxpool_stride
second_stage_target_assigner = target_assigner.create_target_assigner(
'FasterRCNN',
'detection',
use_matmul_gather=frcnn_config.use_matmul_gather_in_matcher)
if is_keras:
second_stage_box_predictor = box_predictor_builder.build_keras(
hyperparams_builder.KerasLayerHyperparams,
freeze_batchnorm=False,
inplace_batchnorm_update=False,
num_predictions_per_location_list=[1],
box_predictor_config=frcnn_config.second_stage_box_predictor,
is_training=is_training,
num_classes=num_classes)
else:
second_stage_box_predictor = box_predictor_builder.build(
hyperparams_builder.build,
frcnn_config.second_stage_box_predictor,
is_training=is_training,
num_classes=num_classes)
second_stage_batch_size = frcnn_config.second_stage_batch_size
second_stage_sampler = sampler.BalancedPositiveNegativeSampler(
positive_fraction=frcnn_config.second_stage_balance_fraction,
is_static=(frcnn_config.use_static_balanced_label_sampler and
use_static_shapes))
(second_stage_non_max_suppression_fn, second_stage_score_conversion_fn
) = post_processing_builder.build(frcnn_config.second_stage_post_processing)
second_stage_localization_loss_weight = (
frcnn_config.second_stage_localization_loss_weight)
second_stage_classification_loss = (
losses_builder.build_faster_rcnn_classification_loss(
frcnn_config.second_stage_classification_loss))
second_stage_classification_loss_weight = (
frcnn_config.second_stage_classification_loss_weight)
second_stage_mask_prediction_loss_weight = (
frcnn_config.second_stage_mask_prediction_loss_weight)
hard_example_miner = None
if frcnn_config.HasField('hard_example_miner'):
hard_example_miner = losses_builder.build_hard_example_miner(
frcnn_config.hard_example_miner,
second_stage_classification_loss_weight,
second_stage_localization_loss_weight)
crop_and_resize_fn = (
ops.matmul_crop_and_resize if frcnn_config.use_matmul_crop_and_resize
else ops.native_crop_and_resize)
clip_anchors_to_image = (
frcnn_config.clip_anchors_to_image)
common_kwargs = {
'is_training':
is_training,
'num_classes':
num_classes,
'image_resizer_fn':
image_resizer_fn,
'feature_extractor':
feature_extractor,
'number_of_stages':
number_of_stages,
'first_stage_anchor_generator':
first_stage_anchor_generator,
'first_stage_target_assigner':
first_stage_target_assigner,
'first_stage_atrous_rate':
first_stage_atrous_rate,
'first_stage_box_predictor_arg_scope_fn':
first_stage_box_predictor_arg_scope_fn,
'first_stage_box_predictor_kernel_size':
first_stage_box_predictor_kernel_size,
'first_stage_box_predictor_depth':
first_stage_box_predictor_depth,
'first_stage_minibatch_size':
first_stage_minibatch_size,
'first_stage_sampler':
first_stage_sampler,
'first_stage_non_max_suppression_fn':
first_stage_non_max_suppression_fn,
'first_stage_max_proposals':
first_stage_max_proposals,
'first_stage_localization_loss_weight':
first_stage_loc_loss_weight,
'first_stage_objectness_loss_weight':
first_stage_obj_loss_weight,
'second_stage_target_assigner':
second_stage_target_assigner,
'second_stage_batch_size':
second_stage_batch_size,
'second_stage_sampler':
second_stage_sampler,
'second_stage_non_max_suppression_fn':
second_stage_non_max_suppression_fn,
'second_stage_score_conversion_fn':
second_stage_score_conversion_fn,
'second_stage_localization_loss_weight':
second_stage_localization_loss_weight,
'second_stage_classification_loss':
second_stage_classification_loss,
'second_stage_classification_loss_weight':
second_stage_classification_loss_weight,
'hard_example_miner':
hard_example_miner,
'add_summaries':
add_summaries,
'crop_and_resize_fn':
crop_and_resize_fn,
'clip_anchors_to_image':
clip_anchors_to_image,
'use_static_shapes':
use_static_shapes,
'resize_masks':
frcnn_config.resize_masks,
'return_raw_detections_during_predict':
frcnn_config.return_raw_detections_during_predict,
'output_final_box_features':
frcnn_config.output_final_box_features
}
if ((not is_keras and isinstance(second_stage_box_predictor,
rfcn_box_predictor.RfcnBoxPredictor)) or
(is_keras and
isinstance(second_stage_box_predictor,
rfcn_keras_box_predictor.RfcnKerasBoxPredictor))):
return rfcn_meta_arch.RFCNMetaArch(
second_stage_rfcn_box_predictor=second_stage_box_predictor,
**common_kwargs)
elif frcnn_config.HasField('context_config'):
context_config = frcnn_config.context_config
common_kwargs.update({
'attention_bottleneck_dimension':
context_config.attention_bottleneck_dimension,
'attention_temperature':
context_config.attention_temperature
})
return context_rcnn_meta_arch.ContextRCNNMetaArch(
initial_crop_size=initial_crop_size,
maxpool_kernel_size=maxpool_kernel_size,
maxpool_stride=maxpool_stride,
second_stage_mask_rcnn_box_predictor=second_stage_box_predictor,
second_stage_mask_prediction_loss_weight=(
second_stage_mask_prediction_loss_weight),
**common_kwargs)
else:
return faster_rcnn_meta_arch.FasterRCNNMetaArch(
initial_crop_size=initial_crop_size,
maxpool_kernel_size=maxpool_kernel_size,
maxpool_stride=maxpool_stride,
second_stage_mask_rcnn_box_predictor=second_stage_box_predictor,
second_stage_mask_prediction_loss_weight=(
second_stage_mask_prediction_loss_weight),
**common_kwargs)
EXPERIMENTAL_META_ARCH_BUILDER_MAP = {
}
def _build_experimental_model(config, is_training, add_summaries=True):
return EXPERIMENTAL_META_ARCH_BUILDER_MAP[config.name](
is_training, add_summaries)
# The class ID in the groundtruth/model architecture is usually 0-based while
# the ID in the label map is 1-based. The offset is used to convert between the
# the two.
CLASS_ID_OFFSET = 1
KEYPOINT_STD_DEV_DEFAULT = 1.0
def keypoint_proto_to_params(kp_config, keypoint_map_dict):
"""Converts CenterNet.KeypointEstimation proto to parameter namedtuple."""
label_map_item = keypoint_map_dict[kp_config.keypoint_class_name]
classification_loss, localization_loss, _, _, _, _, _ = (
losses_builder.build(kp_config.loss))
keypoint_indices = [
keypoint.id for keypoint in label_map_item.keypoints
]
keypoint_labels = [
keypoint.label for keypoint in label_map_item.keypoints
]
keypoint_std_dev_dict = {
label: KEYPOINT_STD_DEV_DEFAULT for label in keypoint_labels
}
if kp_config.keypoint_label_to_std:
for label, value in kp_config.keypoint_label_to_std.items():
keypoint_std_dev_dict[label] = value
keypoint_std_dev = [keypoint_std_dev_dict[label] for label in keypoint_labels]
return center_net_meta_arch.KeypointEstimationParams(
task_name=kp_config.task_name,
class_id=label_map_item.id - CLASS_ID_OFFSET,
keypoint_indices=keypoint_indices,
classification_loss=classification_loss,
localization_loss=localization_loss,
keypoint_labels=keypoint_labels,
keypoint_std_dev=keypoint_std_dev,
task_loss_weight=kp_config.task_loss_weight,
keypoint_regression_loss_weight=kp_config.keypoint_regression_loss_weight,
keypoint_heatmap_loss_weight=kp_config.keypoint_heatmap_loss_weight,
keypoint_offset_loss_weight=kp_config.keypoint_offset_loss_weight,
heatmap_bias_init=kp_config.heatmap_bias_init,
keypoint_candidate_score_threshold=(
kp_config.keypoint_candidate_score_threshold),
num_candidates_per_keypoint=kp_config.num_candidates_per_keypoint,
peak_max_pool_kernel_size=kp_config.peak_max_pool_kernel_size,
unmatched_keypoint_score=kp_config.unmatched_keypoint_score,
box_scale=kp_config.box_scale,
candidate_search_scale=kp_config.candidate_search_scale,
candidate_ranking_mode=kp_config.candidate_ranking_mode,
offset_peak_radius=kp_config.offset_peak_radius,
per_keypoint_offset=kp_config.per_keypoint_offset)
def object_detection_proto_to_params(od_config):
"""Converts CenterNet.ObjectDetection proto to parameter namedtuple."""
loss = losses_pb2.Loss()
# Add dummy classification loss to avoid the loss_builder throwing error.
# TODO(yuhuic): update the loss builder to take the classification loss
# directly.
loss.classification_loss.weighted_sigmoid.CopyFrom(
losses_pb2.WeightedSigmoidClassificationLoss())
loss.localization_loss.CopyFrom(od_config.localization_loss)
_, localization_loss, _, _, _, _, _ = (losses_builder.build(loss))
return center_net_meta_arch.ObjectDetectionParams(
localization_loss=localization_loss,
scale_loss_weight=od_config.scale_loss_weight,
offset_loss_weight=od_config.offset_loss_weight,
task_loss_weight=od_config.task_loss_weight)
def object_center_proto_to_params(oc_config):
"""Converts CenterNet.ObjectCenter proto to parameter namedtuple."""
loss = losses_pb2.Loss()
# Add dummy localization loss to avoid the loss_builder throwing error.
# TODO(yuhuic): update the loss builder to take the localization loss
# directly.
loss.localization_loss.weighted_l2.CopyFrom(
losses_pb2.WeightedL2LocalizationLoss())
loss.classification_loss.CopyFrom(oc_config.classification_loss)
classification_loss, _, _, _, _, _, _ = (losses_builder.build(loss))
return center_net_meta_arch.ObjectCenterParams(
classification_loss=classification_loss,
object_center_loss_weight=oc_config.object_center_loss_weight,
heatmap_bias_init=oc_config.heatmap_bias_init,
min_box_overlap_iou=oc_config.min_box_overlap_iou,
max_box_predictions=oc_config.max_box_predictions,
use_labeled_classes=oc_config.use_labeled_classes)
def mask_proto_to_params(mask_config):
"""Converts CenterNet.MaskEstimation proto to parameter namedtuple."""
loss = losses_pb2.Loss()
# Add dummy localization loss to avoid the loss_builder throwing error.
loss.localization_loss.weighted_l2.CopyFrom(
losses_pb2.WeightedL2LocalizationLoss())
loss.classification_loss.CopyFrom(mask_config.classification_loss)
classification_loss, _, _, _, _, _, _ = (losses_builder.build(loss))
return center_net_meta_arch.MaskParams(
classification_loss=classification_loss,
task_loss_weight=mask_config.task_loss_weight,
mask_height=mask_config.mask_height,
mask_width=mask_config.mask_width,
score_threshold=mask_config.score_threshold,
heatmap_bias_init=mask_config.heatmap_bias_init)
def _build_center_net_model(center_net_config, is_training, add_summaries):
"""Build a CenterNet detection model.
Args:
center_net_config: A CenterNet proto object with model configuration.
is_training: True if this model is being built for training purposes.
add_summaries: Whether to add tf summaries in the model.
Returns:
CenterNetMetaArch based on the config.
"""
image_resizer_fn = image_resizer_builder.build(
center_net_config.image_resizer)
_check_feature_extractor_exists(center_net_config.feature_extractor.type)
feature_extractor = _build_center_net_feature_extractor(
center_net_config.feature_extractor)
object_center_params = object_center_proto_to_params(
center_net_config.object_center_params)
object_detection_params = None
if center_net_config.HasField('object_detection_task'):
object_detection_params = object_detection_proto_to_params(
center_net_config.object_detection_task)
keypoint_params_dict = None
if center_net_config.keypoint_estimation_task:
label_map_proto = label_map_util.load_labelmap(
center_net_config.keypoint_label_map_path)
keypoint_map_dict = {
item.name: item for item in label_map_proto.item if item.keypoints
}
keypoint_params_dict = {}
keypoint_class_id_set = set()
all_keypoint_indices = []
for task in center_net_config.keypoint_estimation_task:
kp_params = keypoint_proto_to_params(task, keypoint_map_dict)
keypoint_params_dict[task.task_name] = kp_params
all_keypoint_indices.extend(kp_params.keypoint_indices)
if kp_params.class_id in keypoint_class_id_set:
raise ValueError(('Multiple keypoint tasks map to the same class id is '
'not allowed: %d' % kp_params.class_id))
else:
keypoint_class_id_set.add(kp_params.class_id)
if len(all_keypoint_indices) > len(set(all_keypoint_indices)):
raise ValueError('Some keypoint indices are used more than once.')
mask_params = None
if center_net_config.HasField('mask_estimation_task'):
mask_params = mask_proto_to_params(center_net_config.mask_estimation_task)
return center_net_meta_arch.CenterNetMetaArch(
is_training=is_training,
add_summaries=add_summaries,
num_classes=center_net_config.num_classes,
feature_extractor=feature_extractor,
image_resizer_fn=image_resizer_fn,
object_center_params=object_center_params,
object_detection_params=object_detection_params,
keypoint_params_dict=keypoint_params_dict,
mask_params=mask_params)
def _build_center_net_feature_extractor(
feature_extractor_config):
"""Build a CenterNet feature extractor from the given config."""
if feature_extractor_config.type not in CENTER_NET_EXTRACTOR_FUNCTION_MAP:
raise ValueError('\'{}\' is not a known CenterNet feature extractor type'
.format(feature_extractor_config.type))
return CENTER_NET_EXTRACTOR_FUNCTION_MAP[feature_extractor_config.type](
channel_means=list(feature_extractor_config.channel_means),
channel_stds=list(feature_extractor_config.channel_stds),
bgr_ordering=feature_extractor_config.bgr_ordering
)
META_ARCH_BUILDER_MAP = {
'ssd': _build_ssd_model,
'faster_rcnn': _build_faster_rcnn_model,
'experimental_model': _build_experimental_model,
'center_net': _build_center_net_model
}
def build(model_config, is_training, add_summaries=True):
"""Builds a DetectionModel based on the model config.
Args:
model_config: A model.proto object containing the config for the desired
DetectionModel.
is_training: True if this model is being built for training purposes.
add_summaries: Whether to add tensorflow summaries in the model graph.
Returns:
DetectionModel based on the config.
Raises:
ValueError: On invalid meta architecture or model.
"""
if not isinstance(model_config, model_pb2.DetectionModel):
raise ValueError('model_config not of type model_pb2.DetectionModel.')
meta_architecture = model_config.WhichOneof('model')
if meta_architecture not in META_ARCH_BUILDER_MAP:
raise ValueError('Unknown meta architecture: {}'.format(meta_architecture))
else:
build_func = META_ARCH_BUILDER_MAP[meta_architecture]
return build_func(getattr(model_config, meta_architecture), is_training,
add_summaries)