Spaces:
Runtime error
Runtime error
import torch | |
import sys | |
import os.path as osp | |
import os | |
import argparse | |
import cv2 | |
import time | |
import h5py | |
from tqdm import tqdm | |
import numpy as np | |
import warnings | |
import signal | |
warnings.filterwarnings('ignore') | |
def signal_handler(sig, frame): | |
print("\nInterrupted by user, shutting down...") | |
if 'loader_thread' in globals() and loader_thread.is_alive(): | |
loader_thread.join(timeout=1.0) # Give the thread 1 second to finish | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() # Free GPU memory immediately | |
os.exit(0) | |
# Register the signal handler | |
signal.signal(signal.SIGINT, signal_handler) | |
sys.path.insert(0, osp.dirname(osp.realpath(__file__))) | |
from tools.utils import get_path | |
from model.gast_net import SpatioTemporalModel, SpatioTemporalModelOptimized1f | |
from common.skeleton import Skeleton | |
from common.graph_utils import adj_mx_from_skeleton | |
from common.generators import * | |
from tools.preprocess import load_kpts_json, h36m_coco_format, revise_kpts, revise_skes | |
from tools.inference import gen_pose | |
from tools.vis_kpts import plot_keypoint | |
cur_dir, chk_root, data_root, lib_root, output_root = get_path(__file__) | |
model_dir = chk_root + 'gastnet/' | |
sys.path.insert(1, lib_root) | |
from lib.pose import gen_video_kpts as hrnet_pose | |
sys.path.pop(1) | |
sys.path.pop(0) | |
skeleton = Skeleton(parents=[-1, 0, 1, 2, 0, 4, 5, 0, 7, 8, 9, 8, 11, 12, 8, 14, 15], | |
joints_left=[4, 5, 6, 11, 12, 13], joints_right=[1, 2, 3, 14, 15, 16]) | |
adj = adj_mx_from_skeleton(skeleton) | |
joints_left, joints_right = [4, 5, 6, 11, 12, 13], [1, 2, 3, 14, 15, 16] | |
kps_left, kps_right = [4, 5, 6, 11, 12, 13], [1, 2, 3, 14, 15, 16] | |
def load_model_layer(): | |
chk = model_dir + '81_frame_model.bin' | |
filters_width = [3, 3, 3, 3] | |
channels = 64 | |
model_pos = SpatioTemporalModel(adj, 17, 2, 17, filter_widths=filters_width, channels=channels, dropout=0.05) | |
checkpoint = torch.load(chk) | |
model_pos.load_state_dict(checkpoint['model_pos']) | |
if torch.cuda.is_available(): | |
model_pos = model_pos.cuda() | |
model_pos = model_pos.eval() | |
return model_pos | |
def generate_skeletons(video=''): | |
def force_exit(sig, frame): | |
print("\nForce terminating...") | |
os._exit(1) | |
signal.signal(signal.SIGINT, force_exit) | |
cap = cv2.VideoCapture(video) | |
width = cap.get(cv2.CAP_PROP_FRAME_WIDTH) | |
height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) | |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
# 2D Keypoint Generation (handled by gen_video_kpts) | |
print('Generating 2D Keypoints:') | |
sys.stdout.flush() | |
keypoints, scores = hrnet_pose(video, det_dim=416, gen_output=True) | |
keypoints, scores, valid_frames = h36m_coco_format(keypoints, scores) | |
re_kpts = revise_kpts(keypoints, scores, valid_frames) | |
num_person = len(re_kpts) | |
model_pos = load_model_layer() | |
pad = (81 - 1) // 2 | |
causal_shift = 0 | |
# 3D Pose Generation | |
print('Recording 3D Pose:') | |
print(f"PROGRESS:100.00") # Start 3D at 100% | |
sys.stdout.flush() | |
total_valid_frames = len(valid_frames) if valid_frames else total_frames | |
prediction = gen_pose(re_kpts, valid_frames, width, height, model_pos, pad, causal_shift) | |
# Simulate 3D progress (replace with gen_pose loop if shared) | |
for i in range(total_valid_frames): | |
progress = 100 + ((i + 1) / total_valid_frames * 100) # 100-200% for 3D | |
print(f"PROGRESS:{progress:.2f}") | |
sys.stdout.flush() | |
time.sleep(0.01) # Placeholder; remove if gen_pose has its own loop | |
output_dir = os.path.abspath('../outputs/') | |
print(f"Creating output directory: {output_dir}") | |
os.makedirs(output_dir, exist_ok=True) | |
npz_dir = os.path.join(output_dir, 'npz') | |
print(f"Creating NPZ directory: {npz_dir}") | |
os.makedirs(npz_dir, exist_ok=True) | |
output_npz = os.path.join(npz_dir, os.path.basename(video).split('.')[0] + '.npz') | |
print(f"Saving NPZ to: {output_npz}") | |
np.savez_compressed(output_npz, reconstruction=prediction) | |
print(f"NPZ saved successfully: {output_npz}") | |
def arg_parse(): | |
parser = argparse.ArgumentParser('Generating skeleton demo.') | |
parser.add_argument('-v', '--video', type=str) | |
args = parser.parse_args() | |
return args | |
if __name__ == "__main__": | |
args = arg_parse() | |
if os.path.isabs(args.video): | |
video_path = args.video | |
else: | |
video_path = os.path.join(data_root, 'video', args.video) | |
generate_skeletons(video=video_path) |