# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Model architecture factory.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from official.legacy.detection.modeling.architecture import fpn from official.legacy.detection.modeling.architecture import heads from official.legacy.detection.modeling.architecture import identity from official.legacy.detection.modeling.architecture import nn_ops from official.legacy.detection.modeling.architecture import resnet from official.legacy.detection.modeling.architecture import spinenet def norm_activation_generator(params): return nn_ops.norm_activation_builder( momentum=params.batch_norm_momentum, epsilon=params.batch_norm_epsilon, trainable=params.batch_norm_trainable, activation=params.activation) def backbone_generator(params): """Generator function for various backbone models.""" if params.architecture.backbone == 'resnet': resnet_params = params.resnet backbone_fn = resnet.Resnet( resnet_depth=resnet_params.resnet_depth, activation=params.norm_activation.activation, norm_activation=norm_activation_generator( params.norm_activation)) elif params.architecture.backbone == 'spinenet': spinenet_params = params.spinenet backbone_fn = spinenet.SpineNetBuilder(model_id=spinenet_params.model_id) else: raise ValueError('Backbone model `{}` is not supported.' .format(params.architecture.backbone)) return backbone_fn def multilevel_features_generator(params): """Generator function for various FPN models.""" if params.architecture.multilevel_features == 'fpn': fpn_params = params.fpn fpn_fn = fpn.Fpn( min_level=params.architecture.min_level, max_level=params.architecture.max_level, fpn_feat_dims=fpn_params.fpn_feat_dims, use_separable_conv=fpn_params.use_separable_conv, activation=params.norm_activation.activation, use_batch_norm=fpn_params.use_batch_norm, norm_activation=norm_activation_generator( params.norm_activation)) elif params.architecture.multilevel_features == 'identity': fpn_fn = identity.Identity() else: raise ValueError('The multi-level feature model `{}` is not supported.' .format(params.architecture.multilevel_features)) return fpn_fn def retinanet_head_generator(params): """Generator function for RetinaNet head architecture.""" head_params = params.retinanet_head anchors_per_location = params.anchor.num_scales * len( params.anchor.aspect_ratios) return heads.RetinanetHead( params.architecture.min_level, params.architecture.max_level, params.architecture.num_classes, anchors_per_location, head_params.num_convs, head_params.num_filters, head_params.use_separable_conv, norm_activation=norm_activation_generator(params.norm_activation)) def rpn_head_generator(params): """Generator function for RPN head architecture.""" head_params = params.rpn_head anchors_per_location = params.anchor.num_scales * len( params.anchor.aspect_ratios) return heads.RpnHead( params.architecture.min_level, params.architecture.max_level, anchors_per_location, head_params.num_convs, head_params.num_filters, head_params.use_separable_conv, params.norm_activation.activation, head_params.use_batch_norm, norm_activation=norm_activation_generator(params.norm_activation)) def oln_rpn_head_generator(params): """Generator function for OLN-proposal (OLN-RPN) head architecture.""" head_params = params.rpn_head anchors_per_location = params.anchor.num_scales * len( params.anchor.aspect_ratios) return heads.OlnRpnHead( params.architecture.min_level, params.architecture.max_level, anchors_per_location, head_params.num_convs, head_params.num_filters, head_params.use_separable_conv, params.norm_activation.activation, head_params.use_batch_norm, norm_activation=norm_activation_generator(params.norm_activation)) def fast_rcnn_head_generator(params): """Generator function for Fast R-CNN head architecture.""" head_params = params.frcnn_head return heads.FastrcnnHead( params.architecture.num_classes, head_params.num_convs, head_params.num_filters, head_params.use_separable_conv, head_params.num_fcs, head_params.fc_dims, params.norm_activation.activation, head_params.use_batch_norm, norm_activation=norm_activation_generator(params.norm_activation)) def oln_box_score_head_generator(params): """Generator function for Scoring Fast R-CNN head architecture.""" head_params = params.frcnn_head return heads.OlnBoxScoreHead( params.architecture.num_classes, head_params.num_convs, head_params.num_filters, head_params.use_separable_conv, head_params.num_fcs, head_params.fc_dims, params.norm_activation.activation, head_params.use_batch_norm, norm_activation=norm_activation_generator(params.norm_activation)) def mask_rcnn_head_generator(params): """Generator function for Mask R-CNN head architecture.""" head_params = params.mrcnn_head return heads.MaskrcnnHead( params.architecture.num_classes, params.architecture.mask_target_size, head_params.num_convs, head_params.num_filters, head_params.use_separable_conv, params.norm_activation.activation, head_params.use_batch_norm, norm_activation=norm_activation_generator(params.norm_activation)) def oln_mask_score_head_generator(params): """Generator function for Scoring Mask R-CNN head architecture.""" head_params = params.mrcnn_head return heads.OlnMaskScoreHead( params.architecture.num_classes, params.architecture.mask_target_size, head_params.num_convs, head_params.num_filters, head_params.use_separable_conv, params.norm_activation.activation, head_params.use_batch_norm, norm_activation=norm_activation_generator(params.norm_activation)) def shapeprior_head_generator(params): """Generator function for shape prior head architecture.""" head_params = params.shapemask_head return heads.ShapemaskPriorHead( params.architecture.num_classes, head_params.num_downsample_channels, head_params.mask_crop_size, head_params.use_category_for_mask, head_params.shape_prior_path) def coarsemask_head_generator(params): """Generator function for ShapeMask coarse mask head architecture.""" head_params = params.shapemask_head return heads.ShapemaskCoarsemaskHead( params.architecture.num_classes, head_params.num_downsample_channels, head_params.mask_crop_size, head_params.use_category_for_mask, head_params.num_convs, norm_activation=norm_activation_generator(params.norm_activation)) def finemask_head_generator(params): """Generator function for Shapemask fine mask head architecture.""" head_params = params.shapemask_head return heads.ShapemaskFinemaskHead( params.architecture.num_classes, head_params.num_downsample_channels, head_params.mask_crop_size, head_params.use_category_for_mask, head_params.num_convs, head_params.upsample_factor)