Spaces:
Build error
Build error
# ------------------------------------------------------------------------------ | |
# Copyright and License Information | |
# https://github.com/microsoft/voxelpose-pytorch/blob/main/lib/models | |
# Original Licence: MIT License | |
# ------------------------------------------------------------------------------ | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from ..builder import HEADS | |
class CuboidCenterHead(nn.Module): | |
"""Get results from the 3D human center heatmap. In this module, human 3D | |
centers are local maximums obtained from the 3D heatmap via NMS (max- | |
pooling). | |
Args: | |
space_size (list[3]): The size of the 3D space. | |
cube_size (list[3]): The size of the heatmap volume. | |
space_center (list[3]): The coordinate of space center. | |
max_num (int): Maximum of human center detections. | |
max_pool_kernel (int): Kernel size of the max-pool kernel in nms. | |
""" | |
def __init__(self, | |
space_size, | |
space_center, | |
cube_size, | |
max_num=10, | |
max_pool_kernel=3): | |
super(CuboidCenterHead, self).__init__() | |
# use register_buffer | |
self.register_buffer('grid_size', torch.tensor(space_size)) | |
self.register_buffer('cube_size', torch.tensor(cube_size)) | |
self.register_buffer('grid_center', torch.tensor(space_center)) | |
self.num_candidates = max_num | |
self.max_pool_kernel = max_pool_kernel | |
self.loss = nn.MSELoss() | |
def _get_real_locations(self, indices): | |
""" | |
Args: | |
indices (torch.Tensor(NXP)): Indices of points in the 3D tensor | |
Returns: | |
real_locations (torch.Tensor(NXPx3)): Locations of points | |
in the world coordinate system | |
""" | |
real_locations = indices.float() / ( | |
self.cube_size - 1) * self.grid_size + \ | |
self.grid_center - self.grid_size / 2.0 | |
return real_locations | |
def _nms_by_max_pool(self, heatmap_volumes): | |
max_num = self.num_candidates | |
batch_size = heatmap_volumes.shape[0] | |
root_cubes_nms = self._max_pool(heatmap_volumes) | |
root_cubes_nms_reshape = root_cubes_nms.reshape(batch_size, -1) | |
topk_values, topk_index = root_cubes_nms_reshape.topk(max_num) | |
topk_unravel_index = self._get_3d_indices(topk_index, | |
heatmap_volumes[0].shape) | |
return topk_values, topk_unravel_index | |
def _max_pool(self, inputs): | |
kernel = self.max_pool_kernel | |
padding = (kernel - 1) // 2 | |
max = F.max_pool3d( | |
inputs, kernel_size=kernel, stride=1, padding=padding) | |
keep = (inputs == max).float() | |
return keep * inputs | |
def _get_3d_indices(indices, shape): | |
"""Get indices in the 3-D tensor. | |
Args: | |
indices (torch.Tensor(NXp)): Indices of points in the 1D tensor | |
shape (torch.Size(3)): The shape of the original 3D tensor | |
Returns: | |
indices: Indices of points in the original 3D tensor | |
""" | |
batch_size = indices.shape[0] | |
num_people = indices.shape[1] | |
indices_x = (indices // | |
(shape[1] * shape[2])).reshape(batch_size, num_people, -1) | |
indices_y = ((indices % (shape[1] * shape[2])) // | |
shape[2]).reshape(batch_size, num_people, -1) | |
indices_z = (indices % shape[2]).reshape(batch_size, num_people, -1) | |
indices = torch.cat([indices_x, indices_y, indices_z], dim=2) | |
return indices | |
def forward(self, heatmap_volumes): | |
""" | |
Args: | |
heatmap_volumes (torch.Tensor(NXLXWXH)): | |
3D human center heatmaps predicted by the network. | |
Returns: | |
human_centers (torch.Tensor(NXPX5)): | |
Coordinates of human centers. | |
""" | |
batch_size = heatmap_volumes.shape[0] | |
topk_values, topk_unravel_index = self._nms_by_max_pool( | |
heatmap_volumes.detach()) | |
topk_unravel_index = self._get_real_locations(topk_unravel_index) | |
human_centers = torch.zeros( | |
batch_size, self.num_candidates, 5, device=heatmap_volumes.device) | |
human_centers[:, :, 0:3] = topk_unravel_index | |
human_centers[:, :, 4] = topk_values | |
return human_centers | |
def get_loss(self, pred_cubes, gt): | |
return dict(loss_center=self.loss(pred_cubes, gt)) | |
class CuboidPoseHead(nn.Module): | |
def __init__(self, beta): | |
"""Get results from the 3D human pose heatmap. Instead of obtaining | |
maximums on the heatmap, this module regresses the coordinates of | |
keypoints via integral pose regression. Refer to `paper. | |
<https://arxiv.org/abs/2004.06239>` for more details. | |
Args: | |
beta: Constant to adjust the magnification of soft-maxed heatmap. | |
""" | |
super(CuboidPoseHead, self).__init__() | |
self.beta = beta | |
self.loss = nn.L1Loss() | |
def forward(self, heatmap_volumes, grid_coordinates): | |
""" | |
Args: | |
heatmap_volumes (torch.Tensor(NxKxLxWxH)): | |
3D human pose heatmaps predicted by the network. | |
grid_coordinates (torch.Tensor(Nx(LxWxH)x3)): | |
Coordinates of the grids in the heatmap volumes. | |
Returns: | |
human_poses (torch.Tensor(NxKx3)): Coordinates of human poses. | |
""" | |
batch_size = heatmap_volumes.size(0) | |
channel = heatmap_volumes.size(1) | |
x = heatmap_volumes.reshape(batch_size, channel, -1, 1) | |
x = F.softmax(self.beta * x, dim=2) | |
grid_coordinates = grid_coordinates.unsqueeze(1) | |
x = torch.mul(x, grid_coordinates) | |
human_poses = torch.sum(x, dim=2) | |
return human_poses | |
def get_loss(self, preds, targets, weights): | |
return dict(loss_pose=self.loss(preds * weights, targets * weights)) | |