ndkhanh95's picture
Upload 226 files
29d411b verified
import argparse
import json
import os
import time
from PIL import Image
import cv2
import numpy as np
import torch
import tqdm
from easy_ViTPose.vit_utils.inference import NumpyEncoder, VideoReader
from easy_ViTPose.inference import VitInference
from easy_ViTPose.vit_utils.visualization import joints_dict
try:
import onnxruntime # noqa: F401
has_onnx = True
except ModuleNotFoundError:
has_onnx = False
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--input', type=str, required=True,
help='path to image / video or webcam ID (=cv2)')
parser.add_argument('--output-path', type=str, default='',
help='output path, if the path provided is a directory '
'output files are "input_name +_result{extension}".')
parser.add_argument('--model', type=str, required=True,
help='checkpoint path of the model')
parser.add_argument('--yolo', type=str, required=False, default=None,
help='checkpoint path of the yolo model')
parser.add_argument('--dataset', type=str, required=False, default=None,
help='Name of the dataset. If None it"s extracted from the file name. \
["coco", "coco_25", "wholebody", "mpii", "ap10k", "apt36k", "aic"]')
parser.add_argument('--det-class', type=str, required=False, default=None,
help='["human", "cat", "dog", "horse", "sheep", \
"cow", "elephant", "bear", "zebra", "giraffe", "animals"]')
parser.add_argument('--model-name', type=str, required=False, choices=['s', 'b', 'l', 'h'],
help='[s: ViT-S, b: ViT-B, l: ViT-L, h: ViT-H]')
parser.add_argument('--yolo-size', type=int, required=False, default=320,
help='YOLOv8 image size during inference')
parser.add_argument('--conf-threshold', type=float, required=False, default=0.5,
help='Minimum confidence for keypoints to be drawn. [0, 1] range')
parser.add_argument('--rotate', type=int, choices=[0, 90, 180, 270],
required=False, default=0,
help='Rotate the image of [90, 180, 270] degress counterclockwise')
parser.add_argument('--yolo-step', type=int,
required=False, default=1,
help='The tracker can be used to predict the bboxes instead of yolo for performance, '
'this flag specifies how often yolo is applied (e.g. 1 applies yolo every frame). '
'This does not have any effect when is_video is False')
parser.add_argument('--single-pose', default=False, action='store_true',
help='Do not use SORT tracker because single pose is expected in the video')
parser.add_argument('--show', default=False, action='store_true',
help='preview result during inference')
parser.add_argument('--show-yolo', default=False, action='store_true',
help='draw yolo results')
parser.add_argument('--show-raw-yolo', default=False, action='store_true',
help='draw yolo result before that SORT is applied for tracking'
' (only valid during video inference)')
parser.add_argument('--save-img', default=False, action='store_true',
help='save image results')
parser.add_argument('--save-json', default=False, action='store_true',
help='save json results')
args = parser.parse_args()
use_mps = hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
use_cuda = torch.cuda.is_available()
# Load Yolo
yolo = args.yolo
if yolo is None:
yolo = 'easy_ViTPose/' + ('yolov8s' + ('.onnx' if has_onnx and not (use_mps or use_cuda) else '.pt'))
input_path = args.input
# Load the image / video reader
try: # Check if is webcam
int(input_path)
is_video = True
except ValueError:
assert os.path.isfile(input_path), 'The input file does not exist'
is_video = input_path[input_path.rfind('.') + 1:].lower() in ['mp4', 'mov']
ext = '.mp4' if is_video else '.png'
assert not (args.save_img or args.save_json) or args.output_path, \
'Specify an output path if using save-img or save-json flags'
output_path = args.output_path
if output_path:
if os.path.isdir(output_path):
og_ext = input_path[input_path.rfind('.'):]
save_name_img = os.path.basename(input_path).replace(og_ext, f"_result{ext}")
save_name_json = os.path.basename(input_path).replace(og_ext, "_result.json")
output_path_img = os.path.join(output_path, save_name_img)
output_path_json = os.path.join(output_path, save_name_json)
else:
output_path_img = output_path + f'{ext}'
output_path_json = output_path + '.json'
wait = 0
total_frames = 1
if is_video:
reader = VideoReader(input_path, args.rotate)
cap = cv2.VideoCapture(input_path) # type: ignore
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
cap.release()
wait = 15
if args.save_img:
cap = cv2.VideoCapture(input_path) # type: ignore
fps = cap.get(cv2.CAP_PROP_FPS)
ret, frame = cap.read()
cap.release()
assert ret
assert fps > 0
output_size = frame.shape[:2][::-1]
# Check if we have X264 otherwise use default MJPG
try:
temp_video = cv2.VideoWriter('/tmp/checkcodec.mp4',
cv2.VideoWriter_fourcc(*'h264'), 30, (32, 32))
opened = temp_video.isOpened()
except Exception:
opened = False
codec = 'h264' if opened else 'MJPG'
out_writer = cv2.VideoWriter(output_path_img,
cv2.VideoWriter_fourcc(*codec), # More efficient codec
fps, output_size) # type: ignore
else:
reader = [np.array(Image.open(input_path).rotate(args.rotate))] # type: ignore
# Initialize model
model = VitInference(args.model, yolo, args.model_name,
args.det_class, args.dataset,
args.yolo_size, is_video=is_video,
single_pose=args.single_pose,
yolo_step=args.yolo_step) # type: ignore
print(f">>> Model loaded: {args.model}")
print(f'>>> Running inference on {input_path}')
keypoints = []
fps = []
tot_time = 0.
for (ith, img) in tqdm.tqdm(enumerate(reader), total=total_frames):
t0 = time.time()
# Run inference
frame_keypoints = model.inference(img)
keypoints.append(frame_keypoints)
delta = time.time() - t0
tot_time += delta
fps.append(delta)
# Draw the poses and save the output img
if args.show or args.save_img:
# Draw result and transform to BGR
img = model.draw(args.show_yolo, args.show_raw_yolo, args.conf_threshold)[..., ::-1]
if args.save_img:
# TODO: If exists add (1), (2), ...
if is_video:
out_writer.write(img)
else:
print('>>> Saving output image')
cv2.imwrite(output_path_img, img)
if args.show:
cv2.imshow('preview', img)
cv2.waitKey(wait)
if is_video:
tot_poses = sum(len(k) for k in keypoints)
print(f'>>> Mean inference FPS: {1 / np.mean(fps):.2f}')
print(f'>>> Total poses predicted: {tot_poses} mean per frame: '
f'{(tot_poses / (ith + 1)):.2f}')
print(f'>>> Mean FPS per pose: {(tot_poses / tot_time):.2f}')
if args.save_json:
print('>>> Saving output json')
with open(output_path_json, 'w') as f:
out = {'keypoints': keypoints,
'skeleton': joints_dict()[model.dataset]['keypoints']}
json.dump(out, f, cls=NumpyEncoder)
if is_video and args.save_img:
out_writer.release()
cv2.destroyAllWindows()