Spaces:
Build error
Build error
File size: 2,512 Bytes
d7a991a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import pytest
import torch
from mmpose.models import HMRMeshHead
from mmpose.models.misc.discriminator import SMPLDiscriminator
def test_mesh_hmr_head():
"""Test hmr mesh head."""
head = HMRMeshHead(in_channels=512)
head.init_weights()
input_shape = (1, 512, 8, 8)
inputs = _demo_inputs(input_shape)
out = head(inputs)
smpl_rotmat, smpl_shape, camera = out
assert smpl_rotmat.shape == torch.Size([1, 24, 3, 3])
assert smpl_shape.shape == torch.Size([1, 10])
assert camera.shape == torch.Size([1, 3])
"""Test hmr mesh head with assigned mean parameters and n_iter """
head = HMRMeshHead(
in_channels=512,
smpl_mean_params='tests/data/smpl/smpl_mean_params.npz',
n_iter=3)
head.init_weights()
input_shape = (1, 512, 8, 8)
inputs = _demo_inputs(input_shape)
out = head(inputs)
smpl_rotmat, smpl_shape, camera = out
assert smpl_rotmat.shape == torch.Size([1, 24, 3, 3])
assert smpl_shape.shape == torch.Size([1, 10])
assert camera.shape == torch.Size([1, 3])
# test discriminator with SMPL pose parameters
# in rotation matrix representation
disc = SMPLDiscriminator(
beta_channel=(10, 10, 5, 1),
per_joint_channel=(9, 32, 32, 16, 1),
full_pose_channel=(23 * 16, 256, 1))
pred_theta = (camera, smpl_rotmat, smpl_shape)
pred_score = disc(pred_theta)
assert pred_score.shape[1] == 25
# test discriminator with SMPL pose parameters
# in axis-angle representation
pred_theta = (camera, camera.new_zeros([1, 72]), smpl_shape)
pred_score = disc(pred_theta)
assert pred_score.shape[1] == 25
with pytest.raises(TypeError):
_ = SMPLDiscriminator(
beta_channel=[10, 10, 5, 1],
per_joint_channel=(9, 32, 32, 16, 1),
full_pose_channel=(23 * 16, 256, 1))
with pytest.raises(ValueError):
_ = SMPLDiscriminator(
beta_channel=(10, ),
per_joint_channel=(9, 32, 32, 16, 1),
full_pose_channel=(23 * 16, 256, 1))
def _demo_inputs(input_shape=(1, 3, 64, 64)):
"""Create a superset of inputs needed to run mesh head.
Args:
input_shape (tuple): input batch dimensions.
Default: (1, 3, 64, 64).
Returns:
Random input tensor with the size of input_shape.
"""
inps = np.random.random(input_shape)
inps = torch.FloatTensor(inps)
return inps
|