Spaces:
Runtime error
Runtime error
File size: 3,499 Bytes
0b7b08a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Copyright (c) Megvii, Inc. and its affiliates.
import unittest
import torch
from torch import nn
from yolox.utils import adjust_status, freeze_module
from yolox.exp import get_exp
class TestModelUtils(unittest.TestCase):
def setUp(self):
self.model: nn.Module = get_exp(exp_name="yolox-s").get_model()
def test_model_state_adjust_status(self):
data = torch.ones(1, 10, 10, 10)
# use bn since bn changes state during train/val
model = nn.BatchNorm2d(10)
prev_state = model.state_dict()
modes = [False, True]
results = [True, False]
# test under train/eval mode
for mode, result in zip(modes, results):
with adjust_status(model, training=mode):
model(data)
model_state = model.state_dict()
self.assertTrue(len(model_state) == len(prev_state))
self.assertEqual(
result,
all([torch.allclose(v, model_state[k]) for k, v in prev_state.items()])
)
# test recurrsive context case
prev_state = model.state_dict()
with adjust_status(model, training=False):
with adjust_status(model, training=False):
model(data)
model_state = model.state_dict()
self.assertTrue(len(model_state) == len(prev_state))
self.assertTrue(
all([torch.allclose(v, model_state[k]) for k, v in prev_state.items()])
)
def test_model_effect_adjust_status(self):
# test context effect
self.model.train()
with adjust_status(self.model, training=False):
for module in self.model.modules():
self.assertFalse(module.training)
# all training after exit
for module in self.model.modules():
self.assertTrue(module.training)
# only backbone set to eval
self.model.backbone.eval()
with adjust_status(self.model, training=False):
for module in self.model.modules():
self.assertFalse(module.training)
for name, module in self.model.named_modules():
if "backbone" in name:
self.assertFalse(module.training)
else:
self.assertTrue(module.training)
def test_freeze_module(self):
model = nn.Sequential(
nn.Conv2d(3, 10, 1),
nn.BatchNorm2d(10),
nn.ReLU(),
)
data = torch.rand(1, 3, 10, 10)
model.train()
assert isinstance(model[1], nn.BatchNorm2d)
before_states = model[1].state_dict()
freeze_module(model[1])
model(data)
after_states = model[1].state_dict()
self.assertTrue(
all([torch.allclose(v, after_states[k]) for k, v in before_states.items()])
)
# yolox test
self.model.train()
for module in self.model.modules():
self.assertTrue(module.training)
freeze_module(self.model, "backbone")
for module in self.model.backbone.modules():
self.assertFalse(module.training)
for p in self.model.backbone.parameters():
self.assertFalse(p.requires_grad)
for module in self.model.head.modules():
self.assertTrue(module.training)
for p in self.model.head.parameters():
self.assertTrue(p.requires_grad)
if __name__ == "__main__":
unittest.main()
|