Spaces:
Build error
Build error
# Copyright (c) OpenMMLab. All rights reserved. | |
import cv2 | |
import mmcv | |
import numpy as np | |
import torch | |
from mmpose.core.visualization.image import imshow_mesh_3d | |
from mmpose.models.misc.discriminator import SMPLDiscriminator | |
from .. import builder | |
from ..builder import POSENETS | |
from .base import BasePose | |
def set_requires_grad(nets, requires_grad=False): | |
"""Set requies_grad for all the networks. | |
Args: | |
nets (nn.Module | list[nn.Module]): A list of networks or a single | |
network. | |
requires_grad (bool): Whether the networks require gradients or not | |
""" | |
if not isinstance(nets, list): | |
nets = [nets] | |
for net in nets: | |
if net is not None: | |
for param in net.parameters(): | |
param.requires_grad = requires_grad | |
class ParametricMesh(BasePose): | |
"""Model-based 3D human mesh detector. Take a single color image as input | |
and output 3D joints, SMPL parameters and camera parameters. | |
Args: | |
backbone (dict): Backbone modules to extract feature. | |
mesh_head (dict): Mesh head to process feature. | |
smpl (dict): Config for SMPL model. | |
disc (dict): Discriminator for SMPL parameters. Default: None. | |
loss_gan (dict): Config for adversarial loss. Default: None. | |
loss_mesh (dict): Config for mesh loss. Default: None. | |
train_cfg (dict): Config for training. Default: None. | |
test_cfg (dict): Config for testing. Default: None. | |
pretrained (str): Path to the pretrained models. | |
""" | |
def __init__(self, | |
backbone, | |
mesh_head, | |
smpl, | |
disc=None, | |
loss_gan=None, | |
loss_mesh=None, | |
train_cfg=None, | |
test_cfg=None, | |
pretrained=None): | |
super().__init__() | |
self.backbone = builder.build_backbone(backbone) | |
self.mesh_head = builder.build_head(mesh_head) | |
self.generator = torch.nn.Sequential(self.backbone, self.mesh_head) | |
self.smpl = builder.build_mesh_model(smpl) | |
self.with_gan = disc is not None and loss_gan is not None | |
if self.with_gan: | |
self.discriminator = SMPLDiscriminator(**disc) | |
self.loss_gan = builder.build_loss(loss_gan) | |
self.disc_step_count = 0 | |
self.train_cfg = train_cfg | |
self.test_cfg = test_cfg | |
self.loss_mesh = builder.build_loss(loss_mesh) | |
self.init_weights(pretrained=pretrained) | |
def init_weights(self, pretrained=None): | |
"""Weight initialization for model.""" | |
self.backbone.init_weights(pretrained) | |
self.mesh_head.init_weights() | |
if self.with_gan: | |
self.discriminator.init_weights() | |
def train_step(self, data_batch, optimizer, **kwargs): | |
"""Train step function. | |
In this function, the detector will finish the train step following | |
the pipeline: | |
1. get fake and real SMPL parameters | |
2. optimize discriminator (if have) | |
3. optimize generator | |
If `self.train_cfg.disc_step > 1`, the train step will contain multiple | |
iterations for optimizing discriminator with different input data and | |
only one iteration for optimizing generator after `disc_step` | |
iterations for discriminator. | |
Args: | |
data_batch (torch.Tensor): Batch of data as input. | |
optimizer (dict[torch.optim.Optimizer]): Dict with optimizers for | |
generator and discriminator (if have). | |
Returns: | |
outputs (dict): Dict with loss, information for logger, | |
the number of samples. | |
""" | |
img = data_batch['img'] | |
pred_smpl = self.generator(img) | |
pred_pose, pred_beta, pred_camera = pred_smpl | |
# optimize discriminator (if have) | |
if self.train_cfg['disc_step'] > 0 and self.with_gan: | |
set_requires_grad(self.discriminator, True) | |
fake_data = (pred_camera.detach(), pred_pose.detach(), | |
pred_beta.detach()) | |
mosh_theta = data_batch['mosh_theta'] | |
real_data = (mosh_theta[:, :3], mosh_theta[:, | |
3:75], mosh_theta[:, | |
75:]) | |
fake_score = self.discriminator(fake_data) | |
real_score = self.discriminator(real_data) | |
disc_losses = {} | |
disc_losses['real_loss'] = self.loss_gan( | |
real_score, target_is_real=True, is_disc=True) | |
disc_losses['fake_loss'] = self.loss_gan( | |
fake_score, target_is_real=False, is_disc=True) | |
loss_disc, log_vars_d = self._parse_losses(disc_losses) | |
optimizer['discriminator'].zero_grad() | |
loss_disc.backward() | |
optimizer['discriminator'].step() | |
self.disc_step_count = \ | |
(self.disc_step_count + 1) % self.train_cfg['disc_step'] | |
if self.disc_step_count != 0: | |
outputs = dict( | |
loss=loss_disc, | |
log_vars=log_vars_d, | |
num_samples=len(next(iter(data_batch.values())))) | |
return outputs | |
# optimize generator | |
pred_out = self.smpl( | |
betas=pred_beta, | |
body_pose=pred_pose[:, 1:], | |
global_orient=pred_pose[:, :1]) | |
pred_vertices, pred_joints_3d = pred_out['vertices'], pred_out[ | |
'joints'] | |
gt_beta = data_batch['beta'] | |
gt_pose = data_batch['pose'] | |
gt_vertices = self.smpl( | |
betas=gt_beta, | |
body_pose=gt_pose[:, 3:], | |
global_orient=gt_pose[:, :3])['vertices'] | |
pred = dict( | |
pose=pred_pose, | |
beta=pred_beta, | |
camera=pred_camera, | |
vertices=pred_vertices, | |
joints_3d=pred_joints_3d) | |
target = { | |
key: data_batch[key] | |
for key in [ | |
'pose', 'beta', 'has_smpl', 'joints_3d', 'joints_2d', | |
'joints_3d_visible', 'joints_2d_visible' | |
] | |
} | |
target['vertices'] = gt_vertices | |
losses = self.loss_mesh(pred, target) | |
if self.with_gan: | |
set_requires_grad(self.discriminator, False) | |
pred_theta = (pred_camera, pred_pose, pred_beta) | |
pred_score = self.discriminator(pred_theta) | |
loss_adv = self.loss_gan( | |
pred_score, target_is_real=True, is_disc=False) | |
losses['adv_loss'] = loss_adv | |
loss, log_vars = self._parse_losses(losses) | |
optimizer['generator'].zero_grad() | |
loss.backward() | |
optimizer['generator'].step() | |
outputs = dict( | |
loss=loss, | |
log_vars=log_vars, | |
num_samples=len(next(iter(data_batch.values())))) | |
return outputs | |
def forward_train(self, *args, **kwargs): | |
"""Forward function for training. | |
For ParametricMesh, we do not use this interface. | |
""" | |
raise NotImplementedError('This interface should not be used in ' | |
'current training schedule. Please use ' | |
'`train_step` for training.') | |
def val_step(self, data_batch, **kwargs): | |
"""Forward function for evaluation. | |
Args: | |
data_batch (dict): Contain data for forward. | |
Returns: | |
dict: Contain the results from model. | |
""" | |
output = self.forward_test(**data_batch, **kwargs) | |
return output | |
def forward_dummy(self, img): | |
"""Used for computing network FLOPs. | |
See ``tools/get_flops.py``. | |
Args: | |
img (torch.Tensor): Input image. | |
Returns: | |
Tensor: Outputs. | |
""" | |
output = self.generator(img) | |
return output | |
def forward_test(self, | |
img, | |
img_metas, | |
return_vertices=False, | |
return_faces=False, | |
**kwargs): | |
"""Defines the computation performed at every call when testing.""" | |
pred_smpl = self.generator(img) | |
pred_pose, pred_beta, pred_camera = pred_smpl | |
pred_out = self.smpl( | |
betas=pred_beta, | |
body_pose=pred_pose[:, 1:], | |
global_orient=pred_pose[:, :1]) | |
pred_vertices, pred_joints_3d = pred_out['vertices'], pred_out[ | |
'joints'] | |
all_preds = {} | |
all_preds['keypoints_3d'] = pred_joints_3d.detach().cpu().numpy() | |
all_preds['smpl_pose'] = pred_pose.detach().cpu().numpy() | |
all_preds['smpl_beta'] = pred_beta.detach().cpu().numpy() | |
all_preds['camera'] = pred_camera.detach().cpu().numpy() | |
if return_vertices: | |
all_preds['vertices'] = pred_vertices.detach().cpu().numpy() | |
if return_faces: | |
all_preds['faces'] = self.smpl.get_faces() | |
all_boxes = [] | |
image_path = [] | |
for img_meta in img_metas: | |
box = np.zeros(6, dtype=np.float32) | |
c = img_meta['center'] | |
s = img_meta['scale'] | |
if 'bbox_score' in img_metas: | |
score = np.array(img_metas['bbox_score']).reshape(-1) | |
else: | |
score = 1.0 | |
box[0:2] = c | |
box[2:4] = s | |
box[4] = np.prod(s * 200.0, axis=0) | |
box[5] = score | |
all_boxes.append(box) | |
image_path.append(img_meta['image_file']) | |
all_preds['bboxes'] = np.stack(all_boxes, axis=0) | |
all_preds['image_path'] = image_path | |
return all_preds | |
def get_3d_joints_from_mesh(self, vertices): | |
"""Get 3D joints from 3D mesh using predefined joints regressor.""" | |
return torch.matmul( | |
self.joints_regressor.to(vertices.device), vertices) | |
def forward(self, img, img_metas=None, return_loss=False, **kwargs): | |
"""Forward function. | |
Calls either forward_train or forward_test depending on whether | |
return_loss=True. | |
Note: | |
- batch_size: N | |
- num_img_channel: C (Default: 3) | |
- img height: imgH | |
- img width: imgW | |
Args: | |
img (torch.Tensor[N x C x imgH x imgW]): Input images. | |
img_metas (list(dict)): Information about data augmentation | |
By default this includes: | |
- "image_file: path to the image file | |
- "center": center of the bbox | |
- "scale": scale of the bbox | |
- "rotation": rotation of the bbox | |
- "bbox_score": score of bbox | |
return_loss (bool): Option to `return loss`. `return loss=True` | |
for training, `return loss=False` for validation & test. | |
Returns: | |
Return predicted 3D joints, SMPL parameters, boxes and image paths. | |
""" | |
if return_loss: | |
return self.forward_train(img, img_metas, **kwargs) | |
return self.forward_test(img, img_metas, **kwargs) | |
def show_result(self, | |
result, | |
img, | |
show=False, | |
out_file=None, | |
win_name='', | |
wait_time=0, | |
bbox_color='green', | |
mesh_color=(76, 76, 204), | |
**kwargs): | |
"""Visualize 3D mesh estimation results. | |
Args: | |
result (list[dict]): The mesh estimation results containing: | |
- "bbox" (ndarray[4]): instance bounding bbox | |
- "center" (ndarray[2]): bbox center | |
- "scale" (ndarray[2]): bbox scale | |
- "keypoints_3d" (ndarray[K,3]): predicted 3D keypoints | |
- "camera" (ndarray[3]): camera parameters | |
- "vertices" (ndarray[V, 3]): predicted 3D vertices | |
- "faces" (ndarray[F, 3]): mesh faces | |
img (str or Tensor): Optional. The image to visualize 2D inputs on. | |
win_name (str): The window name. | |
show (bool): Whether to show the image. Default: False. | |
wait_time (int): Value of waitKey param. Default: 0. | |
out_file (str or None): The filename to write the image. | |
Default: None. | |
bbox_color (str or tuple or :obj:`Color`): Color of bbox lines. | |
mesh_color (str or tuple or :obj:`Color`): Color of mesh surface. | |
Returns: | |
ndarray: Visualized img, only if not `show` or `out_file`. | |
""" | |
if img is not None: | |
img = mmcv.imread(img) | |
focal_length = self.loss_mesh.focal_length | |
H, W, C = img.shape | |
img_center = np.array([[0.5 * W], [0.5 * H]]) | |
# show bounding boxes | |
bboxes = [res['bbox'] for res in result] | |
bboxes = np.vstack(bboxes) | |
mmcv.imshow_bboxes( | |
img, bboxes, colors=bbox_color, top_k=-1, thickness=2, show=False) | |
vertex_list = [] | |
face_list = [] | |
for res in result: | |
vertices = res['vertices'] | |
faces = res['faces'] | |
camera = res['camera'] | |
camera_center = res['center'] | |
scale = res['scale'] | |
# predicted vertices are in root-relative space, | |
# we need to translate them to camera space. | |
translation = np.array([ | |
camera[1], camera[2], | |
2 * focal_length / (scale[0] * 200.0 * camera[0] + 1e-9) | |
]) | |
mean_depth = vertices[:, -1].mean() + translation[-1] | |
translation[:2] += (camera_center - | |
img_center[:, 0]) / focal_length * mean_depth | |
vertices += translation[None, :] | |
vertex_list.append(vertices) | |
face_list.append(faces) | |
# render from front view | |
img_vis = imshow_mesh_3d( | |
img, | |
vertex_list, | |
face_list, | |
img_center, [focal_length, focal_length], | |
colors=mesh_color) | |
# render from side view | |
# rotate mesh vertices | |
R = cv2.Rodrigues(np.array([0, np.radians(90.), 0]))[0] | |
rot_vertex_list = [np.dot(vert, R) for vert in vertex_list] | |
# get the 3D bbox containing all meshes | |
rot_vertices = np.concatenate(rot_vertex_list, axis=0) | |
min_corner = rot_vertices.min(0) | |
max_corner = rot_vertices.max(0) | |
center_3d = 0.5 * (min_corner + max_corner) | |
ratio = 0.8 | |
bbox3d_size = max_corner - min_corner | |
# set appropriate translation to make all meshes appear in the image | |
z_x = bbox3d_size[0] * focal_length / (ratio * W) - min_corner[2] | |
z_y = bbox3d_size[1] * focal_length / (ratio * H) - min_corner[2] | |
z = max(z_x, z_y) | |
translation = -center_3d | |
translation[2] = z | |
translation = translation[None, :] | |
rot_vertex_list = [ | |
rot_vert + translation for rot_vert in rot_vertex_list | |
] | |
# render from side view | |
img_side = imshow_mesh_3d( | |
np.ones_like(img) * 255, rot_vertex_list, face_list, img_center, | |
[focal_length, focal_length]) | |
# merger images from front view and side view | |
img_vis = np.concatenate([img_vis, img_side], axis=1) | |
if show: | |
mmcv.visualization.imshow(img_vis, win_name, wait_time) | |
if out_file is not None: | |
mmcv.imwrite(img_vis, out_file) | |
return img_vis | |