File size: 3,435 Bytes
95f8bbc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# Copyright (c) 2018-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

import argparse
import os
import sys
import tarfile
import zipfile
from glob import glob
from shutil import rmtree

import h5py
import numpy as np

sys.path.append('../')

output_filename_pt = 'data_2d_h36m_sh_pt_mpii'
output_filename_ft = 'data_2d_h36m_sh_ft_h36m'
subjects = ['S1', 'S5', 'S6', 'S7', 'S8', 'S9', 'S11']
cam_map = {
    '54138969': 0,
    '55011271': 1,
    '58860488': 2,
    '60457274': 3,
}

metadata = {
    'num_joints': 16,
    'keypoints_symmetry': [
        [3, 4, 5, 13, 14, 15],
        [0, 1, 2, 10, 11, 12],
    ]
}


def process_subject(subject, file_list, output):
    if subject == 'S11':
        assert len(file_list) == 119, "Expected 119 files for subject " + subject + ", got " + str(len(file_list))
    else:
        assert len(file_list) == 120, "Expected 120 files for subject " + subject + ", got " + str(len(file_list))

    for f in file_list:
        action, cam = os.path.splitext(os.path.basename(f))[0].replace('_', ' ').split('.')

        if subject == 'S11' and action == 'Directions':
            continue  # Discard corrupted video

        if action not in output[subject]:
            output[subject][action] = [None, None, None, None]

        with h5py.File(f) as hf:
            positions = hf['poses'].value
            output[subject][action][cam_map[cam]] = positions.astype('float32')


if __name__ == '__main__':
    if os.path.basename(os.getcwd()) != 'data':
        print('This script must be launched from the "data" directory')
        exit(0)

    parser = argparse.ArgumentParser(description='Human3.6M dataset downloader/converter')

    parser.add_argument('-pt', '--pretrained', default='', type=str, metavar='PATH', help='convert pretrained dataset')
    parser.add_argument('-ft', '--fine-tuned', default='', type=str, metavar='PATH', help='convert fine-tuned dataset')

    args = parser.parse_args()

    if args.pretrained:
        print('Converting pretrained dataset from', args.pretrained)
        print('Extracting...')
        with zipfile.ZipFile(args.pretrained, 'r') as archive:
            archive.extractall('sh_pt')

        print('Converting...')
        output = {}
        for subject in subjects:
            output[subject] = {}
            file_list = glob('sh_pt/h36m/' + subject + '/StackedHourglass/*.h5')
            process_subject(subject, file_list, output)

        print('Saving...')
        np.savez_compressed(output_filename_pt, positions_2d=output, metadata=metadata)

        print('Cleaning up...')
        rmtree('sh_pt')

        print('Done.')

    if args.fine_tuned:
        print('Converting fine-tuned dataset from', args.fine_tuned)
        print('Extracting...')
        with tarfile.open(args.fine_tuned, 'r:gz') as archive:
            archive.extractall('sh_ft')

        print('Converting...')
        output = {}
        for subject in subjects:
            output[subject] = {}
            file_list = glob('sh_ft/' + subject + '/StackedHourglassFineTuned240/*.h5')
            process_subject(subject, file_list, output)

        print('Saving...')
        np.savez_compressed(output_filename_ft, positions_2d=output, metadata=metadata)

        print('Cleaning up...')
        rmtree('sh_ft')

        print('Done.')