Spaces:
Build error
Build error
# Copyright (c) OpenMMLab. All rights reserved. | |
import numpy as np | |
import pytest | |
import torch | |
import torch.nn as nn | |
from mmpose.models.backbones import TCN | |
from mmpose.models.backbones.tcn import BasicTemporalBlock | |
def test_basic_temporal_block(): | |
with pytest.raises(AssertionError): | |
# padding( + shift) should not be larger than x.shape[2] | |
block = BasicTemporalBlock(1024, 1024, dilation=81) | |
x = torch.rand(2, 1024, 150) | |
x_out = block(x) | |
with pytest.raises(AssertionError): | |
# when use_stride_conv is True, shift + kernel_size // 2 should | |
# not be larger than x.shape[2] | |
block = BasicTemporalBlock( | |
1024, 1024, kernel_size=5, causal=True, use_stride_conv=True) | |
x = torch.rand(2, 1024, 3) | |
x_out = block(x) | |
# BasicTemporalBlock with causal == False | |
block = BasicTemporalBlock(1024, 1024) | |
x = torch.rand(2, 1024, 241) | |
x_out = block(x) | |
assert x_out.shape == torch.Size([2, 1024, 235]) | |
# BasicTemporalBlock with causal == True | |
block = BasicTemporalBlock(1024, 1024, causal=True) | |
x = torch.rand(2, 1024, 241) | |
x_out = block(x) | |
assert x_out.shape == torch.Size([2, 1024, 235]) | |
# BasicTemporalBlock with residual == False | |
block = BasicTemporalBlock(1024, 1024, residual=False) | |
x = torch.rand(2, 1024, 241) | |
x_out = block(x) | |
assert x_out.shape == torch.Size([2, 1024, 235]) | |
# BasicTemporalBlock, use_stride_conv == True | |
block = BasicTemporalBlock(1024, 1024, use_stride_conv=True) | |
x = torch.rand(2, 1024, 81) | |
x_out = block(x) | |
assert x_out.shape == torch.Size([2, 1024, 27]) | |
# BasicTemporalBlock with use_stride_conv == True and causal == True | |
block = BasicTemporalBlock(1024, 1024, use_stride_conv=True, causal=True) | |
x = torch.rand(2, 1024, 81) | |
x_out = block(x) | |
assert x_out.shape == torch.Size([2, 1024, 27]) | |
def test_tcn_backbone(): | |
with pytest.raises(AssertionError): | |
# num_blocks should equal len(kernel_sizes) - 1 | |
TCN(in_channels=34, num_blocks=3, kernel_sizes=(3, 3, 3)) | |
with pytest.raises(AssertionError): | |
# kernel size should be odd | |
TCN(in_channels=34, kernel_sizes=(3, 4, 3)) | |
# Test TCN with 2 blocks (use_stride_conv == False) | |
model = TCN(in_channels=34, num_blocks=2, kernel_sizes=(3, 3, 3)) | |
pose2d = torch.rand((2, 34, 243)) | |
feat = model(pose2d) | |
assert len(feat) == 2 | |
assert feat[0].shape == (2, 1024, 235) | |
assert feat[1].shape == (2, 1024, 217) | |
# Test TCN with 4 blocks and weight norm clip | |
max_norm = 0.1 | |
model = TCN( | |
in_channels=34, | |
num_blocks=4, | |
kernel_sizes=(3, 3, 3, 3, 3), | |
max_norm=max_norm) | |
pose2d = torch.rand((2, 34, 243)) | |
feat = model(pose2d) | |
assert len(feat) == 4 | |
assert feat[0].shape == (2, 1024, 235) | |
assert feat[1].shape == (2, 1024, 217) | |
assert feat[2].shape == (2, 1024, 163) | |
assert feat[3].shape == (2, 1024, 1) | |
for module in model.modules(): | |
if isinstance(module, torch.nn.modules.conv._ConvNd): | |
norm = module.weight.norm().item() | |
np.testing.assert_allclose( | |
np.maximum(norm, max_norm), max_norm, rtol=1e-4) | |
# Test TCN with 4 blocks (use_stride_conv == True) | |
model = TCN( | |
in_channels=34, | |
num_blocks=4, | |
kernel_sizes=(3, 3, 3, 3, 3), | |
use_stride_conv=True) | |
pose2d = torch.rand((2, 34, 243)) | |
feat = model(pose2d) | |
assert len(feat) == 4 | |
assert feat[0].shape == (2, 1024, 27) | |
assert feat[1].shape == (2, 1024, 9) | |
assert feat[2].shape == (2, 1024, 3) | |
assert feat[3].shape == (2, 1024, 1) | |
# Check that the model w. or w/o use_stride_conv will have the same | |
# output and gradient after a forward+backward pass | |
model1 = TCN( | |
in_channels=34, | |
stem_channels=4, | |
num_blocks=1, | |
kernel_sizes=(3, 3), | |
dropout=0, | |
residual=False, | |
norm_cfg=None) | |
model2 = TCN( | |
in_channels=34, | |
stem_channels=4, | |
num_blocks=1, | |
kernel_sizes=(3, 3), | |
dropout=0, | |
residual=False, | |
norm_cfg=None, | |
use_stride_conv=True) | |
for m in model1.modules(): | |
if isinstance(m, nn.Conv1d): | |
nn.init.constant_(m.weight, 0.5) | |
if m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
for m in model2.modules(): | |
if isinstance(m, nn.Conv1d): | |
nn.init.constant_(m.weight, 0.5) | |
if m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
input1 = torch.rand((1, 34, 9)) | |
input2 = input1.clone() | |
outputs1 = model1(input1) | |
outputs2 = model2(input2) | |
for output1, output2 in zip(outputs1, outputs2): | |
assert torch.isclose(output1, output2).all() | |
criterion = nn.MSELoss() | |
target = torch.rand(output1.shape) | |
loss1 = criterion(output1, target) | |
loss2 = criterion(output2, target) | |
loss1.backward() | |
loss2.backward() | |
for m1, m2 in zip(model1.modules(), model2.modules()): | |
if isinstance(m1, nn.Conv1d): | |
assert torch.isclose(m1.weight.grad, m2.weight.grad).all() | |