Spaces:
Build error
Build error
File size: 3,994 Bytes
d7a991a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from torch.nn.modules.batchnorm import _BatchNorm
from mmpose.models.backbones import LiteHRNet
from mmpose.models.backbones.litehrnet import LiteHRModule
from mmpose.models.backbones.resnet import Bottleneck
def is_norm(modules):
"""Check if is one of the norms."""
if isinstance(modules, (_BatchNorm, )):
return True
return False
def all_zeros(modules):
"""Check if the weight(and bias) is all zero."""
weight_zero = torch.equal(modules.weight.data,
torch.zeros_like(modules.weight.data))
if hasattr(modules, 'bias'):
bias_zero = torch.equal(modules.bias.data,
torch.zeros_like(modules.bias.data))
else:
bias_zero = True
return weight_zero and bias_zero
def test_litehrmodule():
# Test LiteHRModule forward
block = LiteHRModule(
num_branches=1,
num_blocks=1,
in_channels=[
40,
],
reduce_ratio=8,
module_type='LITE')
x = torch.randn(2, 40, 56, 56)
x_out = block([[x]])
assert x_out[0][0].shape == torch.Size([2, 40, 56, 56])
block = LiteHRModule(
num_branches=1,
num_blocks=1,
in_channels=[
40,
],
reduce_ratio=8,
module_type='NAIVE')
x = torch.randn(2, 40, 56, 56)
x_out = block([x])
assert x_out[0].shape == torch.Size([2, 40, 56, 56])
with pytest.raises(ValueError):
block = LiteHRModule(
num_branches=1,
num_blocks=1,
in_channels=[
40,
],
reduce_ratio=8,
module_type='none')
def test_litehrnet_backbone():
extra = dict(
stem=dict(stem_channels=32, out_channels=32, expand_ratio=1),
num_stages=3,
stages_spec=dict(
num_modules=(2, 4, 2),
num_branches=(2, 3, 4),
num_blocks=(2, 2, 2),
module_type=('LITE', 'LITE', 'LITE'),
with_fuse=(True, True, True),
reduce_ratios=(8, 8, 8),
num_channels=(
(40, 80),
(40, 80, 160),
(40, 80, 160, 320),
)),
with_head=True)
model = LiteHRNet(extra, in_channels=3)
imgs = torch.randn(2, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 1
assert feat[0].shape == torch.Size([2, 40, 56, 56])
# Test HRNet zero initialization of residual
model = LiteHRNet(extra, in_channels=3)
model.init_weights()
for m in model.modules():
if isinstance(m, Bottleneck):
assert all_zeros(m.norm3)
model.train()
imgs = torch.randn(2, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 1
assert feat[0].shape == torch.Size([2, 40, 56, 56])
extra = dict(
stem=dict(stem_channels=32, out_channels=32, expand_ratio=1),
num_stages=3,
stages_spec=dict(
num_modules=(2, 4, 2),
num_branches=(2, 3, 4),
num_blocks=(2, 2, 2),
module_type=('NAIVE', 'NAIVE', 'NAIVE'),
with_fuse=(True, True, True),
reduce_ratios=(8, 8, 8),
num_channels=(
(40, 80),
(40, 80, 160),
(40, 80, 160, 320),
)),
with_head=True)
model = LiteHRNet(extra, in_channels=3)
imgs = torch.randn(2, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 1
assert feat[0].shape == torch.Size([2, 40, 56, 56])
# Test HRNet zero initialization of residual
model = LiteHRNet(extra, in_channels=3)
model.init_weights()
for m in model.modules():
if isinstance(m, Bottleneck):
assert all_zeros(m.norm3)
model.train()
imgs = torch.randn(2, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 1
assert feat[0].shape == torch.Size([2, 40, 56, 56])
|