Sam / VideoToNPZ /gen_skes.py
Amanpreet
fixed
82d4b57
import torch
import sys
import os.path as osp
import os
import argparse
import cv2
import time
import h5py
from tqdm import tqdm
import numpy as np
import warnings
import signal
warnings.filterwarnings('ignore')
def signal_handler(sig, frame):
print("\nInterrupted by user, shutting down...")
if 'loader_thread' in globals() and loader_thread.is_alive():
loader_thread.join(timeout=1.0) # Give the thread 1 second to finish
if torch.cuda.is_available():
torch.cuda.empty_cache() # Free GPU memory immediately
os.exit(0)
# Register the signal handler
signal.signal(signal.SIGINT, signal_handler)
sys.path.insert(0, osp.dirname(osp.realpath(__file__)))
from tools.utils import get_path
from model.gast_net import SpatioTemporalModel, SpatioTemporalModelOptimized1f
from common.skeleton import Skeleton
from common.graph_utils import adj_mx_from_skeleton
from common.generators import *
from tools.preprocess import load_kpts_json, h36m_coco_format, revise_kpts, revise_skes
from tools.inference import gen_pose
from tools.vis_kpts import plot_keypoint
cur_dir, chk_root, data_root, lib_root, output_root = get_path(__file__)
model_dir = chk_root + 'gastnet/'
sys.path.insert(1, lib_root)
from lib.pose import gen_video_kpts as hrnet_pose
sys.path.pop(1)
sys.path.pop(0)
skeleton = Skeleton(parents=[-1, 0, 1, 2, 0, 4, 5, 0, 7, 8, 9, 8, 11, 12, 8, 14, 15],
joints_left=[4, 5, 6, 11, 12, 13], joints_right=[1, 2, 3, 14, 15, 16])
adj = adj_mx_from_skeleton(skeleton)
joints_left, joints_right = [4, 5, 6, 11, 12, 13], [1, 2, 3, 14, 15, 16]
kps_left, kps_right = [4, 5, 6, 11, 12, 13], [1, 2, 3, 14, 15, 16]
def load_model_layer():
chk = model_dir + '81_frame_model.bin'
filters_width = [3, 3, 3, 3]
channels = 64
model_pos = SpatioTemporalModel(adj, 17, 2, 17, filter_widths=filters_width, channels=channels, dropout=0.05)
checkpoint = torch.load(chk)
model_pos.load_state_dict(checkpoint['model_pos'])
if torch.cuda.is_available():
model_pos = model_pos.cuda()
model_pos = model_pos.eval()
return model_pos
def generate_skeletons(video=''):
def force_exit(sig, frame):
print("\nForce terminating...")
os._exit(1)
signal.signal(signal.SIGINT, force_exit)
cap = cv2.VideoCapture(video)
width = cap.get(cv2.CAP_PROP_FRAME_WIDTH)
height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
# 2D Keypoint Generation (handled by gen_video_kpts)
print('Generating 2D Keypoints:')
sys.stdout.flush()
keypoints, scores = hrnet_pose(video, det_dim=416, gen_output=True)
keypoints, scores, valid_frames = h36m_coco_format(keypoints, scores)
re_kpts = revise_kpts(keypoints, scores, valid_frames)
num_person = len(re_kpts)
model_pos = load_model_layer()
pad = (81 - 1) // 2
causal_shift = 0
# 3D Pose Generation
print('Recording 3D Pose:')
print(f"PROGRESS:100.00") # Start 3D at 100%
sys.stdout.flush()
total_valid_frames = len(valid_frames) if valid_frames else total_frames
prediction = gen_pose(re_kpts, valid_frames, width, height, model_pos, pad, causal_shift)
# Simulate 3D progress (replace with gen_pose loop if shared)
for i in range(total_valid_frames):
progress = 100 + ((i + 1) / total_valid_frames * 100) # 100-200% for 3D
print(f"PROGRESS:{progress:.2f}")
sys.stdout.flush()
time.sleep(0.01) # Placeholder; remove if gen_pose has its own loop
output_dir = os.path.abspath('../outputs/')
print(f"Creating output directory: {output_dir}")
os.makedirs(output_dir, exist_ok=True)
npz_dir = os.path.join(output_dir, 'npz')
print(f"Creating NPZ directory: {npz_dir}")
os.makedirs(npz_dir, exist_ok=True)
output_npz = os.path.join(npz_dir, os.path.basename(video).split('.')[0] + '.npz')
print(f"Saving NPZ to: {output_npz}")
np.savez_compressed(output_npz, reconstruction=prediction)
print(f"NPZ saved successfully: {output_npz}")
def arg_parse():
parser = argparse.ArgumentParser('Generating skeleton demo.')
parser.add_argument('-v', '--video', type=str)
args = parser.parse_args()
return args
if __name__ == "__main__":
args = arg_parse()
if os.path.isabs(args.video):
video_path = args.video
else:
video_path = os.path.join(data_root, 'video', args.video)
generate_skeletons(video=video_path)