# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Semantic segmentation configuration definition.""" import dataclasses import os from typing import List, Optional, Sequence, Union import numpy as np from official.core import config_definitions as cfg from official.core import exp_factory from official.modeling import hyperparams from official.modeling import optimization from official.vision.configs import common from official.vision.configs import decoders from official.vision.configs import backbones from official.vision.ops import preprocess_ops @dataclasses.dataclass class DenseFeatureConfig(hyperparams.Config): """Config for dense features, such as RGB pixels, masks, heatmaps. The dense features are encoded images in TF examples. Thus they are 1-, 3- or 4-channel. For features with another channel number (e.g. optical flow), they could be encoded in multiple 1-channel features. The default config is for RGB input, with mean and stddev from ImageNet datasets. Only supports 8-bit encoded features with the maximum value = 255. Attributes: feature_name: The key of the feature in TF examples. num_channels: An `int` specifying the number of channels of the feature. mean: A list of floats in the range of [0, 255] representing the mean value of each channel. The length of the list should match num_channels. stddev: A list of floats in the range of [0, 255] representing the standard deviation of each channel. The length should match num_channels. """ feature_name: str = 'image/encoded' num_channels: int = 3 mean: List[float] = dataclasses.field( default_factory=lambda: preprocess_ops.MEAN_RGB ) stddev: List[float] = dataclasses.field( default_factory=lambda: preprocess_ops.STDDEV_RGB ) @dataclasses.dataclass class DataConfig(cfg.DataConfig): """Input config for training.""" image_feature: DenseFeatureConfig = dataclasses.field( default_factory=DenseFeatureConfig ) output_size: List[int] = dataclasses.field(default_factory=list) # If crop_size is specified, image will be resized first to # output_size, then crop of size crop_size will be cropped. crop_size: List[int] = dataclasses.field(default_factory=list) input_path: Union[Sequence[str], str, hyperparams.Config] = None weights: Optional[hyperparams.Config] = None global_batch_size: int = 0 is_training: bool = True dtype: str = 'float32' shuffle_buffer_size: int = 1000 cycle_length: int = 10 # If resize_eval_groundtruth is set to False, original image sizes are used # for eval. In that case, groundtruth_padded_size has to be specified too to # allow for batching the variable input sizes of images. resize_eval_groundtruth: bool = True groundtruth_padded_size: List[int] = dataclasses.field(default_factory=list) aug_scale_min: float = 1.0 aug_scale_max: float = 1.0 aug_rand_hflip: bool = True preserve_aspect_ratio: bool = True aug_policy: Optional[str] = None drop_remainder: bool = True file_type: str = 'tfrecord' decoder: Optional[common.DataDecoder] = dataclasses.field( default_factory=common.DataDecoder ) additional_dense_features: List[DenseFeatureConfig] = dataclasses.field( default_factory=list) @dataclasses.dataclass class SegmentationHead(hyperparams.Config): """Segmentation head config.""" level: int = 3 num_convs: int = 2 num_filters: int = 256 use_depthwise_convolution: bool = False prediction_kernel_size: int = 1 upsample_factor: int = 1 logit_activation: Optional[str] = None # None, 'sigmoid', or 'softmax'. feature_fusion: Optional[ str] = None # None, deeplabv3plus, panoptic_fpn_fusion or pyramid_fusion # deeplabv3plus feature fusion params low_level: Union[int, str] = 2 low_level_num_filters: int = 48 # panoptic_fpn_fusion params decoder_min_level: Optional[Union[int, str]] = None decoder_max_level: Optional[Union[int, str]] = None @dataclasses.dataclass class MaskScoringHead(hyperparams.Config): """Mask Scoring head config.""" num_convs: int = 4 num_filters: int = 128 fc_input_size: List[int] = dataclasses.field(default_factory=list) num_fcs: int = 2 fc_dims: int = 1024 use_depthwise_convolution: bool = False @dataclasses.dataclass class SemanticSegmentationModel(hyperparams.Config): """Semantic segmentation model config.""" num_classes: int = 0 input_size: List[int] = dataclasses.field(default_factory=list) min_level: int = 3 max_level: int = 6 head: SegmentationHead = dataclasses.field(default_factory=SegmentationHead) backbone: backbones.Backbone = dataclasses.field( default_factory=lambda: backbones.Backbone( # pylint: disable=g-long-lambda type='resnet', resnet=backbones.ResNet() ) ) decoder: decoders.Decoder = dataclasses.field( default_factory=lambda: decoders.Decoder(type='identity') ) mask_scoring_head: Optional[MaskScoringHead] = None norm_activation: common.NormActivation = dataclasses.field( default_factory=common.NormActivation ) @dataclasses.dataclass class Losses(hyperparams.Config): """Loss function config.""" loss_weight: float = 1.0 label_smoothing: float = 0.0 ignore_label: int = 255 gt_is_matting_map: bool = False class_weights: List[float] = dataclasses.field(default_factory=list) l2_weight_decay: float = 0.0 use_groundtruth_dimension: bool = True # If true, use binary cross entropy (sigmoid) in loss, otherwise, use # categorical cross entropy (softmax). use_binary_cross_entropy: bool = False top_k_percent_pixels: float = 1.0 mask_scoring_weight: float = 1.0 @dataclasses.dataclass class Evaluation(hyperparams.Config): """Evaluation config.""" report_per_class_iou: bool = True report_train_mean_iou: bool = True # Turning this off can speed up training. @dataclasses.dataclass class ExportConfig(hyperparams.Config): """Model export config.""" # Whether to rescale the predicted mask to the original image size. rescale_output: bool = False @dataclasses.dataclass class SemanticSegmentationTask(cfg.TaskConfig): """The model config.""" model: SemanticSegmentationModel = dataclasses.field( default_factory=SemanticSegmentationModel ) train_data: DataConfig = dataclasses.field( default_factory=lambda: DataConfig(is_training=True) ) validation_data: DataConfig = dataclasses.field( default_factory=lambda: DataConfig(is_training=False) ) losses: Losses = dataclasses.field(default_factory=Losses) evaluation: Evaluation = dataclasses.field(default_factory=Evaluation) train_input_partition_dims: List[int] = dataclasses.field( default_factory=list) eval_input_partition_dims: List[int] = dataclasses.field(default_factory=list) init_checkpoint: Optional[str] = None init_checkpoint_modules: Union[ str, List[str]] = 'all' # all, backbone, and/or decoder export_config: ExportConfig = dataclasses.field(default_factory=ExportConfig) allow_image_summary: bool = True @exp_factory.register_config_factory('semantic_segmentation') def semantic_segmentation() -> cfg.ExperimentConfig: """Semantic segmentation general.""" return cfg.ExperimentConfig( task=SemanticSegmentationTask(), trainer=cfg.TrainerConfig(), restrictions=[ 'task.train_data.is_training != None', 'task.validation_data.is_training != None' ]) # PASCAL VOC 2012 Dataset PASCAL_TRAIN_EXAMPLES = 10582 PASCAL_VAL_EXAMPLES = 1449 PASCAL_INPUT_PATH_BASE = 'gs://**/pascal_voc_seg' @exp_factory.register_config_factory('seg_deeplabv3_pascal') def seg_deeplabv3_pascal() -> cfg.ExperimentConfig: """Image segmentation on pascal voc with resnet deeplabv3.""" train_batch_size = 16 eval_batch_size = 8 steps_per_epoch = PASCAL_TRAIN_EXAMPLES // train_batch_size output_stride = 16 aspp_dilation_rates = [12, 24, 36] # [6, 12, 18] if output_stride = 16 multigrid = [1, 2, 4] stem_type = 'v1' level = int(np.math.log2(output_stride)) config = cfg.ExperimentConfig( task=SemanticSegmentationTask( model=SemanticSegmentationModel( num_classes=21, input_size=[None, None, 3], backbone=backbones.Backbone( type='dilated_resnet', dilated_resnet=backbones.DilatedResNet( model_id=101, output_stride=output_stride, multigrid=multigrid, stem_type=stem_type)), decoder=decoders.Decoder( type='aspp', aspp=decoders.ASPP( level=level, dilation_rates=aspp_dilation_rates)), head=SegmentationHead(level=level, num_convs=0), norm_activation=common.NormActivation( activation='swish', norm_momentum=0.9997, norm_epsilon=1e-3, use_sync_bn=True)), losses=Losses(l2_weight_decay=1e-4), train_data=DataConfig( input_path=os.path.join(PASCAL_INPUT_PATH_BASE, 'train_aug*'), # TODO(arashwan): test changing size to 513 to match deeplab. output_size=[512, 512], is_training=True, global_batch_size=train_batch_size, aug_scale_min=0.5, aug_scale_max=2.0), validation_data=DataConfig( input_path=os.path.join(PASCAL_INPUT_PATH_BASE, 'val*'), output_size=[512, 512], is_training=False, global_batch_size=eval_batch_size, resize_eval_groundtruth=False, groundtruth_padded_size=[512, 512], drop_remainder=False), # resnet101 init_checkpoint='gs://cloud-tpu-checkpoints/vision-2.0/deeplab/deeplab_resnet101_imagenet/ckpt-62400', init_checkpoint_modules='backbone'), trainer=cfg.TrainerConfig( steps_per_loop=steps_per_epoch, summary_interval=steps_per_epoch, checkpoint_interval=steps_per_epoch, train_steps=45 * steps_per_epoch, validation_steps=PASCAL_VAL_EXAMPLES // eval_batch_size, validation_interval=steps_per_epoch, optimizer_config=optimization.OptimizationConfig({ 'optimizer': { 'type': 'sgd', 'sgd': { 'momentum': 0.9 } }, 'learning_rate': { 'type': 'polynomial', 'polynomial': { 'initial_learning_rate': 0.007, 'decay_steps': 45 * steps_per_epoch, 'end_learning_rate': 0.0, 'power': 0.9 } }, 'warmup': { 'type': 'linear', 'linear': { 'warmup_steps': 5 * steps_per_epoch, 'warmup_learning_rate': 0 } } })), restrictions=[ 'task.train_data.is_training != None', 'task.validation_data.is_training != None' ]) return config @exp_factory.register_config_factory('seg_deeplabv3plus_pascal') def seg_deeplabv3plus_pascal() -> cfg.ExperimentConfig: """Image segmentation on pascal voc with resnet deeplabv3+.""" train_batch_size = 16 eval_batch_size = 8 steps_per_epoch = PASCAL_TRAIN_EXAMPLES // train_batch_size output_stride = 16 aspp_dilation_rates = [6, 12, 18] multigrid = [1, 2, 4] stem_type = 'v1' level = int(np.math.log2(output_stride)) config = cfg.ExperimentConfig( task=SemanticSegmentationTask( model=SemanticSegmentationModel( num_classes=21, input_size=[None, None, 3], backbone=backbones.Backbone( type='dilated_resnet', dilated_resnet=backbones.DilatedResNet( model_id=101, output_stride=output_stride, stem_type=stem_type, multigrid=multigrid)), decoder=decoders.Decoder( type='aspp', aspp=decoders.ASPP( level=level, dilation_rates=aspp_dilation_rates)), head=SegmentationHead( level=level, num_convs=2, feature_fusion='deeplabv3plus', low_level=2, low_level_num_filters=48), norm_activation=common.NormActivation( activation='swish', norm_momentum=0.9997, norm_epsilon=1e-3, use_sync_bn=True)), losses=Losses(l2_weight_decay=1e-4), train_data=DataConfig( input_path=os.path.join(PASCAL_INPUT_PATH_BASE, 'train_aug*'), output_size=[512, 512], is_training=True, global_batch_size=train_batch_size, aug_scale_min=0.5, aug_scale_max=2.0), validation_data=DataConfig( input_path=os.path.join(PASCAL_INPUT_PATH_BASE, 'val*'), output_size=[512, 512], is_training=False, global_batch_size=eval_batch_size, resize_eval_groundtruth=False, groundtruth_padded_size=[512, 512], drop_remainder=False), # resnet101 init_checkpoint='gs://cloud-tpu-checkpoints/vision-2.0/deeplab/deeplab_resnet101_imagenet/ckpt-62400', init_checkpoint_modules='backbone'), trainer=cfg.TrainerConfig( steps_per_loop=steps_per_epoch, summary_interval=steps_per_epoch, checkpoint_interval=steps_per_epoch, train_steps=45 * steps_per_epoch, validation_steps=PASCAL_VAL_EXAMPLES // eval_batch_size, validation_interval=steps_per_epoch, optimizer_config=optimization.OptimizationConfig({ 'optimizer': { 'type': 'sgd', 'sgd': { 'momentum': 0.9 } }, 'learning_rate': { 'type': 'polynomial', 'polynomial': { 'initial_learning_rate': 0.007, 'decay_steps': 45 * steps_per_epoch, 'end_learning_rate': 0.0, 'power': 0.9 } }, 'warmup': { 'type': 'linear', 'linear': { 'warmup_steps': 5 * steps_per_epoch, 'warmup_learning_rate': 0 } } })), restrictions=[ 'task.train_data.is_training != None', 'task.validation_data.is_training != None' ]) return config @exp_factory.register_config_factory('seg_resnetfpn_pascal') def seg_resnetfpn_pascal() -> cfg.ExperimentConfig: """Image segmentation on pascal voc with resnet-fpn.""" train_batch_size = 256 eval_batch_size = 32 steps_per_epoch = PASCAL_TRAIN_EXAMPLES // train_batch_size config = cfg.ExperimentConfig( task=SemanticSegmentationTask( model=SemanticSegmentationModel( num_classes=21, input_size=[512, 512, 3], min_level=3, max_level=7, backbone=backbones.Backbone( type='resnet', resnet=backbones.ResNet(model_id=50)), decoder=decoders.Decoder(type='fpn', fpn=decoders.FPN()), head=SegmentationHead(level=3, num_convs=3), norm_activation=common.NormActivation( activation='swish', use_sync_bn=True)), losses=Losses(l2_weight_decay=1e-4), train_data=DataConfig( input_path=os.path.join(PASCAL_INPUT_PATH_BASE, 'train_aug*'), is_training=True, global_batch_size=train_batch_size, aug_scale_min=0.2, aug_scale_max=1.5), validation_data=DataConfig( input_path=os.path.join(PASCAL_INPUT_PATH_BASE, 'val*'), is_training=False, global_batch_size=eval_batch_size, resize_eval_groundtruth=False, groundtruth_padded_size=[512, 512], drop_remainder=False), ), trainer=cfg.TrainerConfig( steps_per_loop=steps_per_epoch, summary_interval=steps_per_epoch, checkpoint_interval=steps_per_epoch, train_steps=450 * steps_per_epoch, validation_steps=PASCAL_VAL_EXAMPLES // eval_batch_size, validation_interval=steps_per_epoch, optimizer_config=optimization.OptimizationConfig({ 'optimizer': { 'type': 'sgd', 'sgd': { 'momentum': 0.9 } }, 'learning_rate': { 'type': 'polynomial', 'polynomial': { 'initial_learning_rate': 0.007, 'decay_steps': 450 * steps_per_epoch, 'end_learning_rate': 0.0, 'power': 0.9 } }, 'warmup': { 'type': 'linear', 'linear': { 'warmup_steps': 5 * steps_per_epoch, 'warmup_learning_rate': 0 } } })), restrictions=[ 'task.train_data.is_training != None', 'task.validation_data.is_training != None' ]) return config @exp_factory.register_config_factory('mnv2_deeplabv3_pascal') def mnv2_deeplabv3_pascal() -> cfg.ExperimentConfig: """Image segmentation on pascal with mobilenetv2 deeplabv3.""" train_batch_size = 16 eval_batch_size = 16 steps_per_epoch = PASCAL_TRAIN_EXAMPLES // train_batch_size output_stride = 16 aspp_dilation_rates = [] level = int(np.math.log2(output_stride)) pool_kernel_size = [] config = cfg.ExperimentConfig( task=SemanticSegmentationTask( model=SemanticSegmentationModel( num_classes=21, input_size=[None, None, 3], backbone=backbones.Backbone( type='mobilenet', mobilenet=backbones.MobileNet( model_id='MobileNetV2', output_stride=output_stride)), decoder=decoders.Decoder( type='aspp', aspp=decoders.ASPP( level=level, dilation_rates=aspp_dilation_rates, pool_kernel_size=pool_kernel_size)), head=SegmentationHead(level=level, num_convs=0), norm_activation=common.NormActivation( activation='relu', norm_momentum=0.99, norm_epsilon=1e-3, use_sync_bn=True)), losses=Losses(l2_weight_decay=4e-5), train_data=DataConfig( input_path=os.path.join(PASCAL_INPUT_PATH_BASE, 'train_aug*'), output_size=[512, 512], is_training=True, global_batch_size=train_batch_size, aug_scale_min=0.5, aug_scale_max=2.0), validation_data=DataConfig( input_path=os.path.join(PASCAL_INPUT_PATH_BASE, 'val*'), output_size=[512, 512], is_training=False, global_batch_size=eval_batch_size, resize_eval_groundtruth=False, groundtruth_padded_size=[512, 512], drop_remainder=False), # mobilenetv2 init_checkpoint='gs://tf_model_garden/cloud/vision-2.0/deeplab/deeplabv3_mobilenetv2_coco/best_ckpt-63', init_checkpoint_modules=['backbone', 'decoder']), trainer=cfg.TrainerConfig( steps_per_loop=steps_per_epoch, summary_interval=steps_per_epoch, checkpoint_interval=steps_per_epoch, train_steps=30000, validation_steps=PASCAL_VAL_EXAMPLES // eval_batch_size, validation_interval=steps_per_epoch, best_checkpoint_eval_metric='mean_iou', best_checkpoint_export_subdir='best_ckpt', best_checkpoint_metric_comp='higher', optimizer_config=optimization.OptimizationConfig({ 'optimizer': { 'type': 'sgd', 'sgd': { 'momentum': 0.9 } }, 'learning_rate': { 'type': 'polynomial', 'polynomial': { 'initial_learning_rate': 0.007 * train_batch_size / 16, 'decay_steps': 30000, 'end_learning_rate': 0.0, 'power': 0.9 } }, 'warmup': { 'type': 'linear', 'linear': { 'warmup_steps': 5 * steps_per_epoch, 'warmup_learning_rate': 0 } } })), restrictions=[ 'task.train_data.is_training != None', 'task.validation_data.is_training != None' ]) return config # Cityscapes Dataset (Download and process the dataset yourself) CITYSCAPES_TRAIN_EXAMPLES = 2975 CITYSCAPES_VAL_EXAMPLES = 500 CITYSCAPES_INPUT_PATH_BASE = 'cityscapes' @exp_factory.register_config_factory('seg_deeplabv3plus_cityscapes') def seg_deeplabv3plus_cityscapes() -> cfg.ExperimentConfig: """Image segmentation on cityscapes with resnet deeplabv3+.""" train_batch_size = 16 eval_batch_size = 16 steps_per_epoch = CITYSCAPES_TRAIN_EXAMPLES // train_batch_size output_stride = 16 aspp_dilation_rates = [6, 12, 18] multigrid = [1, 2, 4] stem_type = 'v1' level = int(np.math.log2(output_stride)) config = cfg.ExperimentConfig( task=SemanticSegmentationTask( model=SemanticSegmentationModel( # Cityscapes uses only 19 semantic classes for train/evaluation. # The void (background) class is ignored in train and evaluation. num_classes=19, input_size=[None, None, 3], backbone=backbones.Backbone( type='dilated_resnet', dilated_resnet=backbones.DilatedResNet( model_id=101, output_stride=output_stride, stem_type=stem_type, multigrid=multigrid)), decoder=decoders.Decoder( type='aspp', aspp=decoders.ASPP( level=level, dilation_rates=aspp_dilation_rates, pool_kernel_size=[512, 1024])), head=SegmentationHead( level=level, num_convs=2, feature_fusion='deeplabv3plus', low_level=2, low_level_num_filters=48), norm_activation=common.NormActivation( activation='swish', norm_momentum=0.99, norm_epsilon=1e-3, use_sync_bn=True)), losses=Losses(l2_weight_decay=1e-4), train_data=DataConfig( input_path=os.path.join(CITYSCAPES_INPUT_PATH_BASE, 'train_fine**'), crop_size=[512, 1024], output_size=[1024, 2048], is_training=True, global_batch_size=train_batch_size, aug_scale_min=0.5, aug_scale_max=2.0), validation_data=DataConfig( input_path=os.path.join(CITYSCAPES_INPUT_PATH_BASE, 'val_fine*'), output_size=[1024, 2048], is_training=False, global_batch_size=eval_batch_size, resize_eval_groundtruth=True, drop_remainder=False), # resnet101 init_checkpoint='gs://cloud-tpu-checkpoints/vision-2.0/deeplab/deeplab_resnet101_imagenet/ckpt-62400', init_checkpoint_modules='backbone'), trainer=cfg.TrainerConfig( steps_per_loop=steps_per_epoch, summary_interval=steps_per_epoch, checkpoint_interval=steps_per_epoch, train_steps=500 * steps_per_epoch, validation_steps=CITYSCAPES_VAL_EXAMPLES // eval_batch_size, validation_interval=steps_per_epoch, optimizer_config=optimization.OptimizationConfig({ 'optimizer': { 'type': 'sgd', 'sgd': { 'momentum': 0.9 } }, 'learning_rate': { 'type': 'polynomial', 'polynomial': { 'initial_learning_rate': 0.01, 'decay_steps': 500 * steps_per_epoch, 'end_learning_rate': 0.0, 'power': 0.9 } }, 'warmup': { 'type': 'linear', 'linear': { 'warmup_steps': 5 * steps_per_epoch, 'warmup_learning_rate': 0 } } })), restrictions=[ 'task.train_data.is_training != None', 'task.validation_data.is_training != None' ]) return config @exp_factory.register_config_factory('mnv2_deeplabv3_cityscapes') def mnv2_deeplabv3_cityscapes() -> cfg.ExperimentConfig: """Image segmentation on cityscapes with mobilenetv2 deeplabv3.""" train_batch_size = 16 eval_batch_size = 16 steps_per_epoch = CITYSCAPES_TRAIN_EXAMPLES // train_batch_size output_stride = 16 aspp_dilation_rates = [] pool_kernel_size = [512, 1024] level = int(np.math.log2(output_stride)) config = cfg.ExperimentConfig( task=SemanticSegmentationTask( model=SemanticSegmentationModel( # Cityscapes uses only 19 semantic classes for train/evaluation. # The void (background) class is ignored in train and evaluation. num_classes=19, input_size=[None, None, 3], backbone=backbones.Backbone( type='mobilenet', mobilenet=backbones.MobileNet( model_id='MobileNetV2', output_stride=output_stride)), decoder=decoders.Decoder( type='aspp', aspp=decoders.ASPP( level=level, dilation_rates=aspp_dilation_rates, pool_kernel_size=pool_kernel_size)), head=SegmentationHead(level=level, num_convs=0), norm_activation=common.NormActivation( activation='relu', norm_momentum=0.99, norm_epsilon=1e-3, use_sync_bn=True)), losses=Losses(l2_weight_decay=4e-5), train_data=DataConfig( input_path=os.path.join(CITYSCAPES_INPUT_PATH_BASE, 'train_fine**'), crop_size=[512, 1024], output_size=[1024, 2048], is_training=True, global_batch_size=train_batch_size, aug_scale_min=0.5, aug_scale_max=2.0), validation_data=DataConfig( input_path=os.path.join(CITYSCAPES_INPUT_PATH_BASE, 'val_fine*'), output_size=[1024, 2048], is_training=False, global_batch_size=eval_batch_size, resize_eval_groundtruth=True, drop_remainder=False), # Coco pre-trained mobilenetv2 checkpoint init_checkpoint='gs://tf_model_garden/cloud/vision-2.0/deeplab/deeplabv3_mobilenetv2_coco/best_ckpt-63', init_checkpoint_modules='backbone'), trainer=cfg.TrainerConfig( steps_per_loop=steps_per_epoch, summary_interval=steps_per_epoch, checkpoint_interval=steps_per_epoch, train_steps=100000, validation_steps=CITYSCAPES_VAL_EXAMPLES // eval_batch_size, validation_interval=steps_per_epoch, best_checkpoint_eval_metric='mean_iou', best_checkpoint_export_subdir='best_ckpt', best_checkpoint_metric_comp='higher', optimizer_config=optimization.OptimizationConfig({ 'optimizer': { 'type': 'sgd', 'sgd': { 'momentum': 0.9 } }, 'learning_rate': { 'type': 'polynomial', 'polynomial': { 'initial_learning_rate': 0.01, 'decay_steps': 100000, 'end_learning_rate': 0.0, 'power': 0.9 } }, 'warmup': { 'type': 'linear', 'linear': { 'warmup_steps': 5 * steps_per_epoch, 'warmup_learning_rate': 0 } } })), restrictions=[ 'task.train_data.is_training != None', 'task.validation_data.is_training != None' ]) return config @exp_factory.register_config_factory('mnv2_deeplabv3plus_cityscapes') def mnv2_deeplabv3plus_cityscapes() -> cfg.ExperimentConfig: """Image segmentation on cityscapes with mobilenetv2 deeplabv3plus.""" config = mnv2_deeplabv3_cityscapes() config.task.model.head = SegmentationHead( level=4, num_convs=2, feature_fusion='deeplabv3plus', use_depthwise_convolution=True, low_level='2/depthwise', low_level_num_filters=48) config.task.model.backbone.mobilenet.output_intermediate_endpoints = True return config