Spaces:
Build error
Build error
# Copyright (c) OpenMMLab. All rights reserved. | |
import pytest | |
import torch | |
from mmpose.models import HourglassAENet, HourglassNet | |
def test_hourglass_backbone(): | |
with pytest.raises(AssertionError): | |
# HourglassNet's num_stacks should larger than 0 | |
HourglassNet(num_stacks=0) | |
with pytest.raises(AssertionError): | |
# len(stage_channels) should equal len(stage_blocks) | |
HourglassNet( | |
stage_channels=[256, 256, 384, 384, 384], | |
stage_blocks=[2, 2, 2, 2, 2, 4]) | |
with pytest.raises(AssertionError): | |
# len(stage_channels) should larger than downsample_times | |
HourglassNet( | |
downsample_times=5, | |
stage_channels=[256, 256, 384, 384, 384], | |
stage_blocks=[2, 2, 2, 2, 2]) | |
# Test HourglassNet-52 | |
model = HourglassNet(num_stacks=1) | |
model.init_weights() | |
model.train() | |
imgs = torch.randn(1, 3, 256, 256) | |
feat = model(imgs) | |
assert len(feat) == 1 | |
assert feat[0].shape == torch.Size([1, 256, 64, 64]) | |
# Test HourglassNet-104 | |
model = HourglassNet(num_stacks=2) | |
model.init_weights() | |
model.train() | |
imgs = torch.randn(1, 3, 256, 256) | |
feat = model(imgs) | |
assert len(feat) == 2 | |
assert feat[0].shape == torch.Size([1, 256, 64, 64]) | |
assert feat[1].shape == torch.Size([1, 256, 64, 64]) | |
def test_hourglass_ae_backbone(): | |
with pytest.raises(AssertionError): | |
# HourglassAENet's num_stacks should larger than 0 | |
HourglassAENet(num_stacks=0) | |
with pytest.raises(AssertionError): | |
# len(stage_channels) should larger than downsample_times | |
HourglassAENet( | |
downsample_times=5, stage_channels=[256, 256, 384, 384, 384]) | |
# num_stack=1 | |
model = HourglassAENet(num_stacks=1) | |
model.init_weights() | |
model.train() | |
imgs = torch.randn(1, 3, 256, 256) | |
feat = model(imgs) | |
assert len(feat) == 1 | |
assert feat[0].shape == torch.Size([1, 34, 64, 64]) | |
# num_stack=2 | |
model = HourglassAENet(num_stacks=2) | |
model.init_weights() | |
model.train() | |
imgs = torch.randn(1, 3, 256, 256) | |
feat = model(imgs) | |
assert len(feat) == 2 | |
assert feat[0].shape == torch.Size([1, 34, 64, 64]) | |
assert feat[1].shape == torch.Size([1, 34, 64, 64]) | |