|
# Tutorial 5: Training Tricks |
|
|
|
MMSegmentation support following training tricks out of box. |
|
|
|
## Different Learning Rate(LR) for Backbone and Heads |
|
|
|
In semantic segmentation, some methods make the LR of heads larger than backbone to achieve better performance or faster convergence. |
|
|
|
In MMSegmentation, you may add following lines to config to make the LR of heads 10 times of backbone. |
|
|
|
```python |
|
optimizer=dict( |
|
paramwise_cfg = dict( |
|
custom_keys={ |
|
'head': dict(lr_mult=10.)})) |
|
``` |
|
|
|
With this modification, the LR of any parameter group with `'head'` in name will be multiplied by 10. |
|
You may refer to [MMCV doc](https://mmcv.readthedocs.io/en/latest/api.html#mmcv.runner.DefaultOptimizerConstructor) for further details. |
|
|
|
## Online Hard Example Mining (OHEM) |
|
|
|
We implement pixel sampler [here](https://github.com/open-mmlab/mmsegmentation/tree/master/mmseg/core/seg/sampler) for training sampling. |
|
Here is an example config of training PSPNet with OHEM enabled. |
|
|
|
```python |
|
_base_ = './pspnet_r50-d8_512x1024_40k_cityscapes.py' |
|
model=dict( |
|
decode_head=dict( |
|
sampler=dict(type='OHEMPixelSampler', thresh=0.7, min_kept=100000)) ) |
|
``` |
|
|
|
In this way, only pixels with confidence score under 0.7 are used to train. And we keep at least 100000 pixels during training. If `thresh` is not specified, pixels of top ``min_kept`` loss will be selected. |
|
|
|
## Class Balanced Loss |
|
|
|
For dataset that is not balanced in classes distribution, you may change the loss weight of each class. |
|
Here is an example for cityscapes dataset. |
|
|
|
```python |
|
_base_ = './pspnet_r50-d8_512x1024_40k_cityscapes.py' |
|
model=dict( |
|
decode_head=dict( |
|
loss_decode=dict( |
|
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0, |
|
# DeepLab used this class weight for cityscapes |
|
class_weight=[0.8373, 0.9180, 0.8660, 1.0345, 1.0166, 0.9969, 0.9754, |
|
1.0489, 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037, |
|
1.0865, 1.0955, 1.0865, 1.1529, 1.0507]))) |
|
``` |
|
|
|
`class_weight` will be passed into `CrossEntropyLoss` as `weight` argument. Please refer to [PyTorch Doc](https://pytorch.org/docs/stable/nn.html?highlight=crossentropy#torch.nn.CrossEntropyLoss) for details. |
|
|