File size: 2,454 Bytes
d7a991a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# Copyright (c) OpenMMLab. All rights reserved.
import torch

from mmpose.core.post_processing.group import HeatmapParser


def test_group():
    cfg = {}
    cfg['num_joints'] = 17
    cfg['detection_threshold'] = 0.1
    cfg['tag_threshold'] = 1
    cfg['use_detection_val'] = True
    cfg['ignore_too_much'] = False
    cfg['nms_kernel'] = 5
    cfg['nms_padding'] = 2
    cfg['tag_per_joint'] = True
    cfg['max_num_people'] = 1
    parser = HeatmapParser(cfg)
    fake_heatmap = torch.zeros(1, 1, 5, 5)
    fake_heatmap[0, 0, 3, 3] = 1
    fake_heatmap[0, 0, 3, 2] = 0.8
    assert parser.nms(fake_heatmap)[0, 0, 3, 2] == 0
    fake_heatmap = torch.zeros(1, 17, 32, 32)
    fake_tag = torch.zeros(1, 17, 32, 32, 1)
    fake_heatmap[0, 0, 10, 10] = 0.8
    fake_heatmap[0, 1, 12, 12] = 0.9
    fake_heatmap[0, 4, 8, 8] = 0.8
    fake_heatmap[0, 8, 6, 6] = 0.9
    fake_tag[0, 0, 10, 10] = 0.8
    fake_tag[0, 1, 12, 12] = 0.9
    fake_tag[0, 4, 8, 8] = 0.8
    fake_tag[0, 8, 6, 6] = 0.9
    grouped, scores = parser.parse(fake_heatmap, fake_tag, True, True)
    assert grouped[0][0, 0, 0] == 10.25
    assert abs(scores[0] - 0.2) < 0.001
    cfg['tag_per_joint'] = False
    parser = HeatmapParser(cfg)
    grouped, scores = parser.parse(fake_heatmap, fake_tag, False, False)
    assert grouped[0][0, 0, 0] == 10.
    grouped, scores = parser.parse(fake_heatmap, fake_tag, False, True)
    assert grouped[0][0, 0, 0] == 10.


def test_group_score_per_joint():
    cfg = {}
    cfg['num_joints'] = 17
    cfg['detection_threshold'] = 0.1
    cfg['tag_threshold'] = 1
    cfg['use_detection_val'] = True
    cfg['ignore_too_much'] = False
    cfg['nms_kernel'] = 5
    cfg['nms_padding'] = 2
    cfg['tag_per_joint'] = True
    cfg['max_num_people'] = 1
    cfg['score_per_joint'] = True
    parser = HeatmapParser(cfg)
    fake_heatmap = torch.zeros(1, 1, 5, 5)
    fake_heatmap[0, 0, 3, 3] = 1
    fake_heatmap[0, 0, 3, 2] = 0.8
    assert parser.nms(fake_heatmap)[0, 0, 3, 2] == 0
    fake_heatmap = torch.zeros(1, 17, 32, 32)
    fake_tag = torch.zeros(1, 17, 32, 32, 1)
    fake_heatmap[0, 0, 10, 10] = 0.8
    fake_heatmap[0, 1, 12, 12] = 0.9
    fake_heatmap[0, 4, 8, 8] = 0.8
    fake_heatmap[0, 8, 6, 6] = 0.9
    fake_tag[0, 0, 10, 10] = 0.8
    fake_tag[0, 1, 12, 12] = 0.9
    fake_tag[0, 4, 8, 8] = 0.8
    fake_tag[0, 8, 6, 6] = 0.9
    grouped, scores = parser.parse(fake_heatmap, fake_tag, True, True)
    assert len(scores[0]) == 17