Lipsing / inference.py
ibrahim313's picture
Upload 12 files
24bde82 verified
print("\rloading torch ", end="")
import torch
print("\rloading numpy ", end="")
import numpy as np
print("\rloading Image ", end="")
from PIL import Image
print("\rloading argparse ", end="")
import argparse
print("\rloading configparser", end="")
import configparser
print("\rloading math ", end="")
import math
print("\rloading os ", end="")
import os
print("\rloading subprocess ", end="")
import subprocess
print("\rloading pickle ", end="")
import pickle
print("\rloading cv2 ", end="")
import cv2
print("\rloading audio ", end="")
import audio
print("\rloading RetinaFace ", end="")
from batch_face import RetinaFace
print("\rloading re ", end="")
import re
print("\rloading partial ", end="")
from functools import partial
print("\rloading tqdm ", end="")
from tqdm import tqdm
print("\rloading warnings ", end="")
import warnings
warnings.filterwarnings(
"ignore", category=UserWarning, module="torchvision.transforms.functional_tensor"
)
print("\rloading upscale ", end="")
from enhance import upscale
print("\rloading load_sr ", end="")
from enhance import load_sr
print("\rloading load_model ", end="")
from easy_functions import load_model, g_colab
print("\rimports loaded! ")
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
gpu_id = 0 if torch.cuda.is_available() else -1
if device == 'cpu':
print('Warning: No GPU detected so inference will be done on the CPU which is VERY SLOW!')
parser = argparse.ArgumentParser(
description="Inference code to lip-sync videos in the wild using Wav2Lip models"
)
parser.add_argument(
"--checkpoint_path",
type=str,
help="Name of saved checkpoint to load weights from",
required=True,
)
parser.add_argument(
"--segmentation_path",
type=str,
default="checkpoints/face_segmentation.pth",
help="Name of saved checkpoint of segmentation network",
required=False,
)
parser.add_argument(
"--face",
type=str,
help="Filepath of video/image that contains faces to use",
required=True,
)
parser.add_argument(
"--audio",
type=str,
help="Filepath of video/audio file to use as raw audio source",
required=True,
)
parser.add_argument(
"--outfile",
type=str,
help="Video path to save result. See default for an e.g.",
default="results/result_voice.mp4",
)
parser.add_argument(
"--static",
type=bool,
help="If True, then use only first video frame for inference",
default=False,
)
parser.add_argument(
"--fps",
type=float,
help="Can be specified only if input is a static image (default: 25)",
default=25.0,
required=False,
)
parser.add_argument(
"--pads",
nargs="+",
type=int,
default=[0, 10, 0, 0],
help="Padding (top, bottom, left, right). Please adjust to include chin at least",
)
parser.add_argument(
"--wav2lip_batch_size", type=int, help="Batch size for Wav2Lip model(s)", default=1
)
parser.add_argument(
"--out_height",
default=480,
type=int,
help="Output video height. Best results are obtained at 480 or 720",
)
parser.add_argument(
"--crop",
nargs="+",
type=int,
default=[0, -1, 0, -1],
help="Crop video to a smaller region (top, bottom, left, right). Applied after resize_factor and rotate arg. "
"Useful if multiple face present. -1 implies the value will be auto-inferred based on height, width",
)
parser.add_argument(
"--box",
nargs="+",
type=int,
default=[-1, -1, -1, -1],
help="Specify a constant bounding box for the face. Use only as a last resort if the face is not detected."
"Also, might work only if the face is not moving around much. Syntax: (top, bottom, left, right).",
)
parser.add_argument(
"--rotate",
default=False,
action="store_true",
help="Sometimes videos taken from a phone can be flipped 90deg. If true, will flip video right by 90deg."
"Use if you get a flipped result, despite feeding a normal looking video",
)
parser.add_argument(
"--nosmooth",
type=str,
default=False,
help="Prevent smoothing face detections over a short temporal window",
)
parser.add_argument(
"--no_seg",
default=False,
action="store_true",
help="Prevent using face segmentation",
)
parser.add_argument(
"--no_sr", default=False, action="store_true", help="Prevent using super resolution"
)
parser.add_argument(
"--sr_model",
type=str,
default="gfpgan",
help="Name of upscaler - gfpgan or RestoreFormer",
required=False,
)
parser.add_argument(
"--fullres",
default=3,
type=int,
help="used only to determine if full res is used so that no resizing needs to be done if so",
)
parser.add_argument(
"--debug_mask",
type=str,
default=False,
help="Makes background grayscale to see the mask better",
)
parser.add_argument(
"--preview_settings", type=str, default=False, help="Processes only one frame"
)
parser.add_argument(
"--mouth_tracking",
type=str,
default=False,
help="Tracks the mouth in every frame for the mask",
)
parser.add_argument(
"--mask_dilation",
default=150,
type=float,
help="size of mask around mouth",
required=False,
)
parser.add_argument(
"--mask_feathering",
default=151,
type=int,
help="amount of feathering of mask around mouth",
required=False,
)
parser.add_argument(
"--quality",
type=str,
help="Choose between Fast, Improved and Enhanced",
default="Fast",
)
with open(os.path.join("checkpoints", "predictor.pkl"), "rb") as f:
predictor = pickle.load(f)
with open(os.path.join("checkpoints", "mouth_detector.pkl"), "rb") as f:
mouth_detector = pickle.load(f)
# creating variables to prevent failing when a face isn't detected
kernel = last_mask = x = y = w = h = None
g_colab = g_colab()
if not g_colab:
# Load the config file
config = configparser.ConfigParser()
config.read('config.ini')
# Get the value of the "preview_window" variable
preview_window = config.get('OPTIONS', 'preview_window')
all_mouth_landmarks = []
model = detector = detector_model = None
def do_load(checkpoint_path):
global model, detector, detector_model
model = load_model(checkpoint_path)
detector = RetinaFace(
gpu_id=gpu_id, model_path="checkpoints/mobilenet.pth", network="mobilenet"
)
detector_model = detector.model
def face_rect(images):
face_batch_size = 8
num_batches = math.ceil(len(images) / face_batch_size)
prev_ret = None
for i in range(num_batches):
batch = images[i * face_batch_size : (i + 1) * face_batch_size]
all_faces = detector(batch) # return faces list of all images
for faces in all_faces:
if faces:
box, landmarks, score = faces[0]
prev_ret = tuple(map(int, box))
yield prev_ret
def create_tracked_mask(img, original_img):
global kernel, last_mask, x, y, w, h # Add last_mask to global variables
# Convert color space from BGR to RGB if necessary
cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img)
cv2.cvtColor(original_img, cv2.COLOR_BGR2RGB, original_img)
# Detect face
faces = mouth_detector(img)
if len(faces) == 0:
if last_mask is not None:
last_mask = cv2.resize(last_mask, (img.shape[1], img.shape[0]))
mask = last_mask # use the last successful mask
else:
cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img)
return img, None
else:
face = faces[0]
shape = predictor(img, face)
# Get points for mouth
mouth_points = np.array(
[[shape.part(i).x, shape.part(i).y] for i in range(48, 68)]
)
# Calculate bounding box dimensions
x, y, w, h = cv2.boundingRect(mouth_points)
# Set kernel size as a fraction of bounding box size
kernel_size = int(max(w, h) * args.mask_dilation)
# if kernel_size % 2 == 0: # Ensure kernel size is odd
# kernel_size += 1
# Create kernel
kernel = np.ones((kernel_size, kernel_size), np.uint8)
# Create binary mask for mouth
mask = np.zeros(img.shape[:2], dtype=np.uint8)
cv2.fillConvexPoly(mask, mouth_points, 255)
last_mask = mask # Update last_mask with the new mask
# Dilate the mask
dilated_mask = cv2.dilate(mask, kernel)
# Calculate distance transform of dilated mask
dist_transform = cv2.distanceTransform(dilated_mask, cv2.DIST_L2, 5)
# Normalize distance transform
cv2.normalize(dist_transform, dist_transform, 0, 255, cv2.NORM_MINMAX)
# Convert normalized distance transform to binary mask and convert it to uint8
_, masked_diff = cv2.threshold(dist_transform, 50, 255, cv2.THRESH_BINARY)
masked_diff = masked_diff.astype(np.uint8)
# make sure blur is an odd number
blur = args.mask_feathering
if blur % 2 == 0:
blur += 1
# Set blur size as a fraction of bounding box size
blur = int(max(w, h) * blur) # 10% of bounding box size
if blur % 2 == 0: # Ensure blur size is odd
blur += 1
masked_diff = cv2.GaussianBlur(masked_diff, (blur, blur), 0)
# Convert numpy arrays to PIL Images
input1 = Image.fromarray(img)
input2 = Image.fromarray(original_img)
# Convert mask to single channel where pixel values are from the alpha channel of the current mask
mask = Image.fromarray(masked_diff)
# Ensure images are the same size
assert input1.size == input2.size == mask.size
# Paste input1 onto input2 using the mask
input2.paste(input1, (0, 0), mask)
# Convert the final PIL Image back to a numpy array
input2 = np.array(input2)
# input2 = cv2.cvtColor(input2, cv2.COLOR_BGR2RGB)
cv2.cvtColor(input2, cv2.COLOR_BGR2RGB, input2)
return input2, mask
def create_mask(img, original_img):
global kernel, last_mask, x, y, w, h # Add last_mask to global variables
# Convert color space from BGR to RGB if necessary
cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img)
cv2.cvtColor(original_img, cv2.COLOR_BGR2RGB, original_img)
if last_mask is not None:
last_mask = np.array(last_mask) # Convert PIL Image to numpy array
last_mask = cv2.resize(last_mask, (img.shape[1], img.shape[0]))
mask = last_mask # use the last successful mask
mask = Image.fromarray(mask)
else:
# Detect face
faces = mouth_detector(img)
if len(faces) == 0:
cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img)
return img, None
else:
face = faces[0]
shape = predictor(img, face)
# Get points for mouth
mouth_points = np.array(
[[shape.part(i).x, shape.part(i).y] for i in range(48, 68)]
)
# Calculate bounding box dimensions
x, y, w, h = cv2.boundingRect(mouth_points)
# Set kernel size as a fraction of bounding box size
kernel_size = int(max(w, h) * args.mask_dilation)
# if kernel_size % 2 == 0: # Ensure kernel size is odd
# kernel_size += 1
# Create kernel
kernel = np.ones((kernel_size, kernel_size), np.uint8)
# Create binary mask for mouth
mask = np.zeros(img.shape[:2], dtype=np.uint8)
cv2.fillConvexPoly(mask, mouth_points, 255)
# Dilate the mask
dilated_mask = cv2.dilate(mask, kernel)
# Calculate distance transform of dilated mask
dist_transform = cv2.distanceTransform(dilated_mask, cv2.DIST_L2, 5)
# Normalize distance transform
cv2.normalize(dist_transform, dist_transform, 0, 255, cv2.NORM_MINMAX)
# Convert normalized distance transform to binary mask and convert it to uint8
_, masked_diff = cv2.threshold(dist_transform, 50, 255, cv2.THRESH_BINARY)
masked_diff = masked_diff.astype(np.uint8)
if not args.mask_feathering == 0:
blur = args.mask_feathering
# Set blur size as a fraction of bounding box size
blur = int(max(w, h) * blur) # 10% of bounding box size
if blur % 2 == 0: # Ensure blur size is odd
blur += 1
masked_diff = cv2.GaussianBlur(masked_diff, (blur, blur), 0)
# Convert mask to single channel where pixel values are from the alpha channel of the current mask
mask = Image.fromarray(masked_diff)
last_mask = mask # Update last_mask with the final mask after dilation and feathering
# Convert numpy arrays to PIL Images
input1 = Image.fromarray(img)
input2 = Image.fromarray(original_img)
# Resize mask to match image size
# mask = Image.fromarray(mask)
mask = mask.resize(input1.size)
# Ensure images are the same size
assert input1.size == input2.size == mask.size
# Paste input1 onto input2 using the mask
input2.paste(input1, (0, 0), mask)
# Convert the final PIL Image back to a numpy array
input2 = np.array(input2)
# input2 = cv2.cvtColor(input2, cv2.COLOR_BGR2RGB)
cv2.cvtColor(input2, cv2.COLOR_BGR2RGB, input2)
return input2, mask
def get_smoothened_boxes(boxes, T):
for i in range(len(boxes)):
if i + T > len(boxes):
window = boxes[len(boxes) - T :]
else:
window = boxes[i : i + T]
boxes[i] = np.mean(window, axis=0)
return boxes
def face_detect(images, results_file="last_detected_face.pkl"):
# If results file exists, load it and return
if os.path.exists(results_file):
print("Using face detection data from last input")
with open(results_file, "rb") as f:
return pickle.load(f)
results = []
pady1, pady2, padx1, padx2 = args.pads
tqdm_partial = partial(tqdm, position=0, leave=True)
for image, (rect) in tqdm_partial(
zip(images, face_rect(images)),
total=len(images),
desc="detecting face in every frame",
ncols=100,
):
if rect is None:
cv2.imwrite(
"temp/faulty_frame.jpg", image
) # check this frame where the face was not detected.
raise ValueError(
"Face not detected! Ensure the video contains a face in all the frames."
)
y1 = max(0, rect[1] - pady1)
y2 = min(image.shape[0], rect[3] + pady2)
x1 = max(0, rect[0] - padx1)
x2 = min(image.shape[1], rect[2] + padx2)
results.append([x1, y1, x2, y2])
boxes = np.array(results)
if str(args.nosmooth) == "False":
boxes = get_smoothened_boxes(boxes, T=5)
results = [
[image[y1:y2, x1:x2], (y1, y2, x1, x2)]
for image, (x1, y1, x2, y2) in zip(images, boxes)
]
# Save results to file
with open(results_file, "wb") as f:
pickle.dump(results, f)
return results
def datagen(frames, mels):
img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
print("\r" + " " * 100, end="\r")
if args.box[0] == -1:
if not args.static:
face_det_results = face_detect(frames) # BGR2RGB for CNN face detection
else:
face_det_results = face_detect([frames[0]])
else:
print("Using the specified bounding box instead of face detection...")
y1, y2, x1, x2 = args.box
face_det_results = [[f[y1:y2, x1:x2], (y1, y2, x1, x2)] for f in frames]
for i, m in enumerate(mels):
idx = 0 if args.static else i % len(frames)
frame_to_save = frames[idx].copy()
face, coords = face_det_results[idx].copy()
face = cv2.resize(face, (args.img_size, args.img_size))
img_batch.append(face)
mel_batch.append(m)
frame_batch.append(frame_to_save)
coords_batch.append(coords)
if len(img_batch) >= args.wav2lip_batch_size:
img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
img_masked = img_batch.copy()
img_masked[:, args.img_size // 2 :] = 0
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.0
mel_batch = np.reshape(
mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]
)
yield img_batch, mel_batch, frame_batch, coords_batch
img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
if len(img_batch) > 0:
img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
img_masked = img_batch.copy()
img_masked[:, args.img_size // 2 :] = 0
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.0
mel_batch = np.reshape(
mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]
)
yield img_batch, mel_batch, frame_batch, coords_batch
mel_step_size = 16
def _load(checkpoint_path):
if device != "cpu":
checkpoint = torch.load(checkpoint_path)
else:
checkpoint = torch.load(
checkpoint_path, map_location=lambda storage, loc: storage
)
return checkpoint
def main():
args.img_size = 96
frame_number = 11
if os.path.isfile(args.face) and args.face.split(".")[1] in ["jpg", "png", "jpeg"]:
args.static = True
if not os.path.isfile(args.face):
raise ValueError("--face argument must be a valid path to video/image file")
elif args.face.split(".")[1] in ["jpg", "png", "jpeg"]:
full_frames = [cv2.imread(args.face)]
fps = args.fps
else:
if args.fullres != 1:
print("Resizing video...")
video_stream = cv2.VideoCapture(args.face)
fps = video_stream.get(cv2.CAP_PROP_FPS)
full_frames = []
while 1:
still_reading, frame = video_stream.read()
if not still_reading:
video_stream.release()
break
if args.fullres != 1:
aspect_ratio = frame.shape[1] / frame.shape[0]
frame = cv2.resize(
frame, (int(args.out_height * aspect_ratio), args.out_height)
)
if args.rotate:
frame = cv2.rotate(frame, cv2.cv2.ROTATE_90_CLOCKWISE)
y1, y2, x1, x2 = args.crop
if x2 == -1:
x2 = frame.shape[1]
if y2 == -1:
y2 = frame.shape[0]
frame = frame[y1:y2, x1:x2]
full_frames.append(frame)
if not args.audio.endswith(".wav"):
print("Converting audio to .wav")
subprocess.check_call(
[
"ffmpeg",
"-y",
"-loglevel",
"error",
"-i",
args.audio,
"temp/temp.wav",
]
)
args.audio = "temp/temp.wav"
print("analysing audio...")
wav = audio.load_wav(args.audio, 16000)
mel = audio.melspectrogram(wav)
if np.isnan(mel.reshape(-1)).sum() > 0:
raise ValueError(
"Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again"
)
mel_chunks = []
mel_idx_multiplier = 80.0 / fps
i = 0
while 1:
start_idx = int(i * mel_idx_multiplier)
if start_idx + mel_step_size > len(mel[0]):
mel_chunks.append(mel[:, len(mel[0]) - mel_step_size :])
break
mel_chunks.append(mel[:, start_idx : start_idx + mel_step_size])
i += 1
full_frames = full_frames[: len(mel_chunks)]
if str(args.preview_settings) == "True":
full_frames = [full_frames[0]]
mel_chunks = [mel_chunks[0]]
print(str(len(full_frames)) + " frames to process")
batch_size = args.wav2lip_batch_size
if str(args.preview_settings) == "True":
gen = datagen(full_frames, mel_chunks)
else:
gen = datagen(full_frames.copy(), mel_chunks)
for i, (img_batch, mel_batch, frames, coords) in enumerate(
tqdm(
gen,
total=int(np.ceil(float(len(mel_chunks)) / batch_size)),
desc="Processing Wav2Lip",
ncols=100,
)
):
if i == 0:
if not args.quality == "Fast":
print(
f"mask size: {args.mask_dilation}, feathering: {args.mask_feathering}"
)
if not args.quality == "Improved":
print("Loading", args.sr_model)
run_params = load_sr()
print("Starting...")
frame_h, frame_w = full_frames[0].shape[:-1]
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
out = cv2.VideoWriter("temp/result.mp4", fourcc, fps, (frame_w, frame_h))
img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
with torch.no_grad():
pred = model(mel_batch, img_batch)
pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.0
for p, f, c in zip(pred, frames, coords):
# cv2.imwrite('temp/f.jpg', f)
y1, y2, x1, x2 = c
if (
str(args.debug_mask) == "True"
): # makes the background black & white so you can see the mask better
f = cv2.cvtColor(f, cv2.COLOR_BGR2GRAY)
f = cv2.cvtColor(f, cv2.COLOR_GRAY2BGR)
p = cv2.resize(p.astype(np.uint8), (x2 - x1, y2 - y1))
cf = f[y1:y2, x1:x2]
if args.quality == "Enhanced":
p = upscale(p, run_params)
if args.quality in ["Enhanced", "Improved"]:
if str(args.mouth_tracking) == "True":
p, last_mask = create_tracked_mask(p, cf)
else:
p, last_mask = create_mask(p, cf)
f[y1:y2, x1:x2] = p
if not g_colab:
# Display the frame
if preview_window == "Face":
cv2.imshow("face preview - press Q to abort", p)
elif preview_window == "Full":
cv2.imshow("full preview - press Q to abort", f)
elif preview_window == "Both":
cv2.imshow("face preview - press Q to abort", p)
cv2.imshow("full preview - press Q to abort", f)
key = cv2.waitKey(1) & 0xFF
if key == ord('q'):
exit() # Exit the loop when 'Q' is pressed
if str(args.preview_settings) == "True":
cv2.imwrite("temp/preview.jpg", f)
if not g_colab:
cv2.imshow("preview - press Q to close", f)
if cv2.waitKey(-1) & 0xFF == ord('q'):
exit() # Exit the loop when 'Q' is pressed
else:
out.write(f)
# Close the window(s) when done
cv2.destroyAllWindows()
out.release()
if str(args.preview_settings) == "False":
print("converting to final video")
subprocess.check_call([
"ffmpeg",
"-y",
"-loglevel",
"error",
"-i",
"temp/result.mp4",
"-i",
args.audio,
"-c:v",
"libx264",
args.outfile
])
if __name__ == "__main__":
args = parser.parse_args()
do_load(args.checkpoint_path)
main()