Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Semantic segmentation configuration definition.""" | |
import dataclasses | |
import os | |
from typing import List, Optional, Sequence, Union | |
import numpy as np | |
from official.core import config_definitions as cfg | |
from official.core import exp_factory | |
from official.modeling import hyperparams | |
from official.modeling import optimization | |
from official.vision.configs import common | |
from official.vision.configs import decoders | |
from official.vision.configs import backbones | |
from official.vision.ops import preprocess_ops | |
class DenseFeatureConfig(hyperparams.Config): | |
"""Config for dense features, such as RGB pixels, masks, heatmaps. | |
The dense features are encoded images in TF examples. Thus they are | |
1-, 3- or 4-channel. For features with another channel number (e.g. | |
optical flow), they could be encoded in multiple 1-channel features. | |
The default config is for RGB input, with mean and stddev from ImageNet | |
datasets. Only supports 8-bit encoded features with the maximum value = 255. | |
Attributes: | |
feature_name: The key of the feature in TF examples. | |
num_channels: An `int` specifying the number of channels of the feature. | |
mean: A list of floats in the range of [0, 255] representing the mean value | |
of each channel. The length of the list should match num_channels. | |
stddev: A list of floats in the range of [0, 255] representing the standard | |
deviation of each channel. The length should match num_channels. | |
""" | |
feature_name: str = 'image/encoded' | |
num_channels: int = 3 | |
mean: List[float] = dataclasses.field( | |
default_factory=lambda: preprocess_ops.MEAN_RGB | |
) | |
stddev: List[float] = dataclasses.field( | |
default_factory=lambda: preprocess_ops.STDDEV_RGB | |
) | |
class DataConfig(cfg.DataConfig): | |
"""Input config for training.""" | |
image_feature: DenseFeatureConfig = dataclasses.field( | |
default_factory=DenseFeatureConfig | |
) | |
output_size: List[int] = dataclasses.field(default_factory=list) | |
# If crop_size is specified, image will be resized first to | |
# output_size, then crop of size crop_size will be cropped. | |
crop_size: List[int] = dataclasses.field(default_factory=list) | |
input_path: Union[Sequence[str], str, hyperparams.Config] = None | |
weights: Optional[hyperparams.Config] = None | |
global_batch_size: int = 0 | |
is_training: bool = True | |
dtype: str = 'float32' | |
shuffle_buffer_size: int = 1000 | |
cycle_length: int = 10 | |
# If resize_eval_groundtruth is set to False, original image sizes are used | |
# for eval. In that case, groundtruth_padded_size has to be specified too to | |
# allow for batching the variable input sizes of images. | |
resize_eval_groundtruth: bool = True | |
groundtruth_padded_size: List[int] = dataclasses.field(default_factory=list) | |
aug_scale_min: float = 1.0 | |
aug_scale_max: float = 1.0 | |
aug_rand_hflip: bool = True | |
preserve_aspect_ratio: bool = True | |
aug_policy: Optional[str] = None | |
drop_remainder: bool = True | |
file_type: str = 'tfrecord' | |
decoder: Optional[common.DataDecoder] = dataclasses.field( | |
default_factory=common.DataDecoder | |
) | |
additional_dense_features: List[DenseFeatureConfig] = dataclasses.field( | |
default_factory=list) | |
class SegmentationHead(hyperparams.Config): | |
"""Segmentation head config.""" | |
level: int = 3 | |
num_convs: int = 2 | |
num_filters: int = 256 | |
use_depthwise_convolution: bool = False | |
prediction_kernel_size: int = 1 | |
upsample_factor: int = 1 | |
logit_activation: Optional[str] = None # None, 'sigmoid', or 'softmax'. | |
feature_fusion: Optional[ | |
str] = None # None, deeplabv3plus, panoptic_fpn_fusion or pyramid_fusion | |
# deeplabv3plus feature fusion params | |
low_level: Union[int, str] = 2 | |
low_level_num_filters: int = 48 | |
# panoptic_fpn_fusion params | |
decoder_min_level: Optional[Union[int, str]] = None | |
decoder_max_level: Optional[Union[int, str]] = None | |
class MaskScoringHead(hyperparams.Config): | |
"""Mask Scoring head config.""" | |
num_convs: int = 4 | |
num_filters: int = 128 | |
fc_input_size: List[int] = dataclasses.field(default_factory=list) | |
num_fcs: int = 2 | |
fc_dims: int = 1024 | |
use_depthwise_convolution: bool = False | |
class SemanticSegmentationModel(hyperparams.Config): | |
"""Semantic segmentation model config.""" | |
num_classes: int = 0 | |
input_size: List[int] = dataclasses.field(default_factory=list) | |
min_level: int = 3 | |
max_level: int = 6 | |
head: SegmentationHead = dataclasses.field(default_factory=SegmentationHead) | |
backbone: backbones.Backbone = dataclasses.field( | |
default_factory=lambda: backbones.Backbone( # pylint: disable=g-long-lambda | |
type='resnet', resnet=backbones.ResNet() | |
) | |
) | |
decoder: decoders.Decoder = dataclasses.field( | |
default_factory=lambda: decoders.Decoder(type='identity') | |
) | |
mask_scoring_head: Optional[MaskScoringHead] = None | |
norm_activation: common.NormActivation = dataclasses.field( | |
default_factory=common.NormActivation | |
) | |
class Losses(hyperparams.Config): | |
"""Loss function config.""" | |
loss_weight: float = 1.0 | |
label_smoothing: float = 0.0 | |
ignore_label: int = 255 | |
gt_is_matting_map: bool = False | |
class_weights: List[float] = dataclasses.field(default_factory=list) | |
l2_weight_decay: float = 0.0 | |
use_groundtruth_dimension: bool = True | |
# If true, use binary cross entropy (sigmoid) in loss, otherwise, use | |
# categorical cross entropy (softmax). | |
use_binary_cross_entropy: bool = False | |
top_k_percent_pixels: float = 1.0 | |
mask_scoring_weight: float = 1.0 | |
class Evaluation(hyperparams.Config): | |
"""Evaluation config.""" | |
report_per_class_iou: bool = True | |
report_train_mean_iou: bool = True # Turning this off can speed up training. | |
class ExportConfig(hyperparams.Config): | |
"""Model export config.""" | |
# Whether to rescale the predicted mask to the original image size. | |
rescale_output: bool = False | |
class SemanticSegmentationTask(cfg.TaskConfig): | |
"""The model config.""" | |
model: SemanticSegmentationModel = dataclasses.field( | |
default_factory=SemanticSegmentationModel | |
) | |
train_data: DataConfig = dataclasses.field( | |
default_factory=lambda: DataConfig(is_training=True) | |
) | |
validation_data: DataConfig = dataclasses.field( | |
default_factory=lambda: DataConfig(is_training=False) | |
) | |
losses: Losses = dataclasses.field(default_factory=Losses) | |
evaluation: Evaluation = dataclasses.field(default_factory=Evaluation) | |
train_input_partition_dims: List[int] = dataclasses.field( | |
default_factory=list) | |
eval_input_partition_dims: List[int] = dataclasses.field(default_factory=list) | |
init_checkpoint: Optional[str] = None | |
init_checkpoint_modules: Union[ | |
str, List[str]] = 'all' # all, backbone, and/or decoder | |
export_config: ExportConfig = dataclasses.field(default_factory=ExportConfig) | |
allow_image_summary: bool = True | |
def semantic_segmentation() -> cfg.ExperimentConfig: | |
"""Semantic segmentation general.""" | |
return cfg.ExperimentConfig( | |
task=SemanticSegmentationTask(), | |
trainer=cfg.TrainerConfig(), | |
restrictions=[ | |
'task.train_data.is_training != None', | |
'task.validation_data.is_training != None' | |
]) | |
# PASCAL VOC 2012 Dataset | |
PASCAL_TRAIN_EXAMPLES = 10582 | |
PASCAL_VAL_EXAMPLES = 1449 | |
PASCAL_INPUT_PATH_BASE = 'gs://**/pascal_voc_seg' | |
def seg_deeplabv3_pascal() -> cfg.ExperimentConfig: | |
"""Image segmentation on pascal voc with resnet deeplabv3.""" | |
train_batch_size = 16 | |
eval_batch_size = 8 | |
steps_per_epoch = PASCAL_TRAIN_EXAMPLES // train_batch_size | |
output_stride = 16 | |
aspp_dilation_rates = [12, 24, 36] # [6, 12, 18] if output_stride = 16 | |
multigrid = [1, 2, 4] | |
stem_type = 'v1' | |
level = int(np.math.log2(output_stride)) | |
config = cfg.ExperimentConfig( | |
task=SemanticSegmentationTask( | |
model=SemanticSegmentationModel( | |
num_classes=21, | |
input_size=[None, None, 3], | |
backbone=backbones.Backbone( | |
type='dilated_resnet', | |
dilated_resnet=backbones.DilatedResNet( | |
model_id=101, | |
output_stride=output_stride, | |
multigrid=multigrid, | |
stem_type=stem_type)), | |
decoder=decoders.Decoder( | |
type='aspp', | |
aspp=decoders.ASPP( | |
level=level, dilation_rates=aspp_dilation_rates)), | |
head=SegmentationHead(level=level, num_convs=0), | |
norm_activation=common.NormActivation( | |
activation='swish', | |
norm_momentum=0.9997, | |
norm_epsilon=1e-3, | |
use_sync_bn=True)), | |
losses=Losses(l2_weight_decay=1e-4), | |
train_data=DataConfig( | |
input_path=os.path.join(PASCAL_INPUT_PATH_BASE, 'train_aug*'), | |
# TODO(arashwan): test changing size to 513 to match deeplab. | |
output_size=[512, 512], | |
is_training=True, | |
global_batch_size=train_batch_size, | |
aug_scale_min=0.5, | |
aug_scale_max=2.0), | |
validation_data=DataConfig( | |
input_path=os.path.join(PASCAL_INPUT_PATH_BASE, 'val*'), | |
output_size=[512, 512], | |
is_training=False, | |
global_batch_size=eval_batch_size, | |
resize_eval_groundtruth=False, | |
groundtruth_padded_size=[512, 512], | |
drop_remainder=False), | |
# resnet101 | |
init_checkpoint='gs://cloud-tpu-checkpoints/vision-2.0/deeplab/deeplab_resnet101_imagenet/ckpt-62400', | |
init_checkpoint_modules='backbone'), | |
trainer=cfg.TrainerConfig( | |
steps_per_loop=steps_per_epoch, | |
summary_interval=steps_per_epoch, | |
checkpoint_interval=steps_per_epoch, | |
train_steps=45 * steps_per_epoch, | |
validation_steps=PASCAL_VAL_EXAMPLES // eval_batch_size, | |
validation_interval=steps_per_epoch, | |
optimizer_config=optimization.OptimizationConfig({ | |
'optimizer': { | |
'type': 'sgd', | |
'sgd': { | |
'momentum': 0.9 | |
} | |
}, | |
'learning_rate': { | |
'type': 'polynomial', | |
'polynomial': { | |
'initial_learning_rate': 0.007, | |
'decay_steps': 45 * steps_per_epoch, | |
'end_learning_rate': 0.0, | |
'power': 0.9 | |
} | |
}, | |
'warmup': { | |
'type': 'linear', | |
'linear': { | |
'warmup_steps': 5 * steps_per_epoch, | |
'warmup_learning_rate': 0 | |
} | |
} | |
})), | |
restrictions=[ | |
'task.train_data.is_training != None', | |
'task.validation_data.is_training != None' | |
]) | |
return config | |
def seg_deeplabv3plus_pascal() -> cfg.ExperimentConfig: | |
"""Image segmentation on pascal voc with resnet deeplabv3+.""" | |
train_batch_size = 16 | |
eval_batch_size = 8 | |
steps_per_epoch = PASCAL_TRAIN_EXAMPLES // train_batch_size | |
output_stride = 16 | |
aspp_dilation_rates = [6, 12, 18] | |
multigrid = [1, 2, 4] | |
stem_type = 'v1' | |
level = int(np.math.log2(output_stride)) | |
config = cfg.ExperimentConfig( | |
task=SemanticSegmentationTask( | |
model=SemanticSegmentationModel( | |
num_classes=21, | |
input_size=[None, None, 3], | |
backbone=backbones.Backbone( | |
type='dilated_resnet', | |
dilated_resnet=backbones.DilatedResNet( | |
model_id=101, | |
output_stride=output_stride, | |
stem_type=stem_type, | |
multigrid=multigrid)), | |
decoder=decoders.Decoder( | |
type='aspp', | |
aspp=decoders.ASPP( | |
level=level, dilation_rates=aspp_dilation_rates)), | |
head=SegmentationHead( | |
level=level, | |
num_convs=2, | |
feature_fusion='deeplabv3plus', | |
low_level=2, | |
low_level_num_filters=48), | |
norm_activation=common.NormActivation( | |
activation='swish', | |
norm_momentum=0.9997, | |
norm_epsilon=1e-3, | |
use_sync_bn=True)), | |
losses=Losses(l2_weight_decay=1e-4), | |
train_data=DataConfig( | |
input_path=os.path.join(PASCAL_INPUT_PATH_BASE, 'train_aug*'), | |
output_size=[512, 512], | |
is_training=True, | |
global_batch_size=train_batch_size, | |
aug_scale_min=0.5, | |
aug_scale_max=2.0), | |
validation_data=DataConfig( | |
input_path=os.path.join(PASCAL_INPUT_PATH_BASE, 'val*'), | |
output_size=[512, 512], | |
is_training=False, | |
global_batch_size=eval_batch_size, | |
resize_eval_groundtruth=False, | |
groundtruth_padded_size=[512, 512], | |
drop_remainder=False), | |
# resnet101 | |
init_checkpoint='gs://cloud-tpu-checkpoints/vision-2.0/deeplab/deeplab_resnet101_imagenet/ckpt-62400', | |
init_checkpoint_modules='backbone'), | |
trainer=cfg.TrainerConfig( | |
steps_per_loop=steps_per_epoch, | |
summary_interval=steps_per_epoch, | |
checkpoint_interval=steps_per_epoch, | |
train_steps=45 * steps_per_epoch, | |
validation_steps=PASCAL_VAL_EXAMPLES // eval_batch_size, | |
validation_interval=steps_per_epoch, | |
optimizer_config=optimization.OptimizationConfig({ | |
'optimizer': { | |
'type': 'sgd', | |
'sgd': { | |
'momentum': 0.9 | |
} | |
}, | |
'learning_rate': { | |
'type': 'polynomial', | |
'polynomial': { | |
'initial_learning_rate': 0.007, | |
'decay_steps': 45 * steps_per_epoch, | |
'end_learning_rate': 0.0, | |
'power': 0.9 | |
} | |
}, | |
'warmup': { | |
'type': 'linear', | |
'linear': { | |
'warmup_steps': 5 * steps_per_epoch, | |
'warmup_learning_rate': 0 | |
} | |
} | |
})), | |
restrictions=[ | |
'task.train_data.is_training != None', | |
'task.validation_data.is_training != None' | |
]) | |
return config | |
def seg_resnetfpn_pascal() -> cfg.ExperimentConfig: | |
"""Image segmentation on pascal voc with resnet-fpn.""" | |
train_batch_size = 256 | |
eval_batch_size = 32 | |
steps_per_epoch = PASCAL_TRAIN_EXAMPLES // train_batch_size | |
config = cfg.ExperimentConfig( | |
task=SemanticSegmentationTask( | |
model=SemanticSegmentationModel( | |
num_classes=21, | |
input_size=[512, 512, 3], | |
min_level=3, | |
max_level=7, | |
backbone=backbones.Backbone( | |
type='resnet', resnet=backbones.ResNet(model_id=50)), | |
decoder=decoders.Decoder(type='fpn', fpn=decoders.FPN()), | |
head=SegmentationHead(level=3, num_convs=3), | |
norm_activation=common.NormActivation( | |
activation='swish', use_sync_bn=True)), | |
losses=Losses(l2_weight_decay=1e-4), | |
train_data=DataConfig( | |
input_path=os.path.join(PASCAL_INPUT_PATH_BASE, 'train_aug*'), | |
is_training=True, | |
global_batch_size=train_batch_size, | |
aug_scale_min=0.2, | |
aug_scale_max=1.5), | |
validation_data=DataConfig( | |
input_path=os.path.join(PASCAL_INPUT_PATH_BASE, 'val*'), | |
is_training=False, | |
global_batch_size=eval_batch_size, | |
resize_eval_groundtruth=False, | |
groundtruth_padded_size=[512, 512], | |
drop_remainder=False), | |
), | |
trainer=cfg.TrainerConfig( | |
steps_per_loop=steps_per_epoch, | |
summary_interval=steps_per_epoch, | |
checkpoint_interval=steps_per_epoch, | |
train_steps=450 * steps_per_epoch, | |
validation_steps=PASCAL_VAL_EXAMPLES // eval_batch_size, | |
validation_interval=steps_per_epoch, | |
optimizer_config=optimization.OptimizationConfig({ | |
'optimizer': { | |
'type': 'sgd', | |
'sgd': { | |
'momentum': 0.9 | |
} | |
}, | |
'learning_rate': { | |
'type': 'polynomial', | |
'polynomial': { | |
'initial_learning_rate': 0.007, | |
'decay_steps': 450 * steps_per_epoch, | |
'end_learning_rate': 0.0, | |
'power': 0.9 | |
} | |
}, | |
'warmup': { | |
'type': 'linear', | |
'linear': { | |
'warmup_steps': 5 * steps_per_epoch, | |
'warmup_learning_rate': 0 | |
} | |
} | |
})), | |
restrictions=[ | |
'task.train_data.is_training != None', | |
'task.validation_data.is_training != None' | |
]) | |
return config | |
def mnv2_deeplabv3_pascal() -> cfg.ExperimentConfig: | |
"""Image segmentation on pascal with mobilenetv2 deeplabv3.""" | |
train_batch_size = 16 | |
eval_batch_size = 16 | |
steps_per_epoch = PASCAL_TRAIN_EXAMPLES // train_batch_size | |
output_stride = 16 | |
aspp_dilation_rates = [] | |
level = int(np.math.log2(output_stride)) | |
pool_kernel_size = [] | |
config = cfg.ExperimentConfig( | |
task=SemanticSegmentationTask( | |
model=SemanticSegmentationModel( | |
num_classes=21, | |
input_size=[None, None, 3], | |
backbone=backbones.Backbone( | |
type='mobilenet', | |
mobilenet=backbones.MobileNet( | |
model_id='MobileNetV2', output_stride=output_stride)), | |
decoder=decoders.Decoder( | |
type='aspp', | |
aspp=decoders.ASPP( | |
level=level, | |
dilation_rates=aspp_dilation_rates, | |
pool_kernel_size=pool_kernel_size)), | |
head=SegmentationHead(level=level, num_convs=0), | |
norm_activation=common.NormActivation( | |
activation='relu', | |
norm_momentum=0.99, | |
norm_epsilon=1e-3, | |
use_sync_bn=True)), | |
losses=Losses(l2_weight_decay=4e-5), | |
train_data=DataConfig( | |
input_path=os.path.join(PASCAL_INPUT_PATH_BASE, 'train_aug*'), | |
output_size=[512, 512], | |
is_training=True, | |
global_batch_size=train_batch_size, | |
aug_scale_min=0.5, | |
aug_scale_max=2.0), | |
validation_data=DataConfig( | |
input_path=os.path.join(PASCAL_INPUT_PATH_BASE, 'val*'), | |
output_size=[512, 512], | |
is_training=False, | |
global_batch_size=eval_batch_size, | |
resize_eval_groundtruth=False, | |
groundtruth_padded_size=[512, 512], | |
drop_remainder=False), | |
# mobilenetv2 | |
init_checkpoint='gs://tf_model_garden/cloud/vision-2.0/deeplab/deeplabv3_mobilenetv2_coco/best_ckpt-63', | |
init_checkpoint_modules=['backbone', 'decoder']), | |
trainer=cfg.TrainerConfig( | |
steps_per_loop=steps_per_epoch, | |
summary_interval=steps_per_epoch, | |
checkpoint_interval=steps_per_epoch, | |
train_steps=30000, | |
validation_steps=PASCAL_VAL_EXAMPLES // eval_batch_size, | |
validation_interval=steps_per_epoch, | |
best_checkpoint_eval_metric='mean_iou', | |
best_checkpoint_export_subdir='best_ckpt', | |
best_checkpoint_metric_comp='higher', | |
optimizer_config=optimization.OptimizationConfig({ | |
'optimizer': { | |
'type': 'sgd', | |
'sgd': { | |
'momentum': 0.9 | |
} | |
}, | |
'learning_rate': { | |
'type': 'polynomial', | |
'polynomial': { | |
'initial_learning_rate': 0.007 * train_batch_size / 16, | |
'decay_steps': 30000, | |
'end_learning_rate': 0.0, | |
'power': 0.9 | |
} | |
}, | |
'warmup': { | |
'type': 'linear', | |
'linear': { | |
'warmup_steps': 5 * steps_per_epoch, | |
'warmup_learning_rate': 0 | |
} | |
} | |
})), | |
restrictions=[ | |
'task.train_data.is_training != None', | |
'task.validation_data.is_training != None' | |
]) | |
return config | |
# Cityscapes Dataset (Download and process the dataset yourself) | |
CITYSCAPES_TRAIN_EXAMPLES = 2975 | |
CITYSCAPES_VAL_EXAMPLES = 500 | |
CITYSCAPES_INPUT_PATH_BASE = 'cityscapes' | |
def seg_deeplabv3plus_cityscapes() -> cfg.ExperimentConfig: | |
"""Image segmentation on cityscapes with resnet deeplabv3+.""" | |
train_batch_size = 16 | |
eval_batch_size = 16 | |
steps_per_epoch = CITYSCAPES_TRAIN_EXAMPLES // train_batch_size | |
output_stride = 16 | |
aspp_dilation_rates = [6, 12, 18] | |
multigrid = [1, 2, 4] | |
stem_type = 'v1' | |
level = int(np.math.log2(output_stride)) | |
config = cfg.ExperimentConfig( | |
task=SemanticSegmentationTask( | |
model=SemanticSegmentationModel( | |
# Cityscapes uses only 19 semantic classes for train/evaluation. | |
# The void (background) class is ignored in train and evaluation. | |
num_classes=19, | |
input_size=[None, None, 3], | |
backbone=backbones.Backbone( | |
type='dilated_resnet', | |
dilated_resnet=backbones.DilatedResNet( | |
model_id=101, | |
output_stride=output_stride, | |
stem_type=stem_type, | |
multigrid=multigrid)), | |
decoder=decoders.Decoder( | |
type='aspp', | |
aspp=decoders.ASPP( | |
level=level, | |
dilation_rates=aspp_dilation_rates, | |
pool_kernel_size=[512, 1024])), | |
head=SegmentationHead( | |
level=level, | |
num_convs=2, | |
feature_fusion='deeplabv3plus', | |
low_level=2, | |
low_level_num_filters=48), | |
norm_activation=common.NormActivation( | |
activation='swish', | |
norm_momentum=0.99, | |
norm_epsilon=1e-3, | |
use_sync_bn=True)), | |
losses=Losses(l2_weight_decay=1e-4), | |
train_data=DataConfig( | |
input_path=os.path.join(CITYSCAPES_INPUT_PATH_BASE, | |
'train_fine**'), | |
crop_size=[512, 1024], | |
output_size=[1024, 2048], | |
is_training=True, | |
global_batch_size=train_batch_size, | |
aug_scale_min=0.5, | |
aug_scale_max=2.0), | |
validation_data=DataConfig( | |
input_path=os.path.join(CITYSCAPES_INPUT_PATH_BASE, 'val_fine*'), | |
output_size=[1024, 2048], | |
is_training=False, | |
global_batch_size=eval_batch_size, | |
resize_eval_groundtruth=True, | |
drop_remainder=False), | |
# resnet101 | |
init_checkpoint='gs://cloud-tpu-checkpoints/vision-2.0/deeplab/deeplab_resnet101_imagenet/ckpt-62400', | |
init_checkpoint_modules='backbone'), | |
trainer=cfg.TrainerConfig( | |
steps_per_loop=steps_per_epoch, | |
summary_interval=steps_per_epoch, | |
checkpoint_interval=steps_per_epoch, | |
train_steps=500 * steps_per_epoch, | |
validation_steps=CITYSCAPES_VAL_EXAMPLES // eval_batch_size, | |
validation_interval=steps_per_epoch, | |
optimizer_config=optimization.OptimizationConfig({ | |
'optimizer': { | |
'type': 'sgd', | |
'sgd': { | |
'momentum': 0.9 | |
} | |
}, | |
'learning_rate': { | |
'type': 'polynomial', | |
'polynomial': { | |
'initial_learning_rate': 0.01, | |
'decay_steps': 500 * steps_per_epoch, | |
'end_learning_rate': 0.0, | |
'power': 0.9 | |
} | |
}, | |
'warmup': { | |
'type': 'linear', | |
'linear': { | |
'warmup_steps': 5 * steps_per_epoch, | |
'warmup_learning_rate': 0 | |
} | |
} | |
})), | |
restrictions=[ | |
'task.train_data.is_training != None', | |
'task.validation_data.is_training != None' | |
]) | |
return config | |
def mnv2_deeplabv3_cityscapes() -> cfg.ExperimentConfig: | |
"""Image segmentation on cityscapes with mobilenetv2 deeplabv3.""" | |
train_batch_size = 16 | |
eval_batch_size = 16 | |
steps_per_epoch = CITYSCAPES_TRAIN_EXAMPLES // train_batch_size | |
output_stride = 16 | |
aspp_dilation_rates = [] | |
pool_kernel_size = [512, 1024] | |
level = int(np.math.log2(output_stride)) | |
config = cfg.ExperimentConfig( | |
task=SemanticSegmentationTask( | |
model=SemanticSegmentationModel( | |
# Cityscapes uses only 19 semantic classes for train/evaluation. | |
# The void (background) class is ignored in train and evaluation. | |
num_classes=19, | |
input_size=[None, None, 3], | |
backbone=backbones.Backbone( | |
type='mobilenet', | |
mobilenet=backbones.MobileNet( | |
model_id='MobileNetV2', output_stride=output_stride)), | |
decoder=decoders.Decoder( | |
type='aspp', | |
aspp=decoders.ASPP( | |
level=level, | |
dilation_rates=aspp_dilation_rates, | |
pool_kernel_size=pool_kernel_size)), | |
head=SegmentationHead(level=level, num_convs=0), | |
norm_activation=common.NormActivation( | |
activation='relu', | |
norm_momentum=0.99, | |
norm_epsilon=1e-3, | |
use_sync_bn=True)), | |
losses=Losses(l2_weight_decay=4e-5), | |
train_data=DataConfig( | |
input_path=os.path.join(CITYSCAPES_INPUT_PATH_BASE, | |
'train_fine**'), | |
crop_size=[512, 1024], | |
output_size=[1024, 2048], | |
is_training=True, | |
global_batch_size=train_batch_size, | |
aug_scale_min=0.5, | |
aug_scale_max=2.0), | |
validation_data=DataConfig( | |
input_path=os.path.join(CITYSCAPES_INPUT_PATH_BASE, 'val_fine*'), | |
output_size=[1024, 2048], | |
is_training=False, | |
global_batch_size=eval_batch_size, | |
resize_eval_groundtruth=True, | |
drop_remainder=False), | |
# Coco pre-trained mobilenetv2 checkpoint | |
init_checkpoint='gs://tf_model_garden/cloud/vision-2.0/deeplab/deeplabv3_mobilenetv2_coco/best_ckpt-63', | |
init_checkpoint_modules='backbone'), | |
trainer=cfg.TrainerConfig( | |
steps_per_loop=steps_per_epoch, | |
summary_interval=steps_per_epoch, | |
checkpoint_interval=steps_per_epoch, | |
train_steps=100000, | |
validation_steps=CITYSCAPES_VAL_EXAMPLES // eval_batch_size, | |
validation_interval=steps_per_epoch, | |
best_checkpoint_eval_metric='mean_iou', | |
best_checkpoint_export_subdir='best_ckpt', | |
best_checkpoint_metric_comp='higher', | |
optimizer_config=optimization.OptimizationConfig({ | |
'optimizer': { | |
'type': 'sgd', | |
'sgd': { | |
'momentum': 0.9 | |
} | |
}, | |
'learning_rate': { | |
'type': 'polynomial', | |
'polynomial': { | |
'initial_learning_rate': 0.01, | |
'decay_steps': 100000, | |
'end_learning_rate': 0.0, | |
'power': 0.9 | |
} | |
}, | |
'warmup': { | |
'type': 'linear', | |
'linear': { | |
'warmup_steps': 5 * steps_per_epoch, | |
'warmup_learning_rate': 0 | |
} | |
} | |
})), | |
restrictions=[ | |
'task.train_data.is_training != None', | |
'task.validation_data.is_training != None' | |
]) | |
return config | |
def mnv2_deeplabv3plus_cityscapes() -> cfg.ExperimentConfig: | |
"""Image segmentation on cityscapes with mobilenetv2 deeplabv3plus.""" | |
config = mnv2_deeplabv3_cityscapes() | |
config.task.model.head = SegmentationHead( | |
level=4, | |
num_convs=2, | |
feature_fusion='deeplabv3plus', | |
use_depthwise_convolution=True, | |
low_level='2/depthwise', | |
low_level_num_filters=48) | |
config.task.model.backbone.mobilenet.output_intermediate_endpoints = True | |
return config | |