Spaces:
Sleeping
Sleeping
import os | |
from os.path import join | |
import argparse | |
import numpy as np | |
import cv2 | |
import torch | |
from tqdm import tqdm | |
from data import cfg_mnet, cfg_re50 | |
from layers.functions.prior_box import PriorBox | |
from utils.nms.py_cpu_nms import py_cpu_nms | |
from models.retinaface import RetinaFace | |
from utils.box_utils import decode | |
np.random.seed(0) | |
def check_keys(model, pretrained_state_dict): | |
ckpt_keys = set(pretrained_state_dict.keys()) | |
model_keys = set(model.state_dict().keys()) | |
used_pretrained_keys = model_keys & ckpt_keys | |
unused_pretrained_keys = ckpt_keys - model_keys | |
missing_keys = model_keys - ckpt_keys | |
print('Missing keys:{}'.format(len(missing_keys))) | |
print('Unused checkpoint keys:{}'.format(len(unused_pretrained_keys))) | |
print('Used keys:{}'.format(len(used_pretrained_keys))) | |
assert len(used_pretrained_keys) > 0, 'load NONE from pretrained checkpoint' | |
return True | |
def remove_prefix(state_dict, prefix): | |
''' Old style model is stored with all names of parameters sharing common prefix 'module.' ''' | |
print('remove prefix \'{}\''.format(prefix)) | |
def f(x): return x.split(prefix, 1)[-1] if x.startswith(prefix) else x | |
return {f(key): value for key, value in state_dict.items()} | |
def load_model(model, pretrained_path, load_to_cpu): | |
print('Loading pretrained model from {}'.format(pretrained_path)) | |
if load_to_cpu: | |
pretrained_dict = torch.load( | |
pretrained_path, map_location=lambda storage, loc: storage) | |
else: | |
pretrained_dict = torch.load( | |
pretrained_path, map_location=lambda storage, loc: storage.cuda(device)) | |
if "state_dict" in pretrained_dict.keys(): | |
pretrained_dict = remove_prefix( | |
pretrained_dict['state_dict'], 'module.') | |
else: | |
pretrained_dict = remove_prefix(pretrained_dict, 'module.') | |
check_keys(model, pretrained_dict) | |
model.load_state_dict(pretrained_dict, strict=False) | |
model.to(device) | |
return model | |
def detect(img_list, output_path, resize=1): | |
os.makedirs(output_path, exist_ok=True) | |
im_height, im_width, _ = img_list[0].shape | |
scale = torch.Tensor([im_width, im_height, im_width, im_height]) | |
img_x = torch.stack(img_list, dim=0).permute([0, 3, 1, 2]) | |
scale = scale.to(device) | |
# batch size | |
batch_size = args.bs | |
# forward times | |
f_times = img_x.shape[0] // batch_size | |
if img_x.shape[0] % batch_size != 0: | |
f_times += 1 | |
locs_list = list() | |
confs_list = list() | |
for _ in range(f_times): | |
if _ != f_times - 1: | |
batch_img_x = img_x[_ * batch_size:(_ + 1) * batch_size] | |
else: | |
batch_img_x = img_x[_ * batch_size:] # last batch | |
batch_img_x = batch_img_x.to(device).float() | |
l, c, _ = net(batch_img_x) | |
locs_list.append(l) | |
confs_list.append(c) | |
locs = torch.cat(locs_list, dim=0) | |
confs = torch.cat(confs_list, dim=0) | |
priorbox = PriorBox(cfg, image_size=(im_height, im_width)) | |
priors = priorbox.forward() | |
priors = priors.to(device) | |
prior_data = priors.data | |
img_cpu = img_x.permute([0, 2, 3, 1]).cpu().numpy() | |
i = 0 | |
for img, loc, conf in zip(img_cpu, locs, confs): | |
boxes = decode(loc.data, prior_data, cfg['variance']) | |
boxes = boxes * scale / resize | |
boxes = boxes.cpu().numpy() | |
scores = conf.data.cpu().numpy()[:, 1] | |
# ignore low scores | |
inds = np.where(scores > args.confidence_threshold)[0] | |
boxes = boxes[inds] | |
scores = scores[inds] | |
# keep top-K before NMS | |
order = scores.argsort()[::-1][:args.top_k] | |
boxes = boxes[order] | |
scores = scores[order] | |
# do NMS | |
dets = np.hstack((boxes, scores[:, np.newaxis])).astype( | |
np.float32, copy=False) | |
keep = py_cpu_nms(dets, args.nms_threshold) | |
# keep = nms(dets, args.nms_threshold,force_cpu=args.cpu) | |
dets = dets[keep, :] | |
# keep top-K faster NMS | |
dets = dets[:args.keep_top_k, :] | |
if len(dets) == 0: | |
continue | |
det = list(map(int, dets[0])) | |
x, y, size_bb_x, size_bb_y = get_boundingbox(det, img.shape[1], img.shape[0]) | |
cropped_img = img[y:y + size_bb_y, x:x + size_bb_x, :] + (104, 117, 123) | |
cv2.imwrite(join(output_path, '{:04d}.png'.format(i)), cropped_img) | |
i += 1 | |
pass | |
def extract_frames(data_path, interval=1): | |
"""Method to extract frames""" | |
if data_path.split('.')[-1] == "mp4": | |
reader = cv2.VideoCapture(data_path) | |
frame_num = 0 | |
frames = list() | |
while reader.isOpened(): | |
success, image = reader.read() | |
if not success: | |
break | |
cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
image = torch.tensor(image) - torch.tensor([104, 117, 123]) | |
if frame_num % interval == 0: | |
frames.append(image) | |
frame_num += 1 | |
if len(frames) > args.max_frames: | |
break | |
reader.release() | |
if len(frames) > args.max_frames: | |
samples = np.random.choice( | |
np.arange(0, len(frames)), size=args.max_frames, replace=False) | |
return [frames[_] for _ in samples] | |
return frames | |
else: | |
image = cv2.imread(data_path) | |
cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
image = torch.tensor(image) - torch.tensor([104, 117, 123]) | |
return [image] | |
def get_boundingbox(bbox, width, height, scale=1.8, minsize=None): | |
x1 = bbox[0] | |
y1 = bbox[1] | |
x2 = bbox[2] | |
y2 = bbox[3] | |
size_bb_x = int((x2 - x1) * scale) | |
size_bb_y = int((y2 - y1) * scale) | |
if minsize: | |
if size_bb_x < minsize: | |
size_bb_x = minsize | |
if size_bb_y < minsize: | |
size_bb_y = minsize | |
center_x, center_y = (x1 + x2) // 2, (y1 + y2) // 2 | |
# Check for out of bounds, x-y top left corner | |
x1 = max(int(center_x - size_bb_x // 2), 0) | |
y1 = max(int(center_y - size_bb_y // 2), 0) | |
# Check for too big bb size for given x, y | |
size_bb_x = min(width - x1, size_bb_x) | |
size_bb_y = min(height - y1, size_bb_y) | |
return x1, y1, size_bb_x, size_bb_y | |
def extract_method_videos(data_path, interval): | |
video = data_path.split('/')[-1] | |
result_path = '/'.join(data_path.split('/')[:-1]) | |
images_path = join(result_path, 'images') | |
image_folder = video.split('.')[0] | |
try: | |
print(data_path) | |
image_list = extract_frames(data_path, interval) | |
detect(image_list, join(images_path, image_folder)) | |
except Exception as ex: | |
f = open("failure.txt", "a", encoding="utf-8") | |
f.writelines(image_folder + | |
f" Exception for {image_folder}: {ex}\n") | |
f.close() | |
if __name__ == '__main__': | |
p = argparse.ArgumentParser( | |
formatter_class=argparse.ArgumentDefaultsHelpFormatter | |
) | |
p.add_argument('--data_path', '-p', type=str, help='path to the data') | |
p.add_argument('--confidence_threshold', default=0.05, | |
type=float, help='confidence threshold') | |
p.add_argument('--top_k', default=5, type=int, help='top_k') | |
p.add_argument('--nms_threshold', default=0.4, | |
type=float, help='nms threshold') | |
p.add_argument('--keep_top_k', default=1, type=int, help='keep_top_k') | |
p.add_argument('--bs', default=32, type=int, help='batch size') | |
p.add_argument('--frame_interval', '-fi', default=1, type=int, help='frame interval') | |
p.add_argument('--device', "-d", default="cuda:0", type=str, help='device') | |
p.add_argument('--max_frames', default=100, type=int, help='maximum frames per video') | |
args = p.parse_args() | |
torch.set_grad_enabled(False) | |
# use resnet-50 | |
cfg = cfg_re50 | |
pretrained_weights = './weights/Resnet50_Final.pth' | |
torch.backends.cudnn.benchmark = True | |
device = torch.device(args.device) | |
print(device) | |
# net and model | |
net = RetinaFace(cfg=cfg, phase='test') | |
net = load_model(net, pretrained_weights, args.device) | |
net.eval() | |
print('Finished loading model!') | |
extract_method_videos(args.data_path, args.frame_interval) |