Spaces:
Build error
Build error
File size: 6,040 Bytes
d7a991a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
# ------------------------------------------------------------------------------
# Copyright and License Information
# https://github.com/microsoft/voxelpose-pytorch/blob/main/lib/models
# Original Licence: MIT License
# ------------------------------------------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..builder import HEADS
@HEADS.register_module()
class CuboidCenterHead(nn.Module):
"""Get results from the 3D human center heatmap. In this module, human 3D
centers are local maximums obtained from the 3D heatmap via NMS (max-
pooling).
Args:
space_size (list[3]): The size of the 3D space.
cube_size (list[3]): The size of the heatmap volume.
space_center (list[3]): The coordinate of space center.
max_num (int): Maximum of human center detections.
max_pool_kernel (int): Kernel size of the max-pool kernel in nms.
"""
def __init__(self,
space_size,
space_center,
cube_size,
max_num=10,
max_pool_kernel=3):
super(CuboidCenterHead, self).__init__()
# use register_buffer
self.register_buffer('grid_size', torch.tensor(space_size))
self.register_buffer('cube_size', torch.tensor(cube_size))
self.register_buffer('grid_center', torch.tensor(space_center))
self.num_candidates = max_num
self.max_pool_kernel = max_pool_kernel
self.loss = nn.MSELoss()
def _get_real_locations(self, indices):
"""
Args:
indices (torch.Tensor(NXP)): Indices of points in the 3D tensor
Returns:
real_locations (torch.Tensor(NXPx3)): Locations of points
in the world coordinate system
"""
real_locations = indices.float() / (
self.cube_size - 1) * self.grid_size + \
self.grid_center - self.grid_size / 2.0
return real_locations
def _nms_by_max_pool(self, heatmap_volumes):
max_num = self.num_candidates
batch_size = heatmap_volumes.shape[0]
root_cubes_nms = self._max_pool(heatmap_volumes)
root_cubes_nms_reshape = root_cubes_nms.reshape(batch_size, -1)
topk_values, topk_index = root_cubes_nms_reshape.topk(max_num)
topk_unravel_index = self._get_3d_indices(topk_index,
heatmap_volumes[0].shape)
return topk_values, topk_unravel_index
def _max_pool(self, inputs):
kernel = self.max_pool_kernel
padding = (kernel - 1) // 2
max = F.max_pool3d(
inputs, kernel_size=kernel, stride=1, padding=padding)
keep = (inputs == max).float()
return keep * inputs
@staticmethod
def _get_3d_indices(indices, shape):
"""Get indices in the 3-D tensor.
Args:
indices (torch.Tensor(NXp)): Indices of points in the 1D tensor
shape (torch.Size(3)): The shape of the original 3D tensor
Returns:
indices: Indices of points in the original 3D tensor
"""
batch_size = indices.shape[0]
num_people = indices.shape[1]
indices_x = (indices //
(shape[1] * shape[2])).reshape(batch_size, num_people, -1)
indices_y = ((indices % (shape[1] * shape[2])) //
shape[2]).reshape(batch_size, num_people, -1)
indices_z = (indices % shape[2]).reshape(batch_size, num_people, -1)
indices = torch.cat([indices_x, indices_y, indices_z], dim=2)
return indices
def forward(self, heatmap_volumes):
"""
Args:
heatmap_volumes (torch.Tensor(NXLXWXH)):
3D human center heatmaps predicted by the network.
Returns:
human_centers (torch.Tensor(NXPX5)):
Coordinates of human centers.
"""
batch_size = heatmap_volumes.shape[0]
topk_values, topk_unravel_index = self._nms_by_max_pool(
heatmap_volumes.detach())
topk_unravel_index = self._get_real_locations(topk_unravel_index)
human_centers = torch.zeros(
batch_size, self.num_candidates, 5, device=heatmap_volumes.device)
human_centers[:, :, 0:3] = topk_unravel_index
human_centers[:, :, 4] = topk_values
return human_centers
def get_loss(self, pred_cubes, gt):
return dict(loss_center=self.loss(pred_cubes, gt))
@HEADS.register_module()
class CuboidPoseHead(nn.Module):
def __init__(self, beta):
"""Get results from the 3D human pose heatmap. Instead of obtaining
maximums on the heatmap, this module regresses the coordinates of
keypoints via integral pose regression. Refer to `paper.
<https://arxiv.org/abs/2004.06239>` for more details.
Args:
beta: Constant to adjust the magnification of soft-maxed heatmap.
"""
super(CuboidPoseHead, self).__init__()
self.beta = beta
self.loss = nn.L1Loss()
def forward(self, heatmap_volumes, grid_coordinates):
"""
Args:
heatmap_volumes (torch.Tensor(NxKxLxWxH)):
3D human pose heatmaps predicted by the network.
grid_coordinates (torch.Tensor(Nx(LxWxH)x3)):
Coordinates of the grids in the heatmap volumes.
Returns:
human_poses (torch.Tensor(NxKx3)): Coordinates of human poses.
"""
batch_size = heatmap_volumes.size(0)
channel = heatmap_volumes.size(1)
x = heatmap_volumes.reshape(batch_size, channel, -1, 1)
x = F.softmax(self.beta * x, dim=2)
grid_coordinates = grid_coordinates.unsqueeze(1)
x = torch.mul(x, grid_coordinates)
human_poses = torch.sum(x, dim=2)
return human_poses
def get_loss(self, preds, targets, weights):
return dict(loss_pose=self.loss(preds * weights, targets * weights))
|