File size: 6,522 Bytes
b13b124 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 |
# Tutorial 4: Customize Models
## Customize optimizer
Assume you want to add a optimizer named as `MyOptimizer`, which has arguments `a`, `b`, and `c`.
You need to first implement the new optimizer in a file, e.g., in `mmseg/core/optimizer/my_optimizer.py`:
```python
from mmcv.runner import OPTIMIZERS
from torch.optim import Optimizer
@OPTIMIZERS.register_module
class MyOptimizer(Optimizer):
def __init__(self, a, b, c)
```
Then add this module in `mmseg/core/optimizer/__init__.py` thus the registry will
find the new module and add it:
```python
from .my_optimizer import MyOptimizer
```
Then you can use `MyOptimizer` in `optimizer` field of config files.
In the configs, the optimizers are defined by the field `optimizer` like the following:
```python
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
```
To use your own optimizer, the field can be changed as
```python
optimizer = dict(type='MyOptimizer', a=a_value, b=b_value, c=c_value)
```
We already support to use all the optimizers implemented by PyTorch, and the only modification is to change the `optimizer` field of config files.
For example, if you want to use `ADAM`, though the performance will drop a lot, the modification could be as the following.
```python
optimizer = dict(type='Adam', lr=0.0003, weight_decay=0.0001)
```
The users can directly set arguments following the [API doc](https://pytorch.org/docs/stable/optim.html?highlight=optim#module-torch.optim) of PyTorch.
## Customize optimizer constructor
Some models may have some parameter-specific settings for optimization, e.g. weight decay for BatchNoarm layers.
The users can do those fine-grained parameter tuning through customizing optimizer constructor.
```
from mmcv.utils import build_from_cfg
from mmcv.runner import OPTIMIZER_BUILDERS
from .cocktail_optimizer import CocktailOptimizer
@OPTIMIZER_BUILDERS.register_module
class CocktailOptimizerConstructor(object):
def __init__(self, optimizer_cfg, paramwise_cfg=None):
def __call__(self, model):
return my_optimizer
```
## Develop new components
There are mainly 2 types of components in MMSegmentation.
- backbone: usually stacks of convolutional network to extract feature maps, e.g., ResNet, HRNet.
- head: the component for semantic segmentation map decoding.
### Add new backbones
Here we show how to develop new components with an example of MobileNet.
1. Create a new file `mmseg/models/backbones/mobilenet.py`.
```python
import torch.nn as nn
from ..registry import BACKBONES
@BACKBONES.register_module
class MobileNet(nn.Module):
def __init__(self, arg1, arg2):
pass
def forward(self, x): # should return a tuple
pass
def init_weights(self, pretrained=None):
pass
```
2. Import the module in `mmseg/models/backbones/__init__.py`.
```python
from .mobilenet import MobileNet
```
3. Use it in your config file.
```python
model = dict(
...
backbone=dict(
type='MobileNet',
arg1=xxx,
arg2=xxx),
...
```
### Add new heads
In MMSegmentation, we provide a base [BaseDecodeHead](https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/decode_heads/decode_head.py) for all segmentation head.
All newly implemented decode heads should be derived from it.
Here we show how to develop a new head with the example of [PSPNet](https://arxiv.org/abs/1612.01105) as the following.
First, add a new decode head in `mmseg/models/decode_heads/psp_head.py`.
PSPNet implements a decode head for segmentation decode.
To implement a decode head, basically we need to implement three functions of the new module as the following.
```python
@HEADS.register_module()
class PSPHead(BaseDecodeHead):
def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
super(PSPHead, self).__init__(**kwargs)
def init_weights(self):
def forward(self, inputs):
```
Next, the users need to add the module in the `mmseg/models/decode_heads/__init__.py` thus the corresponding registry could find and load them.
To config file of PSPNet is as the following
```python
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained='pretrain_model/resnet50_v1c_trick-2cccc1ad.pth',
backbone=dict(
type='ResNetV1c',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 2, 4),
strides=(1, 2, 1, 1),
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=True),
decode_head=dict(
type='PSPHead',
in_channels=2048,
in_index=3,
channels=512,
pool_scales=(1, 2, 3, 6),
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)))
```
### Add new loss
Assume you want to add a new loss as `MyLoss` for segmentation decode.
To add a new loss function, the users need implement it in `mmseg/models/losses/my_loss.py`.
The decorator `weighted_loss` enable the loss to be weighted for each element.
```python
import torch
import torch.nn as nn
from ..builder import LOSSES
from .utils import weighted_loss
@weighted_loss
def my_loss(pred, target):
assert pred.size() == target.size() and target.numel() > 0
loss = torch.abs(pred - target)
return loss
@LOSSES.register_module
class MyLoss(nn.Module):
def __init__(self, reduction='mean', loss_weight=1.0):
super(MyLoss, self).__init__()
self.reduction = reduction
self.loss_weight = loss_weight
def forward(self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None):
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
loss = self.loss_weight * my_loss(
pred, target, weight, reduction=reduction, avg_factor=avg_factor)
return loss
```
Then the users need to add it in the `mmseg/models/losses/__init__.py`.
```python
from .my_loss import MyLoss, my_loss
```
To use it, modify the `loss_xxx` field.
Then you need to modify the `loss_decode` field in the head.
`loss_weight` could be used to balance multiple losses.
```python
loss_decode=dict(type='MyLoss', loss_weight=1.0))
```
|