Sam / VideoToNPZ /gen_skes.py
Amanpreet
latest
4e2aa43
raw
history blame
4.18 kB
import torch
import sys
import os.path as osp
import os
import argparse
import cv2
import time
import h5py
from tqdm import tqdm
import numpy as np
import warnings
import signal
warnings.filterwarnings('ignore')
sys.path.insert(0, osp.dirname(osp.realpath(__file__)))
from tools.utils import get_path
from model.gast_net import SpatioTemporalModel, SpatioTemporalModelOptimized1f
from common.skeleton import Skeleton
from common.graph_utils import adj_mx_from_skeleton
from common.generators import *
from tools.preprocess import load_kpts_json, h36m_coco_format, revise_kpts, revise_skes
from tools.inference import gen_pose
from tools.vis_kpts import plot_keypoint
cur_dir, chk_root, data_root, lib_root, output_root = get_path(__file__)
model_dir = chk_root + 'gastnet/'
sys.path.insert(1, lib_root)
from lib.pose import gen_video_kpts as hrnet_pose
sys.path.pop(1)
sys.path.pop(0)
skeleton = Skeleton(parents=[-1, 0, 1, 2, 0, 4, 5, 0, 7, 8, 9, 8, 11, 12, 8, 14, 15],
joints_left=[4, 5, 6, 11, 12, 13], joints_right=[1, 2, 3, 14, 15, 16])
adj = adj_mx_from_skeleton(skeleton)
joints_left, joints_right = [4, 5, 6, 11, 12, 13], [1, 2, 3, 14, 15, 16]
kps_left, kps_right = [4, 5, 6, 11, 12, 13], [1, 2, 3, 14, 15, 16]
def signal_handler(sig, frame):
print("\nInterrupted by user, shutting down...")
if 'pool' in locals() and pool is not None:
pool.terminate()
pool.join()
sys.exit(0)
signal.signal(signal.SIGINT, signal_handler)
def load_model_layer():
chk = model_dir + '81_frame_model.bin'
filters_width = [3, 3, 3, 3]
channels = 64
model_pos = SpatioTemporalModel(adj, 17, 2, 17, filter_widths=filters_width, channels=channels, dropout=0.05)
checkpoint = torch.load(chk)
model_pos.load_state_dict(checkpoint['model_pos'])
if torch.cuda.is_available():
model_pos = model_pos.cuda()
model_pos = model_pos.eval()
return model_pos
def generate_skeletons(video=''):
cap = cv2.VideoCapture(video)
width = cap.get(cv2.CAP_PROP_FRAME_WIDTH)
height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
# 2D Keypoint Generation (handled by gen_video_kpts)
print('Generating 2D Keypoints:')
sys.stdout.flush()
keypoints, scores = hrnet_pose(video, det_dim=416, gen_output=True)
keypoints, scores, valid_frames = h36m_coco_format(keypoints, scores)
re_kpts = revise_kpts(keypoints, scores, valid_frames)
num_person = len(re_kpts)
model_pos = load_model_layer()
pad = (81 - 1) // 2
causal_shift = 0
# 3D Pose Generation
print('Recording 3D Pose:')
print(f"PROGRESS:100.00") # Start 3D at 100%
sys.stdout.flush()
total_valid_frames = len(valid_frames) if valid_frames else total_frames
prediction = gen_pose(re_kpts, valid_frames, width, height, model_pos, pad, causal_shift)
# Simulate 3D progress (replace with gen_pose loop if shared)
for i in range(total_valid_frames):
progress = 100 + ((i + 1) / total_valid_frames * 100) # 100-200% for 3D
print(f"PROGRESS:{progress:.2f}")
sys.stdout.flush()
time.sleep(0.01) # Placeholder; remove if gen_pose has its own loop
output_dir = os.path.abspath('../outputs/')
print(f"Creating output directory: {output_dir}")
os.makedirs(output_dir, exist_ok=True)
npz_dir = os.path.join(output_dir, 'npz')
print(f"Creating NPZ directory: {npz_dir}")
os.makedirs(npz_dir, exist_ok=True)
output_npz = os.path.join(npz_dir, os.path.basename(video).split('.')[0] + '.npz')
print(f"Saving NPZ to: {output_npz}")
np.savez_compressed(output_npz, reconstruction=prediction)
print(f"NPZ saved successfully: {output_npz}")
def arg_parse():
parser = argparse.ArgumentParser('Generating skeleton demo.')
parser.add_argument('-v', '--video', type=str)
args = parser.parse_args()
return args
if __name__ == "__main__":
args = arg_parse()
if os.path.isabs(args.video):
video_path = args.video
else:
video_path = os.path.join(data_root, 'video', args.video)
generate_skeletons(video=video_path)