Feng Wang
commited on
Commit
·
0c109d5
1
Parent(s):
6bae5e0
feat(utils): freeze module (#1156)
Browse files- README.md +1 -0
- docs/freeze_module.md +37 -0
- tests/__init__.py +2 -0
- tests/utils/test_model_utils.py +107 -0
- yolox/core/trainer.py +6 -5
- yolox/exp/yolox_base.py +1 -0
- yolox/utils/model_utils.py +85 -6
README.md
CHANGED
@@ -188,6 +188,7 @@ python -m yolox.tools.eval -n yolox-s -c yolox_s.pth -b 1 -d 1 --conf 0.001 --f
|
|
188 |
|
189 |
* [Training on custom data](docs/train_custom_data.md)
|
190 |
* [Manipulating training image size](docs/manipulate_training_image_size.md)
|
|
|
191 |
|
192 |
</details>
|
193 |
|
|
|
188 |
|
189 |
* [Training on custom data](docs/train_custom_data.md)
|
190 |
* [Manipulating training image size](docs/manipulate_training_image_size.md)
|
191 |
+
* [Freezing model](docs/freeze_module.md)
|
192 |
|
193 |
</details>
|
194 |
|
docs/freeze_module.md
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Freeze module
|
2 |
+
|
3 |
+
This page guide users to freeze module in YOLOX.
|
4 |
+
Exp controls everything in YOLOX, so let's start from creating an Exp object.
|
5 |
+
|
6 |
+
## 1. Create your own expermiment object
|
7 |
+
|
8 |
+
We take an example of YOLOX-S model on COCO dataset to give a more clear guide.
|
9 |
+
|
10 |
+
Import the config you want (or write your own Exp object inherit from `yolox.exp.BaseExp`).
|
11 |
+
```python
|
12 |
+
from yolox.exp.default.yolox_s import Exp as MyExp
|
13 |
+
```
|
14 |
+
|
15 |
+
## 2. Override `get_model` method
|
16 |
+
|
17 |
+
Here is a simple code to freeze backbone (FPN not included) of module.
|
18 |
+
```python
|
19 |
+
class Exp(MyExp):
|
20 |
+
|
21 |
+
def get_model(self):
|
22 |
+
from yolox.utils import freeze_module
|
23 |
+
model = super().get_model()
|
24 |
+
freeze_module(model.backbone.backbone)
|
25 |
+
return model
|
26 |
+
```
|
27 |
+
if you only want to freeze FPN, `freeze_module(model.backbone)` might help.
|
28 |
+
|
29 |
+
## 3. Train
|
30 |
+
Suppose that the path of your Exp is `/path/to/my_exp.py`, use the following command to train your model.
|
31 |
+
```bash
|
32 |
+
python3 -m yolox.tools.train -f /path/to/my_exp.py
|
33 |
+
```
|
34 |
+
For more details of training, run the following command.
|
35 |
+
```bash
|
36 |
+
python3 -m yolox.tools.train --help
|
37 |
+
```
|
tests/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
tests/utils/test_model_utils.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
# Copyright (c) Megvii, Inc. and its affiliates.
|
4 |
+
|
5 |
+
import unittest
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch import nn
|
9 |
+
|
10 |
+
from yolox.utils import adjust_status, freeze_module
|
11 |
+
from yolox.exp import get_exp
|
12 |
+
|
13 |
+
|
14 |
+
class TestModelUtils(unittest.TestCase):
|
15 |
+
|
16 |
+
def setUp(self):
|
17 |
+
self.model: nn.Module = get_exp(exp_name="yolox-s").get_model()
|
18 |
+
|
19 |
+
def test_model_state_adjust_status(self):
|
20 |
+
data = torch.ones(1, 10, 10, 10)
|
21 |
+
# use bn since bn changes state during train/val
|
22 |
+
model = nn.BatchNorm2d(10)
|
23 |
+
prev_state = model.state_dict()
|
24 |
+
|
25 |
+
modes = [False, True]
|
26 |
+
results = [True, False]
|
27 |
+
|
28 |
+
# test under train/eval mode
|
29 |
+
for mode, result in zip(modes, results):
|
30 |
+
with adjust_status(model, training=mode):
|
31 |
+
model(data)
|
32 |
+
model_state = model.state_dict()
|
33 |
+
self.assertTrue(len(model_state) == len(prev_state))
|
34 |
+
self.assertEqual(
|
35 |
+
result,
|
36 |
+
all([torch.allclose(v, model_state[k]) for k, v in prev_state.items()])
|
37 |
+
)
|
38 |
+
|
39 |
+
# test recurrsive context case
|
40 |
+
prev_state = model.state_dict()
|
41 |
+
with adjust_status(model, training=False):
|
42 |
+
with adjust_status(model, training=False):
|
43 |
+
model(data)
|
44 |
+
model_state = model.state_dict()
|
45 |
+
self.assertTrue(len(model_state) == len(prev_state))
|
46 |
+
self.assertTrue(
|
47 |
+
all([torch.allclose(v, model_state[k]) for k, v in prev_state.items()])
|
48 |
+
)
|
49 |
+
|
50 |
+
def test_model_effect_adjust_status(self):
|
51 |
+
# test context effect
|
52 |
+
self.model.train()
|
53 |
+
with adjust_status(self.model, training=False):
|
54 |
+
for module in self.model.modules():
|
55 |
+
self.assertFalse(module.training)
|
56 |
+
# all training after exit
|
57 |
+
for module in self.model.modules():
|
58 |
+
self.assertTrue(module.training)
|
59 |
+
|
60 |
+
# only backbone set to eval
|
61 |
+
self.model.backbone.eval()
|
62 |
+
with adjust_status(self.model, training=False):
|
63 |
+
for module in self.model.modules():
|
64 |
+
self.assertFalse(module.training)
|
65 |
+
|
66 |
+
for name, module in self.model.named_modules():
|
67 |
+
if "backbone" in name:
|
68 |
+
self.assertFalse(module.training)
|
69 |
+
else:
|
70 |
+
self.assertTrue(module.training)
|
71 |
+
|
72 |
+
def test_freeze_module(self):
|
73 |
+
model = nn.Sequential(
|
74 |
+
nn.Conv2d(3, 10, 1),
|
75 |
+
nn.BatchNorm2d(10),
|
76 |
+
nn.ReLU(),
|
77 |
+
)
|
78 |
+
data = torch.rand(1, 3, 10, 10)
|
79 |
+
model.train()
|
80 |
+
assert isinstance(model[1], nn.BatchNorm2d)
|
81 |
+
before_states = model[1].state_dict()
|
82 |
+
freeze_module(model[1])
|
83 |
+
model(data)
|
84 |
+
after_states = model[1].state_dict()
|
85 |
+
self.assertTrue(
|
86 |
+
all([torch.allclose(v, after_states[k]) for k, v in before_states.items()])
|
87 |
+
)
|
88 |
+
|
89 |
+
# yolox test
|
90 |
+
self.model.train()
|
91 |
+
for module in self.model.modules():
|
92 |
+
self.assertTrue(module.training)
|
93 |
+
|
94 |
+
freeze_module(self.model, "backbone")
|
95 |
+
for module in self.model.backbone.modules():
|
96 |
+
self.assertFalse(module.training)
|
97 |
+
for p in self.model.backbone.parameters():
|
98 |
+
self.assertFalse(p.requires_grad)
|
99 |
+
|
100 |
+
for module in self.model.head.modules():
|
101 |
+
self.assertTrue(module.training)
|
102 |
+
for p in self.model.head.parameters():
|
103 |
+
self.assertTrue(p.requires_grad)
|
104 |
+
|
105 |
+
|
106 |
+
if __name__ == "__main__":
|
107 |
+
unittest.main()
|
yolox/core/trainer.py
CHANGED
@@ -16,6 +16,7 @@ from yolox.utils import (
|
|
16 |
MeterBuffer,
|
17 |
ModelEMA,
|
18 |
WandbLogger,
|
|
|
19 |
all_reduce_norm,
|
20 |
get_local_rank,
|
21 |
get_model_info,
|
@@ -169,7 +170,6 @@ class Trainer:
|
|
169 |
self.ema_model.updates = self.max_iter * self.start_epoch
|
170 |
|
171 |
self.model = model
|
172 |
-
self.model.train()
|
173 |
|
174 |
self.evaluator = self.exp.get_evaluator(
|
175 |
batch_size=self.args.batch_size, is_distributed=self.is_distributed
|
@@ -320,13 +320,14 @@ class Trainer:
|
|
320 |
if is_parallel(evalmodel):
|
321 |
evalmodel = evalmodel.module
|
322 |
|
323 |
-
|
324 |
-
|
325 |
-
|
|
|
|
|
326 |
update_best_ckpt = ap50_95 > self.best_ap
|
327 |
self.best_ap = max(self.best_ap, ap50_95)
|
328 |
|
329 |
-
self.model.train()
|
330 |
if self.rank == 0:
|
331 |
if self.args.logger == "tensorboard":
|
332 |
self.tblogger.add_scalar("val/COCOAP50", ap50, self.epoch + 1)
|
|
|
16 |
MeterBuffer,
|
17 |
ModelEMA,
|
18 |
WandbLogger,
|
19 |
+
adjust_status,
|
20 |
all_reduce_norm,
|
21 |
get_local_rank,
|
22 |
get_model_info,
|
|
|
170 |
self.ema_model.updates = self.max_iter * self.start_epoch
|
171 |
|
172 |
self.model = model
|
|
|
173 |
|
174 |
self.evaluator = self.exp.get_evaluator(
|
175 |
batch_size=self.args.batch_size, is_distributed=self.is_distributed
|
|
|
320 |
if is_parallel(evalmodel):
|
321 |
evalmodel = evalmodel.module
|
322 |
|
323 |
+
with adjust_status(evalmodel, training=False):
|
324 |
+
ap50_95, ap50, summary = self.exp.eval(
|
325 |
+
evalmodel, self.evaluator, self.is_distributed
|
326 |
+
)
|
327 |
+
|
328 |
update_best_ckpt = ap50_95 > self.best_ap
|
329 |
self.best_ap = max(self.best_ap, ap50_95)
|
330 |
|
|
|
331 |
if self.rank == 0:
|
332 |
if self.args.logger == "tensorboard":
|
333 |
self.tblogger.add_scalar("val/COCOAP50", ap50, self.epoch + 1)
|
yolox/exp/yolox_base.py
CHANGED
@@ -124,6 +124,7 @@ class Exp(BaseExp):
|
|
124 |
|
125 |
self.model.apply(init_yolo)
|
126 |
self.model.head.initialize_biases(1e-2)
|
|
|
127 |
return self.model
|
128 |
|
129 |
def get_data_loader(
|
|
|
124 |
|
125 |
self.model.apply(init_yolo)
|
126 |
self.model.head.initialize_biases(1e-2)
|
127 |
+
self.model.train()
|
128 |
return self.model
|
129 |
|
130 |
def get_data_loader(
|
yolox/utils/model_utils.py
CHANGED
@@ -2,7 +2,9 @@
|
|
2 |
# -*- coding:utf-8 -*-
|
3 |
# Copyright (c) Megvii Inc. All rights reserved.
|
4 |
|
|
|
5 |
from copy import deepcopy
|
|
|
6 |
|
7 |
import torch
|
8 |
import torch.nn as nn
|
@@ -13,11 +15,12 @@ __all__ = [
|
|
13 |
"fuse_model",
|
14 |
"get_model_info",
|
15 |
"replace_module",
|
|
|
|
|
16 |
]
|
17 |
|
18 |
|
19 |
-
def get_model_info(model, tsize):
|
20 |
-
|
21 |
stride = 64
|
22 |
img = torch.zeros((1, 3, stride, stride), device=next(model.parameters()).device)
|
23 |
flops, params = profile(deepcopy(model), inputs=(img,), verbose=False)
|
@@ -28,8 +31,18 @@ def get_model_info(model, tsize):
|
|
28 |
return info
|
29 |
|
30 |
|
31 |
-
def fuse_conv_and_bn(conv, bn):
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
fusedconv = (
|
34 |
nn.Conv2d(
|
35 |
conv.in_channels,
|
@@ -63,7 +76,15 @@ def fuse_conv_and_bn(conv, bn):
|
|
63 |
return fusedconv
|
64 |
|
65 |
|
66 |
-
def fuse_model(model):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
from yolox.models.network_blocks import BaseConv
|
68 |
|
69 |
for m in model.modules():
|
@@ -74,7 +95,7 @@ def fuse_model(model):
|
|
74 |
return model
|
75 |
|
76 |
|
77 |
-
def replace_module(module, replaced_module_type, new_module_type, replace_func=None):
|
78 |
"""
|
79 |
Replace given type in module to a new type. mostly used in deploy.
|
80 |
|
@@ -104,3 +125,61 @@ def replace_module(module, replaced_module_type, new_module_type, replace_func=N
|
|
104 |
model.add_module(name, new_child)
|
105 |
|
106 |
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
# -*- coding:utf-8 -*-
|
3 |
# Copyright (c) Megvii Inc. All rights reserved.
|
4 |
|
5 |
+
import contextlib
|
6 |
from copy import deepcopy
|
7 |
+
from typing import Sequence
|
8 |
|
9 |
import torch
|
10 |
import torch.nn as nn
|
|
|
15 |
"fuse_model",
|
16 |
"get_model_info",
|
17 |
"replace_module",
|
18 |
+
"freeze_module",
|
19 |
+
"adjust_status",
|
20 |
]
|
21 |
|
22 |
|
23 |
+
def get_model_info(model: nn.Module, tsize: Sequence[int]) -> str:
|
|
|
24 |
stride = 64
|
25 |
img = torch.zeros((1, 3, stride, stride), device=next(model.parameters()).device)
|
26 |
flops, params = profile(deepcopy(model), inputs=(img,), verbose=False)
|
|
|
31 |
return info
|
32 |
|
33 |
|
34 |
+
def fuse_conv_and_bn(conv: nn.Conv2d, bn: nn.BatchNorm2d) -> nn.Conv2d:
|
35 |
+
"""
|
36 |
+
Fuse convolution and batchnorm layers.
|
37 |
+
check more info on https://tehnokv.com/posts/fusing-batchnorm-and-conv/
|
38 |
+
|
39 |
+
Args:
|
40 |
+
conv (nn.Conv2d): convolution to fuse.
|
41 |
+
bn (nn.BatchNorm2d): batchnorm to fuse.
|
42 |
+
|
43 |
+
Returns:
|
44 |
+
nn.Conv2d: fused convolution behaves the same as the input conv and bn.
|
45 |
+
"""
|
46 |
fusedconv = (
|
47 |
nn.Conv2d(
|
48 |
conv.in_channels,
|
|
|
76 |
return fusedconv
|
77 |
|
78 |
|
79 |
+
def fuse_model(model: nn.Module) -> nn.Module:
|
80 |
+
"""fuse conv and bn in model
|
81 |
+
|
82 |
+
Args:
|
83 |
+
model (nn.Module): model to fuse
|
84 |
+
|
85 |
+
Returns:
|
86 |
+
nn.Module: fused model
|
87 |
+
"""
|
88 |
from yolox.models.network_blocks import BaseConv
|
89 |
|
90 |
for m in model.modules():
|
|
|
95 |
return model
|
96 |
|
97 |
|
98 |
+
def replace_module(module, replaced_module_type, new_module_type, replace_func=None) -> nn.Module:
|
99 |
"""
|
100 |
Replace given type in module to a new type. mostly used in deploy.
|
101 |
|
|
|
125 |
model.add_module(name, new_child)
|
126 |
|
127 |
return model
|
128 |
+
|
129 |
+
|
130 |
+
def freeze_module(module: nn.Module, name=None) -> nn.Module:
|
131 |
+
"""freeze module inplace
|
132 |
+
|
133 |
+
Args:
|
134 |
+
module (nn.Module): module to freeze.
|
135 |
+
name (str, optional): name to freeze. If not given, freeze the whole module.
|
136 |
+
Note that fuzzy match is not supported. Defaults to None.
|
137 |
+
|
138 |
+
Examples:
|
139 |
+
freeze the backbone of model
|
140 |
+
>>> freeze_moudle(model.backbone)
|
141 |
+
|
142 |
+
or freeze the backbone of model by name
|
143 |
+
>>> freeze_moudle(model, name="backbone")
|
144 |
+
"""
|
145 |
+
for param_name, parameter in module.named_parameters():
|
146 |
+
if name is None or name in param_name:
|
147 |
+
parameter.requires_grad = False
|
148 |
+
|
149 |
+
# ensure module like BN and dropout are freezed
|
150 |
+
for module_name, sub_module in module.named_modules():
|
151 |
+
# actually there are no needs to call eval for every single sub_module
|
152 |
+
if name is None or name in module_name:
|
153 |
+
sub_module.eval()
|
154 |
+
|
155 |
+
return module
|
156 |
+
|
157 |
+
|
158 |
+
@contextlib.contextmanager
|
159 |
+
def adjust_status(module: nn.Module, training: bool = False) -> nn.Module:
|
160 |
+
"""Adjust module to training/eval mode temporarily.
|
161 |
+
|
162 |
+
Args:
|
163 |
+
module (nn.Module): module to adjust status.
|
164 |
+
training (bool): training mode to set. True for train mode, False fro eval mode.
|
165 |
+
|
166 |
+
Examples:
|
167 |
+
>>> with adjust_status(model, training=False):
|
168 |
+
... model(data)
|
169 |
+
"""
|
170 |
+
status = {}
|
171 |
+
|
172 |
+
def backup_status(module):
|
173 |
+
for m in module.modules():
|
174 |
+
# save prev status to dict
|
175 |
+
status[m] = m.training
|
176 |
+
m.training = training
|
177 |
+
|
178 |
+
def recover_status(module):
|
179 |
+
for m in module.modules():
|
180 |
+
# recover prev status from dict
|
181 |
+
m.training = status.pop(m)
|
182 |
+
|
183 |
+
backup_status(module)
|
184 |
+
yield module
|
185 |
+
recover_status(module)
|