Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
# -*- coding:utf-8 -*- | |
# Copyright (c) Megvii, Inc. and its affiliates. | |
import unittest | |
import torch | |
from torch import nn | |
from yolox.utils import adjust_status, freeze_module | |
from yolox.exp import get_exp | |
class TestModelUtils(unittest.TestCase): | |
def setUp(self): | |
self.model: nn.Module = get_exp(exp_name="yolox-s").get_model() | |
def test_model_state_adjust_status(self): | |
data = torch.ones(1, 10, 10, 10) | |
# use bn since bn changes state during train/val | |
model = nn.BatchNorm2d(10) | |
prev_state = model.state_dict() | |
modes = [False, True] | |
results = [True, False] | |
# test under train/eval mode | |
for mode, result in zip(modes, results): | |
with adjust_status(model, training=mode): | |
model(data) | |
model_state = model.state_dict() | |
self.assertTrue(len(model_state) == len(prev_state)) | |
self.assertEqual( | |
result, | |
all([torch.allclose(v, model_state[k]) for k, v in prev_state.items()]) | |
) | |
# test recurrsive context case | |
prev_state = model.state_dict() | |
with adjust_status(model, training=False): | |
with adjust_status(model, training=False): | |
model(data) | |
model_state = model.state_dict() | |
self.assertTrue(len(model_state) == len(prev_state)) | |
self.assertTrue( | |
all([torch.allclose(v, model_state[k]) for k, v in prev_state.items()]) | |
) | |
def test_model_effect_adjust_status(self): | |
# test context effect | |
self.model.train() | |
with adjust_status(self.model, training=False): | |
for module in self.model.modules(): | |
self.assertFalse(module.training) | |
# all training after exit | |
for module in self.model.modules(): | |
self.assertTrue(module.training) | |
# only backbone set to eval | |
self.model.backbone.eval() | |
with adjust_status(self.model, training=False): | |
for module in self.model.modules(): | |
self.assertFalse(module.training) | |
for name, module in self.model.named_modules(): | |
if "backbone" in name: | |
self.assertFalse(module.training) | |
else: | |
self.assertTrue(module.training) | |
def test_freeze_module(self): | |
model = nn.Sequential( | |
nn.Conv2d(3, 10, 1), | |
nn.BatchNorm2d(10), | |
nn.ReLU(), | |
) | |
data = torch.rand(1, 3, 10, 10) | |
model.train() | |
assert isinstance(model[1], nn.BatchNorm2d) | |
before_states = model[1].state_dict() | |
freeze_module(model[1]) | |
model(data) | |
after_states = model[1].state_dict() | |
self.assertTrue( | |
all([torch.allclose(v, after_states[k]) for k, v in before_states.items()]) | |
) | |
# yolox test | |
self.model.train() | |
for module in self.model.modules(): | |
self.assertTrue(module.training) | |
freeze_module(self.model, "backbone") | |
for module in self.model.backbone.modules(): | |
self.assertFalse(module.training) | |
for p in self.model.backbone.parameters(): | |
self.assertFalse(p.requires_grad) | |
for module in self.model.head.modules(): | |
self.assertTrue(module.training) | |
for p in self.model.head.parameters(): | |
self.assertTrue(p.requires_grad) | |
if __name__ == "__main__": | |
unittest.main() | |