File size: 3,394 Bytes
11e6f7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# ------------------------------------------------------------------------------
# https://github.dev/HRNet/HigherHRNet-Human-Pose-Estimation
# Copyright (c) Microsoft
# Licensed under the MIT License.
# Written by Bin Xiao ([email protected])
# Modified by Bowen Cheng ([email protected])
# ------------------------------------------------------------------------------

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np

import logging

import torch
import torch.nn as nn
from pdb import set_trace as st

logger = logging.getLogger(__name__)


class HeatmapGenerator():
    def __init__(self, heatmap_size, num_joints=68, sigma=2):
        self.heatmap_size = heatmap_size
        # self.image_size = image_size
        self.num_joints = num_joints
        if sigma < 0:
            sigma = self.heatmap_size / 64
        self.sigma = sigma
        size = 6 * sigma + 3
        x = np.arange(0, size, 1, float)
        y = x[:, np.newaxis]
        x0, y0 = 3 * sigma + 1, 3 * sigma + 1
        self.g = np.exp(-((x - x0)**2 + (y - y0)**2) / (2 * sigma**2))

    # def __call__(self, joints, image_size: np.ndarray):
    def __call__(self, joints, image_size: int):
        """generate heatmap gt from joints

        Args:
            joints (np.ndarray): N,3

        Returns:
            hms: N,H,W
        """
        hms = np.zeros((self.num_joints, self.heatmap_size, self.heatmap_size),
                       dtype=np.float32)
        sigma = self.sigma

        # feat_stride = image_size / [self.heatmap_size, self.heatmap_size]
        feat_stride = image_size / self.heatmap_size
        for idx, pt in enumerate(joints):
            # for idx, pt in enumerate(p):
            if pt[2] > 0:
                # x = int(pt[0] / feat_stride[0] + 0.5)
                # y = int(pt[1] / feat_stride[1] + 0.5) # normalize joints to heatmap size
                x = int(pt[0] / feat_stride + 0.5)
                y = int(pt[1] / feat_stride +
                        0.5)  # normalize joints to heatmap size
                if x < 0 or y < 0 or \
                    x >= self.heatmap_size or y >= self.heatmap_size:
                    continue

                ul = int(np.round(x - 3 * sigma - 1)), int(
                    np.round(y - 3 * sigma - 1))
                br = int(np.round(x + 3 * sigma + 2)), int(
                    np.round(y + 3 * sigma + 2))

                c, d = max(0, -ul[0]), min(br[0], self.heatmap_size) - ul[0]
                a, b = max(0, -ul[1]), min(br[1], self.heatmap_size) - ul[1]

                cc, dd = max(0, ul[0]), min(br[0], self.heatmap_size)
                aa, bb = max(0, ul[1]), min(br[1], self.heatmap_size)
                hms[idx, aa:bb, cc:dd] = np.maximum(hms[idx, aa:bb, cc:dd],
                                                    self.g[a:b, c:d])
        return hms


class HeatmapLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, pred, gt, mask=None):
        # todo, add mask
        assert pred.size() == gt.size()
        loss = ((pred - gt)**2)
        if mask is not None:
            loss = loss * mask[:, None, :, :].expand_as(pred)
        # loss = loss.mean(dim=3).mean(dim=2).mean(dim=1)
        loss = loss.mean()
        # loss = loss.mean(dim=3).mean(dim=2).sum(dim=1)
        return loss