Spaces:
Build error
Build error
# Copyright (c) OpenMMLab. All rights reserved. | |
import numpy as np | |
import pytest | |
import torch | |
from mmpose.models import HMRMeshHead | |
from mmpose.models.misc.discriminator import SMPLDiscriminator | |
def test_mesh_hmr_head(): | |
"""Test hmr mesh head.""" | |
head = HMRMeshHead(in_channels=512) | |
head.init_weights() | |
input_shape = (1, 512, 8, 8) | |
inputs = _demo_inputs(input_shape) | |
out = head(inputs) | |
smpl_rotmat, smpl_shape, camera = out | |
assert smpl_rotmat.shape == torch.Size([1, 24, 3, 3]) | |
assert smpl_shape.shape == torch.Size([1, 10]) | |
assert camera.shape == torch.Size([1, 3]) | |
"""Test hmr mesh head with assigned mean parameters and n_iter """ | |
head = HMRMeshHead( | |
in_channels=512, | |
smpl_mean_params='tests/data/smpl/smpl_mean_params.npz', | |
n_iter=3) | |
head.init_weights() | |
input_shape = (1, 512, 8, 8) | |
inputs = _demo_inputs(input_shape) | |
out = head(inputs) | |
smpl_rotmat, smpl_shape, camera = out | |
assert smpl_rotmat.shape == torch.Size([1, 24, 3, 3]) | |
assert smpl_shape.shape == torch.Size([1, 10]) | |
assert camera.shape == torch.Size([1, 3]) | |
# test discriminator with SMPL pose parameters | |
# in rotation matrix representation | |
disc = SMPLDiscriminator( | |
beta_channel=(10, 10, 5, 1), | |
per_joint_channel=(9, 32, 32, 16, 1), | |
full_pose_channel=(23 * 16, 256, 1)) | |
pred_theta = (camera, smpl_rotmat, smpl_shape) | |
pred_score = disc(pred_theta) | |
assert pred_score.shape[1] == 25 | |
# test discriminator with SMPL pose parameters | |
# in axis-angle representation | |
pred_theta = (camera, camera.new_zeros([1, 72]), smpl_shape) | |
pred_score = disc(pred_theta) | |
assert pred_score.shape[1] == 25 | |
with pytest.raises(TypeError): | |
_ = SMPLDiscriminator( | |
beta_channel=[10, 10, 5, 1], | |
per_joint_channel=(9, 32, 32, 16, 1), | |
full_pose_channel=(23 * 16, 256, 1)) | |
with pytest.raises(ValueError): | |
_ = SMPLDiscriminator( | |
beta_channel=(10, ), | |
per_joint_channel=(9, 32, 32, 16, 1), | |
full_pose_channel=(23 * 16, 256, 1)) | |
def _demo_inputs(input_shape=(1, 3, 64, 64)): | |
"""Create a superset of inputs needed to run mesh head. | |
Args: | |
input_shape (tuple): input batch dimensions. | |
Default: (1, 3, 64, 64). | |
Returns: | |
Random input tensor with the size of input_shape. | |
""" | |
inps = np.random.random(input_shape) | |
inps = torch.FloatTensor(inps) | |
return inps | |