File size: 6,040 Bytes
d7a991a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
# ------------------------------------------------------------------------------
# Copyright and License Information
# https://github.com/microsoft/voxelpose-pytorch/blob/main/lib/models
# Original Licence: MIT License
# ------------------------------------------------------------------------------

import torch
import torch.nn as nn
import torch.nn.functional as F

from ..builder import HEADS


@HEADS.register_module()
class CuboidCenterHead(nn.Module):
    """Get results from the 3D human center heatmap. In this module, human 3D
    centers are local maximums obtained from the 3D heatmap via NMS (max-
    pooling).

    Args:
        space_size (list[3]): The size of the 3D space.
        cube_size (list[3]): The size of the heatmap volume.
        space_center (list[3]): The coordinate of space center.
        max_num (int): Maximum of human center detections.
        max_pool_kernel (int): Kernel size of the max-pool kernel in nms.
    """

    def __init__(self,
                 space_size,
                 space_center,
                 cube_size,
                 max_num=10,
                 max_pool_kernel=3):
        super(CuboidCenterHead, self).__init__()
        # use register_buffer
        self.register_buffer('grid_size', torch.tensor(space_size))
        self.register_buffer('cube_size', torch.tensor(cube_size))
        self.register_buffer('grid_center', torch.tensor(space_center))

        self.num_candidates = max_num
        self.max_pool_kernel = max_pool_kernel
        self.loss = nn.MSELoss()

    def _get_real_locations(self, indices):
        """
        Args:
            indices (torch.Tensor(NXP)): Indices of points in the 3D tensor

        Returns:
            real_locations (torch.Tensor(NXPx3)): Locations of points
                in the world coordinate system
        """
        real_locations = indices.float() / (
                self.cube_size - 1) * self.grid_size + \
            self.grid_center - self.grid_size / 2.0
        return real_locations

    def _nms_by_max_pool(self, heatmap_volumes):
        max_num = self.num_candidates
        batch_size = heatmap_volumes.shape[0]
        root_cubes_nms = self._max_pool(heatmap_volumes)
        root_cubes_nms_reshape = root_cubes_nms.reshape(batch_size, -1)
        topk_values, topk_index = root_cubes_nms_reshape.topk(max_num)
        topk_unravel_index = self._get_3d_indices(topk_index,
                                                  heatmap_volumes[0].shape)

        return topk_values, topk_unravel_index

    def _max_pool(self, inputs):
        kernel = self.max_pool_kernel
        padding = (kernel - 1) // 2
        max = F.max_pool3d(
            inputs, kernel_size=kernel, stride=1, padding=padding)
        keep = (inputs == max).float()
        return keep * inputs

    @staticmethod
    def _get_3d_indices(indices, shape):
        """Get indices in the 3-D tensor.

        Args:
            indices (torch.Tensor(NXp)): Indices of points in the 1D tensor
            shape (torch.Size(3)): The shape of the original 3D tensor

        Returns:
            indices: Indices of points in the original 3D tensor
        """
        batch_size = indices.shape[0]
        num_people = indices.shape[1]
        indices_x = (indices //
                     (shape[1] * shape[2])).reshape(batch_size, num_people, -1)
        indices_y = ((indices % (shape[1] * shape[2])) //
                     shape[2]).reshape(batch_size, num_people, -1)
        indices_z = (indices % shape[2]).reshape(batch_size, num_people, -1)
        indices = torch.cat([indices_x, indices_y, indices_z], dim=2)
        return indices

    def forward(self, heatmap_volumes):
        """

        Args:
            heatmap_volumes (torch.Tensor(NXLXWXH)):
                3D human center heatmaps predicted by the network.
        Returns:
            human_centers (torch.Tensor(NXPX5)):
                Coordinates of human centers.
        """
        batch_size = heatmap_volumes.shape[0]

        topk_values, topk_unravel_index = self._nms_by_max_pool(
            heatmap_volumes.detach())

        topk_unravel_index = self._get_real_locations(topk_unravel_index)

        human_centers = torch.zeros(
            batch_size, self.num_candidates, 5, device=heatmap_volumes.device)
        human_centers[:, :, 0:3] = topk_unravel_index
        human_centers[:, :, 4] = topk_values

        return human_centers

    def get_loss(self, pred_cubes, gt):

        return dict(loss_center=self.loss(pred_cubes, gt))


@HEADS.register_module()
class CuboidPoseHead(nn.Module):

    def __init__(self, beta):
        """Get results from the 3D human pose heatmap. Instead of obtaining
        maximums on the heatmap, this module regresses the coordinates of
        keypoints via integral pose regression. Refer to `paper.

        <https://arxiv.org/abs/2004.06239>` for more details.

        Args:
            beta: Constant to adjust the magnification of soft-maxed heatmap.
        """
        super(CuboidPoseHead, self).__init__()
        self.beta = beta
        self.loss = nn.L1Loss()

    def forward(self, heatmap_volumes, grid_coordinates):
        """

        Args:
            heatmap_volumes (torch.Tensor(NxKxLxWxH)):
                3D human pose heatmaps predicted by the network.
            grid_coordinates (torch.Tensor(Nx(LxWxH)x3)):
                Coordinates of the grids in the heatmap volumes.
        Returns:
            human_poses (torch.Tensor(NxKx3)): Coordinates of human poses.
        """
        batch_size = heatmap_volumes.size(0)
        channel = heatmap_volumes.size(1)
        x = heatmap_volumes.reshape(batch_size, channel, -1, 1)
        x = F.softmax(self.beta * x, dim=2)
        grid_coordinates = grid_coordinates.unsqueeze(1)
        x = torch.mul(x, grid_coordinates)
        human_poses = torch.sum(x, dim=2)

        return human_poses

    def get_loss(self, preds, targets, weights):

        return dict(loss_pose=self.loss(preds * weights, targets * weights))