File size: 4,498 Bytes
1cdc47e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82d4b57
 
 
 
 
 
 
 
 
 
 
1cdc47e
 
 
 
 
 
 
 
 
 
82d4b57
1cdc47e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82d4b57
 
 
 
 
1cdc47e
 
 
4e2aa43
1cdc47e
4e2aa43
 
 
1cdc47e
4e2aa43
1cdc47e
 
 
 
 
 
 
 
 
4e2aa43
1cdc47e
4e2aa43
 
 
 
 
 
 
 
 
 
1cdc47e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import torch
import sys
import os.path as osp
import os
import argparse
import cv2
import time
import h5py
from tqdm import tqdm
import numpy as np
import warnings
import signal

warnings.filterwarnings('ignore')

def signal_handler(sig, frame):
    print("\nInterrupted by user, shutting down...")
    if 'loader_thread' in globals() and loader_thread.is_alive():
        loader_thread.join(timeout=1.0)  # Give the thread 1 second to finish
    if torch.cuda.is_available():
        torch.cuda.empty_cache()  # Free GPU memory immediately
    os.exit(0)

# Register the signal handler
signal.signal(signal.SIGINT, signal_handler)

sys.path.insert(0, osp.dirname(osp.realpath(__file__)))
from tools.utils import get_path
from model.gast_net import SpatioTemporalModel, SpatioTemporalModelOptimized1f
from common.skeleton import Skeleton
from common.graph_utils import adj_mx_from_skeleton
from common.generators import *
from tools.preprocess import load_kpts_json, h36m_coco_format, revise_kpts, revise_skes
from tools.inference import gen_pose
from tools.vis_kpts import plot_keypoint


cur_dir, chk_root, data_root, lib_root, output_root = get_path(__file__)
model_dir = chk_root + 'gastnet/'
sys.path.insert(1, lib_root)
from lib.pose import gen_video_kpts as hrnet_pose
sys.path.pop(1)
sys.path.pop(0)

skeleton = Skeleton(parents=[-1, 0, 1, 2, 0, 4, 5, 0, 7, 8, 9, 8, 11, 12, 8, 14, 15],
                    joints_left=[4, 5, 6, 11, 12, 13], joints_right=[1, 2, 3, 14, 15, 16])
adj = adj_mx_from_skeleton(skeleton)

joints_left, joints_right = [4, 5, 6, 11, 12, 13], [1, 2, 3, 14, 15, 16]
kps_left, kps_right = [4, 5, 6, 11, 12, 13], [1, 2, 3, 14, 15, 16]


def load_model_layer():
    chk = model_dir + '81_frame_model.bin'
    filters_width = [3, 3, 3, 3]
    channels = 64

    model_pos = SpatioTemporalModel(adj, 17, 2, 17, filter_widths=filters_width, channels=channels, dropout=0.05)

    checkpoint = torch.load(chk)
    model_pos.load_state_dict(checkpoint['model_pos'])

    if torch.cuda.is_available():
        model_pos = model_pos.cuda()
    model_pos = model_pos.eval()

    return model_pos

def generate_skeletons(video=''):
    def force_exit(sig, frame):
        print("\nForce terminating...")
        os._exit(1)
    signal.signal(signal.SIGINT, force_exit)
    
    cap = cv2.VideoCapture(video)
    width = cap.get(cv2.CAP_PROP_FRAME_WIDTH)
    height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    # 2D Keypoint Generation (handled by gen_video_kpts)
    print('Generating 2D Keypoints:')
    sys.stdout.flush()
    keypoints, scores = hrnet_pose(video, det_dim=416, gen_output=True)
    
    keypoints, scores, valid_frames = h36m_coco_format(keypoints, scores)
    re_kpts = revise_kpts(keypoints, scores, valid_frames)
    num_person = len(re_kpts)

    model_pos = load_model_layer()

    pad = (81 - 1) // 2
    causal_shift = 0

    # 3D Pose Generation
    print('Recording 3D Pose:')
    print(f"PROGRESS:100.00")  # Start 3D at 100%
    sys.stdout.flush()
    total_valid_frames = len(valid_frames) if valid_frames else total_frames
    prediction = gen_pose(re_kpts, valid_frames, width, height, model_pos, pad, causal_shift)
    # Simulate 3D progress (replace with gen_pose loop if shared)
    for i in range(total_valid_frames):
        progress = 100 + ((i + 1) / total_valid_frames * 100)  # 100-200% for 3D
        print(f"PROGRESS:{progress:.2f}")
        sys.stdout.flush()
        time.sleep(0.01)  # Placeholder; remove if gen_pose has its own loop

    output_dir = os.path.abspath('../outputs/')
    print(f"Creating output directory: {output_dir}")
    os.makedirs(output_dir, exist_ok=True)

    npz_dir = os.path.join(output_dir, 'npz')
    print(f"Creating NPZ directory: {npz_dir}")
    os.makedirs(npz_dir, exist_ok=True)

    output_npz = os.path.join(npz_dir, os.path.basename(video).split('.')[0] + '.npz')
    print(f"Saving NPZ to: {output_npz}")
    np.savez_compressed(output_npz, reconstruction=prediction)
    print(f"NPZ saved successfully: {output_npz}")
    
def arg_parse():
    parser = argparse.ArgumentParser('Generating skeleton demo.')
    parser.add_argument('-v', '--video', type=str)
    args = parser.parse_args()
    return args

if __name__ == "__main__":
    args = arg_parse()
    if os.path.isabs(args.video):
        video_path = args.video
    else:
        video_path = os.path.join(data_root, 'video', args.video)
    generate_skeletons(video=video_path)