Spaces:
Running
on
L4
Running
on
L4
File size: 5,705 Bytes
9ab094a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import os
import cv2
import time
import glob
import argparse
import numpy as np
from PIL import Image
import torch
from tqdm import tqdm
from itertools import cycle
from torch.multiprocessing import Pool, Process, set_start_method
from facexlib.alignment import landmark_98_to_68
from facexlib.detection import init_detection_model
from facexlib.utils import load_file_from_url
from facexlib.alignment.awing_arch import FAN
def init_alignment_model(model_name, half=False, device='cuda', model_rootpath=None):
if model_name == 'awing_fan':
model = FAN(num_modules=4, num_landmarks=98, device=device)
model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/alignment_WFLW_4HG.pth'
else:
raise NotImplementedError(f'{model_name} is not implemented.')
model_path = load_file_from_url(
url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath)
model.load_state_dict(torch.load(model_path, map_location=device)['state_dict'], strict=True)
model.eval()
model = model.to(device)
return model
class KeypointExtractor():
def __init__(self, device='cuda'):
### gfpgan/weights
try:
import webui # in webui
root_path = 'extensions/SadTalker/gfpgan/weights'
except:
root_path = 'gfpgan/weights'
self.detector = init_alignment_model('awing_fan',device=device, model_rootpath=root_path)
self.det_net = init_detection_model('retinaface_resnet50', half=False,device=device, model_rootpath=root_path)
def extract_keypoint(self, images, name=None, info=True):
if isinstance(images, list):
keypoints = []
if info:
i_range = tqdm(images,desc='landmark Det:')
else:
i_range = images
for image in i_range:
current_kp = self.extract_keypoint(image)
# current_kp = self.detector.get_landmarks(np.array(image))
if np.mean(current_kp) == -1 and keypoints:
keypoints.append(keypoints[-1])
else:
keypoints.append(current_kp[None])
keypoints = np.concatenate(keypoints, 0)
np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
return keypoints
else:
while True:
try:
with torch.no_grad():
# face detection -> face alignment.
img = np.array(images)
bboxes = self.det_net.detect_faces(images, 0.97)
bboxes = bboxes[0]
img = img[int(bboxes[1]):int(bboxes[3]), int(bboxes[0]):int(bboxes[2]), :]
keypoints = landmark_98_to_68(self.detector.get_landmarks(img)) # [0]
#### keypoints to the original location
keypoints[:,0] += int(bboxes[0])
keypoints[:,1] += int(bboxes[1])
break
except RuntimeError as e:
if str(e).startswith('CUDA'):
print("Warning: out of memory, sleep for 1s")
time.sleep(1)
else:
print(e)
break
except TypeError:
print('No face detected in this image')
shape = [68, 2]
keypoints = -1. * np.ones(shape)
break
if name is not None:
np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
return keypoints
def read_video(filename):
frames = []
cap = cv2.VideoCapture(filename)
while cap.isOpened():
ret, frame = cap.read()
if ret:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = Image.fromarray(frame)
frames.append(frame)
else:
break
cap.release()
return frames
def run(data):
filename, opt, device = data
os.environ['CUDA_VISIBLE_DEVICES'] = device
kp_extractor = KeypointExtractor()
images = read_video(filename)
name = filename.split('/')[-2:]
os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True)
kp_extractor.extract_keypoint(
images,
name=os.path.join(opt.output_dir, name[-2], name[-1])
)
if __name__ == '__main__':
set_start_method('spawn')
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--input_dir', type=str, help='the folder of the input files')
parser.add_argument('--output_dir', type=str, help='the folder of the output files')
parser.add_argument('--device_ids', type=str, default='0,1')
parser.add_argument('--workers', type=int, default=4)
opt = parser.parse_args()
filenames = list()
VIDEO_EXTENSIONS_LOWERCASE = {'mp4'}
VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE})
extensions = VIDEO_EXTENSIONS
for ext in extensions:
os.listdir(f'{opt.input_dir}')
print(f'{opt.input_dir}/*.{ext}')
filenames = sorted(glob.glob(f'{opt.input_dir}/*.{ext}'))
print('Total number of videos:', len(filenames))
pool = Pool(opt.workers)
args_list = cycle([opt])
device_ids = opt.device_ids.split(",")
device_ids = cycle(device_ids)
for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))):
None
|