import torch import sys import os.path as osp import os import argparse import cv2 import time import h5py from tqdm import tqdm import numpy as np import warnings import signal warnings.filterwarnings('ignore') def signal_handler(sig, frame): print("\nInterrupted by user, shutting down...") if 'loader_thread' in globals() and loader_thread.is_alive(): loader_thread.join(timeout=1.0) # Give the thread 1 second to finish if torch.cuda.is_available(): torch.cuda.empty_cache() # Free GPU memory immediately os.exit(0) # Register the signal handler signal.signal(signal.SIGINT, signal_handler) sys.path.insert(0, osp.dirname(osp.realpath(__file__))) from tools.utils import get_path from model.gast_net import SpatioTemporalModel, SpatioTemporalModelOptimized1f from common.skeleton import Skeleton from common.graph_utils import adj_mx_from_skeleton from common.generators import * from tools.preprocess import load_kpts_json, h36m_coco_format, revise_kpts, revise_skes from tools.inference import gen_pose from tools.vis_kpts import plot_keypoint cur_dir, chk_root, data_root, lib_root, output_root = get_path(__file__) model_dir = chk_root + 'gastnet/' sys.path.insert(1, lib_root) from lib.pose import gen_video_kpts as hrnet_pose sys.path.pop(1) sys.path.pop(0) skeleton = Skeleton(parents=[-1, 0, 1, 2, 0, 4, 5, 0, 7, 8, 9, 8, 11, 12, 8, 14, 15], joints_left=[4, 5, 6, 11, 12, 13], joints_right=[1, 2, 3, 14, 15, 16]) adj = adj_mx_from_skeleton(skeleton) joints_left, joints_right = [4, 5, 6, 11, 12, 13], [1, 2, 3, 14, 15, 16] kps_left, kps_right = [4, 5, 6, 11, 12, 13], [1, 2, 3, 14, 15, 16] def load_model_layer(): chk = model_dir + '81_frame_model.bin' filters_width = [3, 3, 3, 3] channels = 64 model_pos = SpatioTemporalModel(adj, 17, 2, 17, filter_widths=filters_width, channels=channels, dropout=0.05) checkpoint = torch.load(chk) model_pos.load_state_dict(checkpoint['model_pos']) if torch.cuda.is_available(): model_pos = model_pos.cuda() model_pos = model_pos.eval() return model_pos def generate_skeletons(video=''): def force_exit(sig, frame): print("\nForce terminating...") os._exit(1) signal.signal(signal.SIGINT, force_exit) cap = cv2.VideoCapture(video) width = cap.get(cv2.CAP_PROP_FRAME_WIDTH) height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) # 2D Keypoint Generation (handled by gen_video_kpts) print('Generating 2D Keypoints:') sys.stdout.flush() keypoints, scores = hrnet_pose(video, det_dim=416, gen_output=True) keypoints, scores, valid_frames = h36m_coco_format(keypoints, scores) re_kpts = revise_kpts(keypoints, scores, valid_frames) num_person = len(re_kpts) model_pos = load_model_layer() pad = (81 - 1) // 2 causal_shift = 0 # 3D Pose Generation print('Recording 3D Pose:') print(f"PROGRESS:100.00") # Start 3D at 100% sys.stdout.flush() total_valid_frames = len(valid_frames) if valid_frames else total_frames prediction = gen_pose(re_kpts, valid_frames, width, height, model_pos, pad, causal_shift) # Simulate 3D progress (replace with gen_pose loop if shared) for i in range(total_valid_frames): progress = 100 + ((i + 1) / total_valid_frames * 100) # 100-200% for 3D print(f"PROGRESS:{progress:.2f}") sys.stdout.flush() time.sleep(0.01) # Placeholder; remove if gen_pose has its own loop output_dir = os.path.abspath('../outputs/') print(f"Creating output directory: {output_dir}") os.makedirs(output_dir, exist_ok=True) npz_dir = os.path.join(output_dir, 'npz') print(f"Creating NPZ directory: {npz_dir}") os.makedirs(npz_dir, exist_ok=True) output_npz = os.path.join(npz_dir, os.path.basename(video).split('.')[0] + '.npz') print(f"Saving NPZ to: {output_npz}") np.savez_compressed(output_npz, reconstruction=prediction) print(f"NPZ saved successfully: {output_npz}") def arg_parse(): parser = argparse.ArgumentParser('Generating skeleton demo.') parser.add_argument('-v', '--video', type=str) args = parser.parse_args() return args if __name__ == "__main__": args = arg_parse() if os.path.isabs(args.video): video_path = args.video else: video_path = os.path.join(data_root, 'video', args.video) generate_skeletons(video=video_path)