Spaces:
Running
Running
# Copyright (c) 2018-present, Facebook, Inc. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
# | |
import argparse | |
import os | |
import sys | |
import tarfile | |
import zipfile | |
from glob import glob | |
from shutil import rmtree | |
import h5py | |
import numpy as np | |
sys.path.append('../') | |
output_filename_pt = 'data_2d_h36m_sh_pt_mpii' | |
output_filename_ft = 'data_2d_h36m_sh_ft_h36m' | |
subjects = ['S1', 'S5', 'S6', 'S7', 'S8', 'S9', 'S11'] | |
cam_map = { | |
'54138969': 0, | |
'55011271': 1, | |
'58860488': 2, | |
'60457274': 3, | |
} | |
metadata = { | |
'num_joints': 16, | |
'keypoints_symmetry': [ | |
[3, 4, 5, 13, 14, 15], | |
[0, 1, 2, 10, 11, 12], | |
] | |
} | |
def process_subject(subject, file_list, output): | |
if subject == 'S11': | |
assert len(file_list) == 119, "Expected 119 files for subject " + subject + ", got " + str(len(file_list)) | |
else: | |
assert len(file_list) == 120, "Expected 120 files for subject " + subject + ", got " + str(len(file_list)) | |
for f in file_list: | |
action, cam = os.path.splitext(os.path.basename(f))[0].replace('_', ' ').split('.') | |
if subject == 'S11' and action == 'Directions': | |
continue # Discard corrupted video | |
if action not in output[subject]: | |
output[subject][action] = [None, None, None, None] | |
with h5py.File(f) as hf: | |
positions = hf['poses'].value | |
output[subject][action][cam_map[cam]] = positions.astype('float32') | |
if __name__ == '__main__': | |
if os.path.basename(os.getcwd()) != 'data': | |
print('This script must be launched from the "data" directory') | |
exit(0) | |
parser = argparse.ArgumentParser(description='Human3.6M dataset downloader/converter') | |
parser.add_argument('-pt', '--pretrained', default='', type=str, metavar='PATH', help='convert pretrained dataset') | |
parser.add_argument('-ft', '--fine-tuned', default='', type=str, metavar='PATH', help='convert fine-tuned dataset') | |
args = parser.parse_args() | |
if args.pretrained: | |
print('Converting pretrained dataset from', args.pretrained) | |
print('Extracting...') | |
with zipfile.ZipFile(args.pretrained, 'r') as archive: | |
archive.extractall('sh_pt') | |
print('Converting...') | |
output = {} | |
for subject in subjects: | |
output[subject] = {} | |
file_list = glob('sh_pt/h36m/' + subject + '/StackedHourglass/*.h5') | |
process_subject(subject, file_list, output) | |
print('Saving...') | |
np.savez_compressed(output_filename_pt, positions_2d=output, metadata=metadata) | |
print('Cleaning up...') | |
rmtree('sh_pt') | |
print('Done.') | |
if args.fine_tuned: | |
print('Converting fine-tuned dataset from', args.fine_tuned) | |
print('Extracting...') | |
with tarfile.open(args.fine_tuned, 'r:gz') as archive: | |
archive.extractall('sh_ft') | |
print('Converting...') | |
output = {} | |
for subject in subjects: | |
output[subject] = {} | |
file_list = glob('sh_ft/' + subject + '/StackedHourglassFineTuned240/*.h5') | |
process_subject(subject, file_list, output) | |
print('Saving...') | |
np.savez_compressed(output_filename_ft, positions_2d=output, metadata=metadata) | |
print('Cleaning up...') | |
rmtree('sh_ft') | |
print('Done.') | |