Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Factory methods to build models.""" | |
from typing import Optional | |
import tensorflow as tf, tf_keras | |
from official.vision.configs import image_classification as classification_cfg | |
from official.vision.configs import maskrcnn as maskrcnn_cfg | |
from official.vision.configs import retinanet as retinanet_cfg | |
from official.vision.configs import semantic_segmentation as segmentation_cfg | |
from official.vision.modeling import backbones | |
from official.vision.modeling import classification_model | |
from official.vision.modeling import decoders | |
from official.vision.modeling import maskrcnn_model | |
from official.vision.modeling import retinanet_model | |
from official.vision.modeling import segmentation_model | |
from official.vision.modeling.heads import dense_prediction_heads | |
from official.vision.modeling.heads import instance_heads | |
from official.vision.modeling.heads import segmentation_heads | |
from official.vision.modeling.layers import detection_generator | |
from official.vision.modeling.layers import mask_sampler | |
from official.vision.modeling.layers import roi_aligner | |
from official.vision.modeling.layers import roi_generator | |
from official.vision.modeling.layers import roi_sampler | |
def build_classification_model( | |
input_specs: tf_keras.layers.InputSpec, | |
model_config: classification_cfg.ImageClassificationModel, | |
l2_regularizer: Optional[tf_keras.regularizers.Regularizer] = None, | |
skip_logits_layer: bool = False, | |
backbone: Optional[tf_keras.Model] = None) -> tf_keras.Model: | |
"""Builds the classification model.""" | |
norm_activation_config = model_config.norm_activation | |
if not backbone: | |
backbone = backbones.factory.build_backbone( | |
input_specs=input_specs, | |
backbone_config=model_config.backbone, | |
norm_activation_config=norm_activation_config, | |
l2_regularizer=l2_regularizer) | |
model = classification_model.ClassificationModel( | |
backbone=backbone, | |
num_classes=model_config.num_classes, | |
input_specs=input_specs, | |
dropout_rate=model_config.dropout_rate, | |
kernel_initializer=model_config.kernel_initializer, | |
kernel_regularizer=l2_regularizer, | |
add_head_batch_norm=model_config.add_head_batch_norm, | |
use_sync_bn=norm_activation_config.use_sync_bn, | |
norm_momentum=norm_activation_config.norm_momentum, | |
norm_epsilon=norm_activation_config.norm_epsilon, | |
skip_logits_layer=skip_logits_layer) | |
return model | |
def build_maskrcnn(input_specs: tf_keras.layers.InputSpec, | |
model_config: maskrcnn_cfg.MaskRCNN, | |
l2_regularizer: Optional[ | |
tf_keras.regularizers.Regularizer] = None, | |
backbone: Optional[tf_keras.Model] = None, | |
decoder: Optional[tf_keras.Model] = None) -> tf_keras.Model: | |
"""Builds Mask R-CNN model.""" | |
norm_activation_config = model_config.norm_activation | |
if not backbone: | |
backbone = backbones.factory.build_backbone( | |
input_specs=input_specs, | |
backbone_config=model_config.backbone, | |
norm_activation_config=norm_activation_config, | |
l2_regularizer=l2_regularizer) | |
backbone_features = backbone(tf_keras.Input(input_specs.shape[1:])) | |
if not decoder: | |
decoder = decoders.factory.build_decoder( | |
input_specs=backbone.output_specs, | |
model_config=model_config, | |
l2_regularizer=l2_regularizer) | |
rpn_head_config = model_config.rpn_head | |
roi_generator_config = model_config.roi_generator | |
roi_sampler_config = model_config.roi_sampler | |
roi_aligner_config = model_config.roi_aligner | |
detection_head_config = model_config.detection_head | |
generator_config = model_config.detection_generator | |
num_anchors_per_location = ( | |
len(model_config.anchor.aspect_ratios) * model_config.anchor.num_scales) | |
rpn_head = dense_prediction_heads.RPNHead( | |
min_level=model_config.min_level, | |
max_level=model_config.max_level, | |
num_anchors_per_location=num_anchors_per_location, | |
num_convs=rpn_head_config.num_convs, | |
num_filters=rpn_head_config.num_filters, | |
use_separable_conv=rpn_head_config.use_separable_conv, | |
activation=norm_activation_config.activation, | |
use_sync_bn=norm_activation_config.use_sync_bn, | |
norm_momentum=norm_activation_config.norm_momentum, | |
norm_epsilon=norm_activation_config.norm_epsilon, | |
kernel_regularizer=l2_regularizer) | |
detection_head = instance_heads.DetectionHead( | |
num_classes=model_config.num_classes, | |
num_convs=detection_head_config.num_convs, | |
num_filters=detection_head_config.num_filters, | |
use_separable_conv=detection_head_config.use_separable_conv, | |
num_fcs=detection_head_config.num_fcs, | |
fc_dims=detection_head_config.fc_dims, | |
class_agnostic_bbox_pred=detection_head_config.class_agnostic_bbox_pred, | |
activation=norm_activation_config.activation, | |
use_sync_bn=norm_activation_config.use_sync_bn, | |
norm_momentum=norm_activation_config.norm_momentum, | |
norm_epsilon=norm_activation_config.norm_epsilon, | |
kernel_regularizer=l2_regularizer, | |
name='detection_head') | |
if decoder: | |
decoder_features = decoder(backbone_features) | |
rpn_head(decoder_features) | |
if roi_sampler_config.cascade_iou_thresholds: | |
detection_head_cascade = [detection_head] | |
for cascade_num in range(len(roi_sampler_config.cascade_iou_thresholds)): | |
detection_head = instance_heads.DetectionHead( | |
num_classes=model_config.num_classes, | |
num_convs=detection_head_config.num_convs, | |
num_filters=detection_head_config.num_filters, | |
use_separable_conv=detection_head_config.use_separable_conv, | |
num_fcs=detection_head_config.num_fcs, | |
fc_dims=detection_head_config.fc_dims, | |
class_agnostic_bbox_pred=detection_head_config | |
.class_agnostic_bbox_pred, | |
activation=norm_activation_config.activation, | |
use_sync_bn=norm_activation_config.use_sync_bn, | |
norm_momentum=norm_activation_config.norm_momentum, | |
norm_epsilon=norm_activation_config.norm_epsilon, | |
kernel_regularizer=l2_regularizer, | |
name='detection_head_{}'.format(cascade_num + 1)) | |
detection_head_cascade.append(detection_head) | |
detection_head = detection_head_cascade | |
roi_generator_obj = roi_generator.MultilevelROIGenerator( | |
pre_nms_top_k=roi_generator_config.pre_nms_top_k, | |
pre_nms_score_threshold=roi_generator_config.pre_nms_score_threshold, | |
pre_nms_min_size_threshold=( | |
roi_generator_config.pre_nms_min_size_threshold), | |
nms_iou_threshold=roi_generator_config.nms_iou_threshold, | |
num_proposals=roi_generator_config.num_proposals, | |
test_pre_nms_top_k=roi_generator_config.test_pre_nms_top_k, | |
test_pre_nms_score_threshold=( | |
roi_generator_config.test_pre_nms_score_threshold), | |
test_pre_nms_min_size_threshold=( | |
roi_generator_config.test_pre_nms_min_size_threshold), | |
test_nms_iou_threshold=roi_generator_config.test_nms_iou_threshold, | |
test_num_proposals=roi_generator_config.test_num_proposals, | |
use_batched_nms=roi_generator_config.use_batched_nms) | |
roi_sampler_cascade = [] | |
roi_sampler_obj = roi_sampler.ROISampler( | |
mix_gt_boxes=roi_sampler_config.mix_gt_boxes, | |
num_sampled_rois=roi_sampler_config.num_sampled_rois, | |
foreground_fraction=roi_sampler_config.foreground_fraction, | |
foreground_iou_threshold=roi_sampler_config.foreground_iou_threshold, | |
background_iou_high_threshold=( | |
roi_sampler_config.background_iou_high_threshold), | |
background_iou_low_threshold=( | |
roi_sampler_config.background_iou_low_threshold)) | |
roi_sampler_cascade.append(roi_sampler_obj) | |
# Initialize additional roi simplers for cascade heads. | |
if roi_sampler_config.cascade_iou_thresholds: | |
for iou in roi_sampler_config.cascade_iou_thresholds: | |
roi_sampler_obj = roi_sampler.ROISampler( | |
mix_gt_boxes=False, | |
num_sampled_rois=roi_sampler_config.num_sampled_rois, | |
foreground_iou_threshold=iou, | |
background_iou_high_threshold=iou, | |
background_iou_low_threshold=0.0, | |
skip_subsampling=True) | |
roi_sampler_cascade.append(roi_sampler_obj) | |
roi_aligner_obj = roi_aligner.MultilevelROIAligner( | |
crop_size=roi_aligner_config.crop_size, | |
sample_offset=roi_aligner_config.sample_offset) | |
detection_generator_obj = detection_generator.DetectionGenerator( | |
apply_nms=generator_config.apply_nms, | |
pre_nms_top_k=generator_config.pre_nms_top_k, | |
pre_nms_score_threshold=generator_config.pre_nms_score_threshold, | |
nms_iou_threshold=generator_config.nms_iou_threshold, | |
max_num_detections=generator_config.max_num_detections, | |
nms_version=generator_config.nms_version, | |
use_cpu_nms=generator_config.use_cpu_nms, | |
soft_nms_sigma=generator_config.soft_nms_sigma, | |
use_sigmoid_probability=generator_config.use_sigmoid_probability) | |
if model_config.include_mask: | |
mask_head = instance_heads.MaskHead( | |
num_classes=model_config.num_classes, | |
upsample_factor=model_config.mask_head.upsample_factor, | |
num_convs=model_config.mask_head.num_convs, | |
num_filters=model_config.mask_head.num_filters, | |
use_separable_conv=model_config.mask_head.use_separable_conv, | |
activation=model_config.norm_activation.activation, | |
norm_momentum=model_config.norm_activation.norm_momentum, | |
norm_epsilon=model_config.norm_activation.norm_epsilon, | |
kernel_regularizer=l2_regularizer, | |
class_agnostic=model_config.mask_head.class_agnostic) | |
mask_sampler_obj = mask_sampler.MaskSampler( | |
mask_target_size=( | |
model_config.mask_roi_aligner.crop_size * | |
model_config.mask_head.upsample_factor), | |
num_sampled_masks=model_config.mask_sampler.num_sampled_masks) | |
mask_roi_aligner_obj = roi_aligner.MultilevelROIAligner( | |
crop_size=model_config.mask_roi_aligner.crop_size, | |
sample_offset=model_config.mask_roi_aligner.sample_offset) | |
else: | |
mask_head = None | |
mask_sampler_obj = None | |
mask_roi_aligner_obj = None | |
model = maskrcnn_model.MaskRCNNModel( | |
backbone=backbone, | |
decoder=decoder, | |
rpn_head=rpn_head, | |
detection_head=detection_head, | |
roi_generator=roi_generator_obj, | |
roi_sampler=roi_sampler_cascade, | |
roi_aligner=roi_aligner_obj, | |
detection_generator=detection_generator_obj, | |
mask_head=mask_head, | |
mask_sampler=mask_sampler_obj, | |
mask_roi_aligner=mask_roi_aligner_obj, | |
class_agnostic_bbox_pred=detection_head_config.class_agnostic_bbox_pred, | |
cascade_class_ensemble=detection_head_config.cascade_class_ensemble, | |
min_level=model_config.min_level, | |
max_level=model_config.max_level, | |
num_scales=model_config.anchor.num_scales, | |
aspect_ratios=model_config.anchor.aspect_ratios, | |
anchor_size=model_config.anchor.anchor_size, | |
outer_boxes_scale=model_config.outer_boxes_scale) | |
return model | |
def build_retinanet( | |
input_specs: tf_keras.layers.InputSpec, | |
model_config: retinanet_cfg.RetinaNet, | |
l2_regularizer: Optional[tf_keras.regularizers.Regularizer] = None, | |
backbone: Optional[tf_keras.Model] = None, | |
decoder: Optional[tf_keras.Model] = None | |
) -> tf_keras.Model: | |
"""Builds RetinaNet model.""" | |
norm_activation_config = model_config.norm_activation | |
if not backbone: | |
backbone = backbones.factory.build_backbone( | |
input_specs=input_specs, | |
backbone_config=model_config.backbone, | |
norm_activation_config=norm_activation_config, | |
l2_regularizer=l2_regularizer) | |
backbone_features = backbone(tf_keras.Input(input_specs.shape[1:])) | |
if not decoder: | |
decoder = decoders.factory.build_decoder( | |
input_specs=backbone.output_specs, | |
model_config=model_config, | |
l2_regularizer=l2_regularizer) | |
head_config = model_config.head | |
generator_config = model_config.detection_generator | |
num_anchors_per_location = ( | |
len(model_config.anchor.aspect_ratios) * model_config.anchor.num_scales) | |
head = dense_prediction_heads.RetinaNetHead( | |
min_level=model_config.min_level, | |
max_level=model_config.max_level, | |
num_classes=model_config.num_classes, | |
num_anchors_per_location=num_anchors_per_location, | |
num_convs=head_config.num_convs, | |
num_filters=head_config.num_filters, | |
attribute_heads=[ | |
cfg.as_dict() for cfg in (head_config.attribute_heads or []) | |
], | |
share_classification_heads=head_config.share_classification_heads, | |
use_separable_conv=head_config.use_separable_conv, | |
activation=norm_activation_config.activation, | |
use_sync_bn=norm_activation_config.use_sync_bn, | |
norm_momentum=norm_activation_config.norm_momentum, | |
norm_epsilon=norm_activation_config.norm_epsilon, | |
kernel_regularizer=l2_regularizer, | |
share_level_convs=head_config.share_level_convs, | |
) | |
# Builds decoder and head so that their trainable weights are initialized | |
if decoder: | |
decoder_features = decoder(backbone_features) | |
_ = head(decoder_features) | |
# Add `input_image_size` into `tflite_post_processing_config`. | |
tflite_post_processing_config = ( | |
generator_config.tflite_post_processing.as_dict() | |
) | |
tflite_post_processing_config['input_image_size'] = ( | |
input_specs.shape[1], | |
input_specs.shape[2], | |
) | |
detection_generator_obj = detection_generator.MultilevelDetectionGenerator( | |
apply_nms=generator_config.apply_nms, | |
pre_nms_top_k=generator_config.pre_nms_top_k, | |
pre_nms_score_threshold=generator_config.pre_nms_score_threshold, | |
nms_iou_threshold=generator_config.nms_iou_threshold, | |
max_num_detections=generator_config.max_num_detections, | |
nms_version=generator_config.nms_version, | |
use_cpu_nms=generator_config.use_cpu_nms, | |
soft_nms_sigma=generator_config.soft_nms_sigma, | |
tflite_post_processing_config=tflite_post_processing_config, | |
return_decoded=generator_config.return_decoded, | |
use_class_agnostic_nms=generator_config.use_class_agnostic_nms, | |
box_coder_weights=generator_config.box_coder_weights, | |
) | |
model = retinanet_model.RetinaNetModel( | |
backbone, | |
decoder, | |
head, | |
detection_generator_obj, | |
min_level=model_config.min_level, | |
max_level=model_config.max_level, | |
num_scales=model_config.anchor.num_scales, | |
aspect_ratios=model_config.anchor.aspect_ratios, | |
anchor_size=model_config.anchor.anchor_size) | |
return model | |
def build_segmentation_model( | |
input_specs: tf_keras.layers.InputSpec, | |
model_config: segmentation_cfg.SemanticSegmentationModel, | |
l2_regularizer: Optional[tf_keras.regularizers.Regularizer] = None, | |
backbone: Optional[tf_keras.Model] = None, | |
decoder: Optional[tf_keras.Model] = None | |
) -> tf_keras.Model: | |
"""Builds Segmentation model.""" | |
norm_activation_config = model_config.norm_activation | |
if not backbone: | |
backbone = backbones.factory.build_backbone( | |
input_specs=input_specs, | |
backbone_config=model_config.backbone, | |
norm_activation_config=norm_activation_config, | |
l2_regularizer=l2_regularizer) | |
if not decoder: | |
decoder = decoders.factory.build_decoder( | |
input_specs=backbone.output_specs, | |
model_config=model_config, | |
l2_regularizer=l2_regularizer) | |
head_config = model_config.head | |
head = segmentation_heads.SegmentationHead( | |
num_classes=model_config.num_classes, | |
level=head_config.level, | |
num_convs=head_config.num_convs, | |
prediction_kernel_size=head_config.prediction_kernel_size, | |
num_filters=head_config.num_filters, | |
use_depthwise_convolution=head_config.use_depthwise_convolution, | |
upsample_factor=head_config.upsample_factor, | |
feature_fusion=head_config.feature_fusion, | |
low_level=head_config.low_level, | |
low_level_num_filters=head_config.low_level_num_filters, | |
activation=norm_activation_config.activation, | |
logit_activation=head_config.logit_activation, | |
use_sync_bn=norm_activation_config.use_sync_bn, | |
norm_momentum=norm_activation_config.norm_momentum, | |
norm_epsilon=norm_activation_config.norm_epsilon, | |
kernel_regularizer=l2_regularizer) | |
mask_scoring_head = None | |
if model_config.mask_scoring_head: | |
mask_scoring_head = segmentation_heads.MaskScoring( | |
num_classes=model_config.num_classes, | |
**model_config.mask_scoring_head.as_dict(), | |
activation=norm_activation_config.activation, | |
use_sync_bn=norm_activation_config.use_sync_bn, | |
norm_momentum=norm_activation_config.norm_momentum, | |
norm_epsilon=norm_activation_config.norm_epsilon, | |
kernel_regularizer=l2_regularizer) | |
model = segmentation_model.SegmentationModel( | |
backbone, decoder, head, mask_scoring_head=mask_scoring_head) | |
return model | |