# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Factory methods to build models.""" from typing import Optional import tensorflow as tf, tf_keras from official.vision.configs import image_classification as classification_cfg from official.vision.configs import maskrcnn as maskrcnn_cfg from official.vision.configs import retinanet as retinanet_cfg from official.vision.configs import semantic_segmentation as segmentation_cfg from official.vision.modeling import backbones from official.vision.modeling import classification_model from official.vision.modeling import decoders from official.vision.modeling import maskrcnn_model from official.vision.modeling import retinanet_model from official.vision.modeling import segmentation_model from official.vision.modeling.heads import dense_prediction_heads from official.vision.modeling.heads import instance_heads from official.vision.modeling.heads import segmentation_heads from official.vision.modeling.layers import detection_generator from official.vision.modeling.layers import mask_sampler from official.vision.modeling.layers import roi_aligner from official.vision.modeling.layers import roi_generator from official.vision.modeling.layers import roi_sampler def build_classification_model( input_specs: tf_keras.layers.InputSpec, model_config: classification_cfg.ImageClassificationModel, l2_regularizer: Optional[tf_keras.regularizers.Regularizer] = None, skip_logits_layer: bool = False, backbone: Optional[tf_keras.Model] = None) -> tf_keras.Model: """Builds the classification model.""" norm_activation_config = model_config.norm_activation if not backbone: backbone = backbones.factory.build_backbone( input_specs=input_specs, backbone_config=model_config.backbone, norm_activation_config=norm_activation_config, l2_regularizer=l2_regularizer) model = classification_model.ClassificationModel( backbone=backbone, num_classes=model_config.num_classes, input_specs=input_specs, dropout_rate=model_config.dropout_rate, kernel_initializer=model_config.kernel_initializer, kernel_regularizer=l2_regularizer, add_head_batch_norm=model_config.add_head_batch_norm, use_sync_bn=norm_activation_config.use_sync_bn, norm_momentum=norm_activation_config.norm_momentum, norm_epsilon=norm_activation_config.norm_epsilon, skip_logits_layer=skip_logits_layer) return model def build_maskrcnn(input_specs: tf_keras.layers.InputSpec, model_config: maskrcnn_cfg.MaskRCNN, l2_regularizer: Optional[ tf_keras.regularizers.Regularizer] = None, backbone: Optional[tf_keras.Model] = None, decoder: Optional[tf_keras.Model] = None) -> tf_keras.Model: """Builds Mask R-CNN model.""" norm_activation_config = model_config.norm_activation if not backbone: backbone = backbones.factory.build_backbone( input_specs=input_specs, backbone_config=model_config.backbone, norm_activation_config=norm_activation_config, l2_regularizer=l2_regularizer) backbone_features = backbone(tf_keras.Input(input_specs.shape[1:])) if not decoder: decoder = decoders.factory.build_decoder( input_specs=backbone.output_specs, model_config=model_config, l2_regularizer=l2_regularizer) rpn_head_config = model_config.rpn_head roi_generator_config = model_config.roi_generator roi_sampler_config = model_config.roi_sampler roi_aligner_config = model_config.roi_aligner detection_head_config = model_config.detection_head generator_config = model_config.detection_generator num_anchors_per_location = ( len(model_config.anchor.aspect_ratios) * model_config.anchor.num_scales) rpn_head = dense_prediction_heads.RPNHead( min_level=model_config.min_level, max_level=model_config.max_level, num_anchors_per_location=num_anchors_per_location, num_convs=rpn_head_config.num_convs, num_filters=rpn_head_config.num_filters, use_separable_conv=rpn_head_config.use_separable_conv, activation=norm_activation_config.activation, use_sync_bn=norm_activation_config.use_sync_bn, norm_momentum=norm_activation_config.norm_momentum, norm_epsilon=norm_activation_config.norm_epsilon, kernel_regularizer=l2_regularizer) detection_head = instance_heads.DetectionHead( num_classes=model_config.num_classes, num_convs=detection_head_config.num_convs, num_filters=detection_head_config.num_filters, use_separable_conv=detection_head_config.use_separable_conv, num_fcs=detection_head_config.num_fcs, fc_dims=detection_head_config.fc_dims, class_agnostic_bbox_pred=detection_head_config.class_agnostic_bbox_pred, activation=norm_activation_config.activation, use_sync_bn=norm_activation_config.use_sync_bn, norm_momentum=norm_activation_config.norm_momentum, norm_epsilon=norm_activation_config.norm_epsilon, kernel_regularizer=l2_regularizer, name='detection_head') if decoder: decoder_features = decoder(backbone_features) rpn_head(decoder_features) if roi_sampler_config.cascade_iou_thresholds: detection_head_cascade = [detection_head] for cascade_num in range(len(roi_sampler_config.cascade_iou_thresholds)): detection_head = instance_heads.DetectionHead( num_classes=model_config.num_classes, num_convs=detection_head_config.num_convs, num_filters=detection_head_config.num_filters, use_separable_conv=detection_head_config.use_separable_conv, num_fcs=detection_head_config.num_fcs, fc_dims=detection_head_config.fc_dims, class_agnostic_bbox_pred=detection_head_config .class_agnostic_bbox_pred, activation=norm_activation_config.activation, use_sync_bn=norm_activation_config.use_sync_bn, norm_momentum=norm_activation_config.norm_momentum, norm_epsilon=norm_activation_config.norm_epsilon, kernel_regularizer=l2_regularizer, name='detection_head_{}'.format(cascade_num + 1)) detection_head_cascade.append(detection_head) detection_head = detection_head_cascade roi_generator_obj = roi_generator.MultilevelROIGenerator( pre_nms_top_k=roi_generator_config.pre_nms_top_k, pre_nms_score_threshold=roi_generator_config.pre_nms_score_threshold, pre_nms_min_size_threshold=( roi_generator_config.pre_nms_min_size_threshold), nms_iou_threshold=roi_generator_config.nms_iou_threshold, num_proposals=roi_generator_config.num_proposals, test_pre_nms_top_k=roi_generator_config.test_pre_nms_top_k, test_pre_nms_score_threshold=( roi_generator_config.test_pre_nms_score_threshold), test_pre_nms_min_size_threshold=( roi_generator_config.test_pre_nms_min_size_threshold), test_nms_iou_threshold=roi_generator_config.test_nms_iou_threshold, test_num_proposals=roi_generator_config.test_num_proposals, use_batched_nms=roi_generator_config.use_batched_nms) roi_sampler_cascade = [] roi_sampler_obj = roi_sampler.ROISampler( mix_gt_boxes=roi_sampler_config.mix_gt_boxes, num_sampled_rois=roi_sampler_config.num_sampled_rois, foreground_fraction=roi_sampler_config.foreground_fraction, foreground_iou_threshold=roi_sampler_config.foreground_iou_threshold, background_iou_high_threshold=( roi_sampler_config.background_iou_high_threshold), background_iou_low_threshold=( roi_sampler_config.background_iou_low_threshold)) roi_sampler_cascade.append(roi_sampler_obj) # Initialize additional roi simplers for cascade heads. if roi_sampler_config.cascade_iou_thresholds: for iou in roi_sampler_config.cascade_iou_thresholds: roi_sampler_obj = roi_sampler.ROISampler( mix_gt_boxes=False, num_sampled_rois=roi_sampler_config.num_sampled_rois, foreground_iou_threshold=iou, background_iou_high_threshold=iou, background_iou_low_threshold=0.0, skip_subsampling=True) roi_sampler_cascade.append(roi_sampler_obj) roi_aligner_obj = roi_aligner.MultilevelROIAligner( crop_size=roi_aligner_config.crop_size, sample_offset=roi_aligner_config.sample_offset) detection_generator_obj = detection_generator.DetectionGenerator( apply_nms=generator_config.apply_nms, pre_nms_top_k=generator_config.pre_nms_top_k, pre_nms_score_threshold=generator_config.pre_nms_score_threshold, nms_iou_threshold=generator_config.nms_iou_threshold, max_num_detections=generator_config.max_num_detections, nms_version=generator_config.nms_version, use_cpu_nms=generator_config.use_cpu_nms, soft_nms_sigma=generator_config.soft_nms_sigma, use_sigmoid_probability=generator_config.use_sigmoid_probability) if model_config.include_mask: mask_head = instance_heads.MaskHead( num_classes=model_config.num_classes, upsample_factor=model_config.mask_head.upsample_factor, num_convs=model_config.mask_head.num_convs, num_filters=model_config.mask_head.num_filters, use_separable_conv=model_config.mask_head.use_separable_conv, activation=model_config.norm_activation.activation, norm_momentum=model_config.norm_activation.norm_momentum, norm_epsilon=model_config.norm_activation.norm_epsilon, kernel_regularizer=l2_regularizer, class_agnostic=model_config.mask_head.class_agnostic) mask_sampler_obj = mask_sampler.MaskSampler( mask_target_size=( model_config.mask_roi_aligner.crop_size * model_config.mask_head.upsample_factor), num_sampled_masks=model_config.mask_sampler.num_sampled_masks) mask_roi_aligner_obj = roi_aligner.MultilevelROIAligner( crop_size=model_config.mask_roi_aligner.crop_size, sample_offset=model_config.mask_roi_aligner.sample_offset) else: mask_head = None mask_sampler_obj = None mask_roi_aligner_obj = None model = maskrcnn_model.MaskRCNNModel( backbone=backbone, decoder=decoder, rpn_head=rpn_head, detection_head=detection_head, roi_generator=roi_generator_obj, roi_sampler=roi_sampler_cascade, roi_aligner=roi_aligner_obj, detection_generator=detection_generator_obj, mask_head=mask_head, mask_sampler=mask_sampler_obj, mask_roi_aligner=mask_roi_aligner_obj, class_agnostic_bbox_pred=detection_head_config.class_agnostic_bbox_pred, cascade_class_ensemble=detection_head_config.cascade_class_ensemble, min_level=model_config.min_level, max_level=model_config.max_level, num_scales=model_config.anchor.num_scales, aspect_ratios=model_config.anchor.aspect_ratios, anchor_size=model_config.anchor.anchor_size, outer_boxes_scale=model_config.outer_boxes_scale) return model def build_retinanet( input_specs: tf_keras.layers.InputSpec, model_config: retinanet_cfg.RetinaNet, l2_regularizer: Optional[tf_keras.regularizers.Regularizer] = None, backbone: Optional[tf_keras.Model] = None, decoder: Optional[tf_keras.Model] = None ) -> tf_keras.Model: """Builds RetinaNet model.""" norm_activation_config = model_config.norm_activation if not backbone: backbone = backbones.factory.build_backbone( input_specs=input_specs, backbone_config=model_config.backbone, norm_activation_config=norm_activation_config, l2_regularizer=l2_regularizer) backbone_features = backbone(tf_keras.Input(input_specs.shape[1:])) if not decoder: decoder = decoders.factory.build_decoder( input_specs=backbone.output_specs, model_config=model_config, l2_regularizer=l2_regularizer) head_config = model_config.head generator_config = model_config.detection_generator num_anchors_per_location = ( len(model_config.anchor.aspect_ratios) * model_config.anchor.num_scales) head = dense_prediction_heads.RetinaNetHead( min_level=model_config.min_level, max_level=model_config.max_level, num_classes=model_config.num_classes, num_anchors_per_location=num_anchors_per_location, num_convs=head_config.num_convs, num_filters=head_config.num_filters, attribute_heads=[ cfg.as_dict() for cfg in (head_config.attribute_heads or []) ], share_classification_heads=head_config.share_classification_heads, use_separable_conv=head_config.use_separable_conv, activation=norm_activation_config.activation, use_sync_bn=norm_activation_config.use_sync_bn, norm_momentum=norm_activation_config.norm_momentum, norm_epsilon=norm_activation_config.norm_epsilon, kernel_regularizer=l2_regularizer, share_level_convs=head_config.share_level_convs, ) # Builds decoder and head so that their trainable weights are initialized if decoder: decoder_features = decoder(backbone_features) _ = head(decoder_features) # Add `input_image_size` into `tflite_post_processing_config`. tflite_post_processing_config = ( generator_config.tflite_post_processing.as_dict() ) tflite_post_processing_config['input_image_size'] = ( input_specs.shape[1], input_specs.shape[2], ) detection_generator_obj = detection_generator.MultilevelDetectionGenerator( apply_nms=generator_config.apply_nms, pre_nms_top_k=generator_config.pre_nms_top_k, pre_nms_score_threshold=generator_config.pre_nms_score_threshold, nms_iou_threshold=generator_config.nms_iou_threshold, max_num_detections=generator_config.max_num_detections, nms_version=generator_config.nms_version, use_cpu_nms=generator_config.use_cpu_nms, soft_nms_sigma=generator_config.soft_nms_sigma, tflite_post_processing_config=tflite_post_processing_config, return_decoded=generator_config.return_decoded, use_class_agnostic_nms=generator_config.use_class_agnostic_nms, box_coder_weights=generator_config.box_coder_weights, ) model = retinanet_model.RetinaNetModel( backbone, decoder, head, detection_generator_obj, min_level=model_config.min_level, max_level=model_config.max_level, num_scales=model_config.anchor.num_scales, aspect_ratios=model_config.anchor.aspect_ratios, anchor_size=model_config.anchor.anchor_size) return model def build_segmentation_model( input_specs: tf_keras.layers.InputSpec, model_config: segmentation_cfg.SemanticSegmentationModel, l2_regularizer: Optional[tf_keras.regularizers.Regularizer] = None, backbone: Optional[tf_keras.Model] = None, decoder: Optional[tf_keras.Model] = None ) -> tf_keras.Model: """Builds Segmentation model.""" norm_activation_config = model_config.norm_activation if not backbone: backbone = backbones.factory.build_backbone( input_specs=input_specs, backbone_config=model_config.backbone, norm_activation_config=norm_activation_config, l2_regularizer=l2_regularizer) if not decoder: decoder = decoders.factory.build_decoder( input_specs=backbone.output_specs, model_config=model_config, l2_regularizer=l2_regularizer) head_config = model_config.head head = segmentation_heads.SegmentationHead( num_classes=model_config.num_classes, level=head_config.level, num_convs=head_config.num_convs, prediction_kernel_size=head_config.prediction_kernel_size, num_filters=head_config.num_filters, use_depthwise_convolution=head_config.use_depthwise_convolution, upsample_factor=head_config.upsample_factor, feature_fusion=head_config.feature_fusion, low_level=head_config.low_level, low_level_num_filters=head_config.low_level_num_filters, activation=norm_activation_config.activation, logit_activation=head_config.logit_activation, use_sync_bn=norm_activation_config.use_sync_bn, norm_momentum=norm_activation_config.norm_momentum, norm_epsilon=norm_activation_config.norm_epsilon, kernel_regularizer=l2_regularizer) mask_scoring_head = None if model_config.mask_scoring_head: mask_scoring_head = segmentation_heads.MaskScoring( num_classes=model_config.num_classes, **model_config.mask_scoring_head.as_dict(), activation=norm_activation_config.activation, use_sync_bn=norm_activation_config.use_sync_bn, norm_momentum=norm_activation_config.norm_momentum, norm_epsilon=norm_activation_config.norm_epsilon, kernel_regularizer=l2_regularizer) model = segmentation_model.SegmentationModel( backbone, decoder, head, mask_scoring_head=mask_scoring_head) return model