text_to_speech_sync_video / Wav2Lip /get_face_detect.py
zmbfeng's picture
Wav2lip commits without big files
e22ac1d
raw
history blame
2.03 kB
from os import listdir, path
import numpy as np
import scipy, cv2, os, sys, argparse
import json, subprocess, random, string
from tqdm import tqdm
from glob import glob
import torch, face_detection
from models import Wav2Lip
import platform
import pickle
#import os
#os.system("!pip show Wav2Lip > /content/temp.txt")
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {} for inference.'.format(device))
def get_smoothened_boxes(boxes, T):
for i in range(len(boxes)):
if i + T > len(boxes):
window = boxes[len(boxes) - T:]
else:
window = boxes[i : i + T]
boxes[i] = np.mean(window, axis=0)
return boxes
def face_detect(images):
detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D,
flip_input=False, device=device)
batch_size = 16
while 1:
predictions = []
try:
for i in tqdm(range(0, len(images), batch_size)):
predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size])))
except RuntimeError:
if batch_size == 1:
raise RuntimeError('Image too big to run face detection on GPU. Please use the --resize_factor argument')
batch_size //= 2
print('Recovering from OOM error; New batch size: {}'.format(batch_size))
continue
break
results = []
for rect, image in zip(predictions, images):
if rect is None:
cv2.imwrite('temp/faulty_frame.jpg', image) # check this frame where the face was not detected.
raise ValueError('Face not detected! Ensure the video contains a face in all the frames.')
y1 = max(0, rect[1])
y2 = min(image.shape[0], rect[3])
x1 = max(0, rect[0])
x2 = min(image.shape[1], rect[2])
results.append([x1, y1, x2, y2])
boxes = np.array(results)
boxes = get_smoothened_boxes(boxes, T=5)
results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)]
del detector
with open('/content/gdrive/MyDrive/Avatar/chat_bot/Wav2Lip/new_face_det_result.pkl', 'wb') as file:
pickle.dump(results, file)
#return results