Feng Wang commited on
Commit
0c109d5
·
1 Parent(s): 6bae5e0

feat(utils): freeze module (#1156)

Browse files
README.md CHANGED
@@ -188,6 +188,7 @@ python -m yolox.tools.eval -n yolox-s -c yolox_s.pth -b 1 -d 1 --conf 0.001 --f
188
 
189
  * [Training on custom data](docs/train_custom_data.md)
190
  * [Manipulating training image size](docs/manipulate_training_image_size.md)
 
191
 
192
  </details>
193
 
 
188
 
189
  * [Training on custom data](docs/train_custom_data.md)
190
  * [Manipulating training image size](docs/manipulate_training_image_size.md)
191
+ * [Freezing model](docs/freeze_module.md)
192
 
193
  </details>
194
 
docs/freeze_module.md ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Freeze module
2
+
3
+ This page guide users to freeze module in YOLOX.
4
+ Exp controls everything in YOLOX, so let's start from creating an Exp object.
5
+
6
+ ## 1. Create your own expermiment object
7
+
8
+ We take an example of YOLOX-S model on COCO dataset to give a more clear guide.
9
+
10
+ Import the config you want (or write your own Exp object inherit from `yolox.exp.BaseExp`).
11
+ ```python
12
+ from yolox.exp.default.yolox_s import Exp as MyExp
13
+ ```
14
+
15
+ ## 2. Override `get_model` method
16
+
17
+ Here is a simple code to freeze backbone (FPN not included) of module.
18
+ ```python
19
+ class Exp(MyExp):
20
+
21
+ def get_model(self):
22
+ from yolox.utils import freeze_module
23
+ model = super().get_model()
24
+ freeze_module(model.backbone.backbone)
25
+ return model
26
+ ```
27
+ if you only want to freeze FPN, `freeze_module(model.backbone)` might help.
28
+
29
+ ## 3. Train
30
+ Suppose that the path of your Exp is `/path/to/my_exp.py`, use the following command to train your model.
31
+ ```bash
32
+ python3 -m yolox.tools.train -f /path/to/my_exp.py
33
+ ```
34
+ For more details of training, run the following command.
35
+ ```bash
36
+ python3 -m yolox.tools.train --help
37
+ ```
tests/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding:utf-8 -*-
tests/utils/test_model_utils.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding:utf-8 -*-
3
+ # Copyright (c) Megvii, Inc. and its affiliates.
4
+
5
+ import unittest
6
+
7
+ import torch
8
+ from torch import nn
9
+
10
+ from yolox.utils import adjust_status, freeze_module
11
+ from yolox.exp import get_exp
12
+
13
+
14
+ class TestModelUtils(unittest.TestCase):
15
+
16
+ def setUp(self):
17
+ self.model: nn.Module = get_exp(exp_name="yolox-s").get_model()
18
+
19
+ def test_model_state_adjust_status(self):
20
+ data = torch.ones(1, 10, 10, 10)
21
+ # use bn since bn changes state during train/val
22
+ model = nn.BatchNorm2d(10)
23
+ prev_state = model.state_dict()
24
+
25
+ modes = [False, True]
26
+ results = [True, False]
27
+
28
+ # test under train/eval mode
29
+ for mode, result in zip(modes, results):
30
+ with adjust_status(model, training=mode):
31
+ model(data)
32
+ model_state = model.state_dict()
33
+ self.assertTrue(len(model_state) == len(prev_state))
34
+ self.assertEqual(
35
+ result,
36
+ all([torch.allclose(v, model_state[k]) for k, v in prev_state.items()])
37
+ )
38
+
39
+ # test recurrsive context case
40
+ prev_state = model.state_dict()
41
+ with adjust_status(model, training=False):
42
+ with adjust_status(model, training=False):
43
+ model(data)
44
+ model_state = model.state_dict()
45
+ self.assertTrue(len(model_state) == len(prev_state))
46
+ self.assertTrue(
47
+ all([torch.allclose(v, model_state[k]) for k, v in prev_state.items()])
48
+ )
49
+
50
+ def test_model_effect_adjust_status(self):
51
+ # test context effect
52
+ self.model.train()
53
+ with adjust_status(self.model, training=False):
54
+ for module in self.model.modules():
55
+ self.assertFalse(module.training)
56
+ # all training after exit
57
+ for module in self.model.modules():
58
+ self.assertTrue(module.training)
59
+
60
+ # only backbone set to eval
61
+ self.model.backbone.eval()
62
+ with adjust_status(self.model, training=False):
63
+ for module in self.model.modules():
64
+ self.assertFalse(module.training)
65
+
66
+ for name, module in self.model.named_modules():
67
+ if "backbone" in name:
68
+ self.assertFalse(module.training)
69
+ else:
70
+ self.assertTrue(module.training)
71
+
72
+ def test_freeze_module(self):
73
+ model = nn.Sequential(
74
+ nn.Conv2d(3, 10, 1),
75
+ nn.BatchNorm2d(10),
76
+ nn.ReLU(),
77
+ )
78
+ data = torch.rand(1, 3, 10, 10)
79
+ model.train()
80
+ assert isinstance(model[1], nn.BatchNorm2d)
81
+ before_states = model[1].state_dict()
82
+ freeze_module(model[1])
83
+ model(data)
84
+ after_states = model[1].state_dict()
85
+ self.assertTrue(
86
+ all([torch.allclose(v, after_states[k]) for k, v in before_states.items()])
87
+ )
88
+
89
+ # yolox test
90
+ self.model.train()
91
+ for module in self.model.modules():
92
+ self.assertTrue(module.training)
93
+
94
+ freeze_module(self.model, "backbone")
95
+ for module in self.model.backbone.modules():
96
+ self.assertFalse(module.training)
97
+ for p in self.model.backbone.parameters():
98
+ self.assertFalse(p.requires_grad)
99
+
100
+ for module in self.model.head.modules():
101
+ self.assertTrue(module.training)
102
+ for p in self.model.head.parameters():
103
+ self.assertTrue(p.requires_grad)
104
+
105
+
106
+ if __name__ == "__main__":
107
+ unittest.main()
yolox/core/trainer.py CHANGED
@@ -16,6 +16,7 @@ from yolox.utils import (
16
  MeterBuffer,
17
  ModelEMA,
18
  WandbLogger,
 
19
  all_reduce_norm,
20
  get_local_rank,
21
  get_model_info,
@@ -169,7 +170,6 @@ class Trainer:
169
  self.ema_model.updates = self.max_iter * self.start_epoch
170
 
171
  self.model = model
172
- self.model.train()
173
 
174
  self.evaluator = self.exp.get_evaluator(
175
  batch_size=self.args.batch_size, is_distributed=self.is_distributed
@@ -320,13 +320,14 @@ class Trainer:
320
  if is_parallel(evalmodel):
321
  evalmodel = evalmodel.module
322
 
323
- ap50_95, ap50, summary = self.exp.eval(
324
- evalmodel, self.evaluator, self.is_distributed
325
- )
 
 
326
  update_best_ckpt = ap50_95 > self.best_ap
327
  self.best_ap = max(self.best_ap, ap50_95)
328
 
329
- self.model.train()
330
  if self.rank == 0:
331
  if self.args.logger == "tensorboard":
332
  self.tblogger.add_scalar("val/COCOAP50", ap50, self.epoch + 1)
 
16
  MeterBuffer,
17
  ModelEMA,
18
  WandbLogger,
19
+ adjust_status,
20
  all_reduce_norm,
21
  get_local_rank,
22
  get_model_info,
 
170
  self.ema_model.updates = self.max_iter * self.start_epoch
171
 
172
  self.model = model
 
173
 
174
  self.evaluator = self.exp.get_evaluator(
175
  batch_size=self.args.batch_size, is_distributed=self.is_distributed
 
320
  if is_parallel(evalmodel):
321
  evalmodel = evalmodel.module
322
 
323
+ with adjust_status(evalmodel, training=False):
324
+ ap50_95, ap50, summary = self.exp.eval(
325
+ evalmodel, self.evaluator, self.is_distributed
326
+ )
327
+
328
  update_best_ckpt = ap50_95 > self.best_ap
329
  self.best_ap = max(self.best_ap, ap50_95)
330
 
 
331
  if self.rank == 0:
332
  if self.args.logger == "tensorboard":
333
  self.tblogger.add_scalar("val/COCOAP50", ap50, self.epoch + 1)
yolox/exp/yolox_base.py CHANGED
@@ -124,6 +124,7 @@ class Exp(BaseExp):
124
 
125
  self.model.apply(init_yolo)
126
  self.model.head.initialize_biases(1e-2)
 
127
  return self.model
128
 
129
  def get_data_loader(
 
124
 
125
  self.model.apply(init_yolo)
126
  self.model.head.initialize_biases(1e-2)
127
+ self.model.train()
128
  return self.model
129
 
130
  def get_data_loader(
yolox/utils/model_utils.py CHANGED
@@ -2,7 +2,9 @@
2
  # -*- coding:utf-8 -*-
3
  # Copyright (c) Megvii Inc. All rights reserved.
4
 
 
5
  from copy import deepcopy
 
6
 
7
  import torch
8
  import torch.nn as nn
@@ -13,11 +15,12 @@ __all__ = [
13
  "fuse_model",
14
  "get_model_info",
15
  "replace_module",
 
 
16
  ]
17
 
18
 
19
- def get_model_info(model, tsize):
20
-
21
  stride = 64
22
  img = torch.zeros((1, 3, stride, stride), device=next(model.parameters()).device)
23
  flops, params = profile(deepcopy(model), inputs=(img,), verbose=False)
@@ -28,8 +31,18 @@ def get_model_info(model, tsize):
28
  return info
29
 
30
 
31
- def fuse_conv_and_bn(conv, bn):
32
- # Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
 
 
 
 
 
 
 
 
 
 
33
  fusedconv = (
34
  nn.Conv2d(
35
  conv.in_channels,
@@ -63,7 +76,15 @@ def fuse_conv_and_bn(conv, bn):
63
  return fusedconv
64
 
65
 
66
- def fuse_model(model):
 
 
 
 
 
 
 
 
67
  from yolox.models.network_blocks import BaseConv
68
 
69
  for m in model.modules():
@@ -74,7 +95,7 @@ def fuse_model(model):
74
  return model
75
 
76
 
77
- def replace_module(module, replaced_module_type, new_module_type, replace_func=None):
78
  """
79
  Replace given type in module to a new type. mostly used in deploy.
80
 
@@ -104,3 +125,61 @@ def replace_module(module, replaced_module_type, new_module_type, replace_func=N
104
  model.add_module(name, new_child)
105
 
106
  return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  # -*- coding:utf-8 -*-
3
  # Copyright (c) Megvii Inc. All rights reserved.
4
 
5
+ import contextlib
6
  from copy import deepcopy
7
+ from typing import Sequence
8
 
9
  import torch
10
  import torch.nn as nn
 
15
  "fuse_model",
16
  "get_model_info",
17
  "replace_module",
18
+ "freeze_module",
19
+ "adjust_status",
20
  ]
21
 
22
 
23
+ def get_model_info(model: nn.Module, tsize: Sequence[int]) -> str:
 
24
  stride = 64
25
  img = torch.zeros((1, 3, stride, stride), device=next(model.parameters()).device)
26
  flops, params = profile(deepcopy(model), inputs=(img,), verbose=False)
 
31
  return info
32
 
33
 
34
+ def fuse_conv_and_bn(conv: nn.Conv2d, bn: nn.BatchNorm2d) -> nn.Conv2d:
35
+ """
36
+ Fuse convolution and batchnorm layers.
37
+ check more info on https://tehnokv.com/posts/fusing-batchnorm-and-conv/
38
+
39
+ Args:
40
+ conv (nn.Conv2d): convolution to fuse.
41
+ bn (nn.BatchNorm2d): batchnorm to fuse.
42
+
43
+ Returns:
44
+ nn.Conv2d: fused convolution behaves the same as the input conv and bn.
45
+ """
46
  fusedconv = (
47
  nn.Conv2d(
48
  conv.in_channels,
 
76
  return fusedconv
77
 
78
 
79
+ def fuse_model(model: nn.Module) -> nn.Module:
80
+ """fuse conv and bn in model
81
+
82
+ Args:
83
+ model (nn.Module): model to fuse
84
+
85
+ Returns:
86
+ nn.Module: fused model
87
+ """
88
  from yolox.models.network_blocks import BaseConv
89
 
90
  for m in model.modules():
 
95
  return model
96
 
97
 
98
+ def replace_module(module, replaced_module_type, new_module_type, replace_func=None) -> nn.Module:
99
  """
100
  Replace given type in module to a new type. mostly used in deploy.
101
 
 
125
  model.add_module(name, new_child)
126
 
127
  return model
128
+
129
+
130
+ def freeze_module(module: nn.Module, name=None) -> nn.Module:
131
+ """freeze module inplace
132
+
133
+ Args:
134
+ module (nn.Module): module to freeze.
135
+ name (str, optional): name to freeze. If not given, freeze the whole module.
136
+ Note that fuzzy match is not supported. Defaults to None.
137
+
138
+ Examples:
139
+ freeze the backbone of model
140
+ >>> freeze_moudle(model.backbone)
141
+
142
+ or freeze the backbone of model by name
143
+ >>> freeze_moudle(model, name="backbone")
144
+ """
145
+ for param_name, parameter in module.named_parameters():
146
+ if name is None or name in param_name:
147
+ parameter.requires_grad = False
148
+
149
+ # ensure module like BN and dropout are freezed
150
+ for module_name, sub_module in module.named_modules():
151
+ # actually there are no needs to call eval for every single sub_module
152
+ if name is None or name in module_name:
153
+ sub_module.eval()
154
+
155
+ return module
156
+
157
+
158
+ @contextlib.contextmanager
159
+ def adjust_status(module: nn.Module, training: bool = False) -> nn.Module:
160
+ """Adjust module to training/eval mode temporarily.
161
+
162
+ Args:
163
+ module (nn.Module): module to adjust status.
164
+ training (bool): training mode to set. True for train mode, False fro eval mode.
165
+
166
+ Examples:
167
+ >>> with adjust_status(model, training=False):
168
+ ... model(data)
169
+ """
170
+ status = {}
171
+
172
+ def backup_status(module):
173
+ for m in module.modules():
174
+ # save prev status to dict
175
+ status[m] = m.training
176
+ m.training = training
177
+
178
+ def recover_status(module):
179
+ for m in module.modules():
180
+ # recover prev status from dict
181
+ m.training = status.pop(m)
182
+
183
+ backup_status(module)
184
+ yield module
185
+ recover_status(module)