print("\rloading torch ", end="") import torch print("\rloading numpy ", end="") import numpy as np print("\rloading Image ", end="") from PIL import Image print("\rloading argparse ", end="") import argparse print("\rloading configparser", end="") import configparser print("\rloading math ", end="") import math print("\rloading os ", end="") import os print("\rloading subprocess ", end="") import subprocess print("\rloading pickle ", end="") import pickle print("\rloading cv2 ", end="") import cv2 print("\rloading audio ", end="") import audio print("\rloading RetinaFace ", end="") from batch_face import RetinaFace print("\rloading re ", end="") import re print("\rloading partial ", end="") from functools import partial print("\rloading tqdm ", end="") from tqdm import tqdm print("\rloading warnings ", end="") import warnings warnings.filterwarnings( "ignore", category=UserWarning, module="torchvision.transforms.functional_tensor" ) print("\rloading upscale ", end="") from enhance import upscale print("\rloading load_sr ", end="") from enhance import load_sr print("\rloading load_model ", end="") from easy_functions import load_model, g_colab print("\rimports loaded! ") device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' gpu_id = 0 if torch.cuda.is_available() else -1 if device == 'cpu': print('Warning: No GPU detected so inference will be done on the CPU which is VERY SLOW!') parser = argparse.ArgumentParser( description="Inference code to lip-sync videos in the wild using Wav2Lip models" ) parser.add_argument( "--checkpoint_path", type=str, help="Name of saved checkpoint to load weights from", required=True, ) parser.add_argument( "--segmentation_path", type=str, default="checkpoints/face_segmentation.pth", help="Name of saved checkpoint of segmentation network", required=False, ) parser.add_argument( "--face", type=str, help="Filepath of video/image that contains faces to use", required=True, ) parser.add_argument( "--audio", type=str, help="Filepath of video/audio file to use as raw audio source", required=True, ) parser.add_argument( "--outfile", type=str, help="Video path to save result. See default for an e.g.", default="results/result_voice.mp4", ) parser.add_argument( "--static", type=bool, help="If True, then use only first video frame for inference", default=False, ) parser.add_argument( "--fps", type=float, help="Can be specified only if input is a static image (default: 25)", default=25.0, required=False, ) parser.add_argument( "--pads", nargs="+", type=int, default=[0, 10, 0, 0], help="Padding (top, bottom, left, right). Please adjust to include chin at least", ) parser.add_argument( "--wav2lip_batch_size", type=int, help="Batch size for Wav2Lip model(s)", default=1 ) parser.add_argument( "--out_height", default=480, type=int, help="Output video height. Best results are obtained at 480 or 720", ) parser.add_argument( "--crop", nargs="+", type=int, default=[0, -1, 0, -1], help="Crop video to a smaller region (top, bottom, left, right). Applied after resize_factor and rotate arg. " "Useful if multiple face present. -1 implies the value will be auto-inferred based on height, width", ) parser.add_argument( "--box", nargs="+", type=int, default=[-1, -1, -1, -1], help="Specify a constant bounding box for the face. Use only as a last resort if the face is not detected." "Also, might work only if the face is not moving around much. Syntax: (top, bottom, left, right).", ) parser.add_argument( "--rotate", default=False, action="store_true", help="Sometimes videos taken from a phone can be flipped 90deg. If true, will flip video right by 90deg." "Use if you get a flipped result, despite feeding a normal looking video", ) parser.add_argument( "--nosmooth", type=str, default=False, help="Prevent smoothing face detections over a short temporal window", ) parser.add_argument( "--no_seg", default=False, action="store_true", help="Prevent using face segmentation", ) parser.add_argument( "--no_sr", default=False, action="store_true", help="Prevent using super resolution" ) parser.add_argument( "--sr_model", type=str, default="gfpgan", help="Name of upscaler - gfpgan or RestoreFormer", required=False, ) parser.add_argument( "--fullres", default=3, type=int, help="used only to determine if full res is used so that no resizing needs to be done if so", ) parser.add_argument( "--debug_mask", type=str, default=False, help="Makes background grayscale to see the mask better", ) parser.add_argument( "--preview_settings", type=str, default=False, help="Processes only one frame" ) parser.add_argument( "--mouth_tracking", type=str, default=False, help="Tracks the mouth in every frame for the mask", ) parser.add_argument( "--mask_dilation", default=150, type=float, help="size of mask around mouth", required=False, ) parser.add_argument( "--mask_feathering", default=151, type=int, help="amount of feathering of mask around mouth", required=False, ) parser.add_argument( "--quality", type=str, help="Choose between Fast, Improved and Enhanced", default="Fast", ) with open(os.path.join("checkpoints", "predictor.pkl"), "rb") as f: predictor = pickle.load(f) with open(os.path.join("checkpoints", "mouth_detector.pkl"), "rb") as f: mouth_detector = pickle.load(f) # creating variables to prevent failing when a face isn't detected kernel = last_mask = x = y = w = h = None g_colab = g_colab() if not g_colab: # Load the config file config = configparser.ConfigParser() config.read('config.ini') # Get the value of the "preview_window" variable preview_window = config.get('OPTIONS', 'preview_window') all_mouth_landmarks = [] model = detector = detector_model = None def do_load(checkpoint_path): global model, detector, detector_model model = load_model(checkpoint_path) detector = RetinaFace( gpu_id=gpu_id, model_path="checkpoints/mobilenet.pth", network="mobilenet" ) detector_model = detector.model def face_rect(images): face_batch_size = 8 num_batches = math.ceil(len(images) / face_batch_size) prev_ret = None for i in range(num_batches): batch = images[i * face_batch_size : (i + 1) * face_batch_size] all_faces = detector(batch) # return faces list of all images for faces in all_faces: if faces: box, landmarks, score = faces[0] prev_ret = tuple(map(int, box)) yield prev_ret def create_tracked_mask(img, original_img): global kernel, last_mask, x, y, w, h # Add last_mask to global variables # Convert color space from BGR to RGB if necessary cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img) cv2.cvtColor(original_img, cv2.COLOR_BGR2RGB, original_img) # Detect face faces = mouth_detector(img) if len(faces) == 0: if last_mask is not None: last_mask = cv2.resize(last_mask, (img.shape[1], img.shape[0])) mask = last_mask # use the last successful mask else: cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img) return img, None else: face = faces[0] shape = predictor(img, face) # Get points for mouth mouth_points = np.array( [[shape.part(i).x, shape.part(i).y] for i in range(48, 68)] ) # Calculate bounding box dimensions x, y, w, h = cv2.boundingRect(mouth_points) # Set kernel size as a fraction of bounding box size kernel_size = int(max(w, h) * args.mask_dilation) # if kernel_size % 2 == 0: # Ensure kernel size is odd # kernel_size += 1 # Create kernel kernel = np.ones((kernel_size, kernel_size), np.uint8) # Create binary mask for mouth mask = np.zeros(img.shape[:2], dtype=np.uint8) cv2.fillConvexPoly(mask, mouth_points, 255) last_mask = mask # Update last_mask with the new mask # Dilate the mask dilated_mask = cv2.dilate(mask, kernel) # Calculate distance transform of dilated mask dist_transform = cv2.distanceTransform(dilated_mask, cv2.DIST_L2, 5) # Normalize distance transform cv2.normalize(dist_transform, dist_transform, 0, 255, cv2.NORM_MINMAX) # Convert normalized distance transform to binary mask and convert it to uint8 _, masked_diff = cv2.threshold(dist_transform, 50, 255, cv2.THRESH_BINARY) masked_diff = masked_diff.astype(np.uint8) # make sure blur is an odd number blur = args.mask_feathering if blur % 2 == 0: blur += 1 # Set blur size as a fraction of bounding box size blur = int(max(w, h) * blur) # 10% of bounding box size if blur % 2 == 0: # Ensure blur size is odd blur += 1 masked_diff = cv2.GaussianBlur(masked_diff, (blur, blur), 0) # Convert numpy arrays to PIL Images input1 = Image.fromarray(img) input2 = Image.fromarray(original_img) # Convert mask to single channel where pixel values are from the alpha channel of the current mask mask = Image.fromarray(masked_diff) # Ensure images are the same size assert input1.size == input2.size == mask.size # Paste input1 onto input2 using the mask input2.paste(input1, (0, 0), mask) # Convert the final PIL Image back to a numpy array input2 = np.array(input2) # input2 = cv2.cvtColor(input2, cv2.COLOR_BGR2RGB) cv2.cvtColor(input2, cv2.COLOR_BGR2RGB, input2) return input2, mask def create_mask(img, original_img): global kernel, last_mask, x, y, w, h # Add last_mask to global variables # Convert color space from BGR to RGB if necessary cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img) cv2.cvtColor(original_img, cv2.COLOR_BGR2RGB, original_img) if last_mask is not None: last_mask = np.array(last_mask) # Convert PIL Image to numpy array last_mask = cv2.resize(last_mask, (img.shape[1], img.shape[0])) mask = last_mask # use the last successful mask mask = Image.fromarray(mask) else: # Detect face faces = mouth_detector(img) if len(faces) == 0: cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img) return img, None else: face = faces[0] shape = predictor(img, face) # Get points for mouth mouth_points = np.array( [[shape.part(i).x, shape.part(i).y] for i in range(48, 68)] ) # Calculate bounding box dimensions x, y, w, h = cv2.boundingRect(mouth_points) # Set kernel size as a fraction of bounding box size kernel_size = int(max(w, h) * args.mask_dilation) # if kernel_size % 2 == 0: # Ensure kernel size is odd # kernel_size += 1 # Create kernel kernel = np.ones((kernel_size, kernel_size), np.uint8) # Create binary mask for mouth mask = np.zeros(img.shape[:2], dtype=np.uint8) cv2.fillConvexPoly(mask, mouth_points, 255) # Dilate the mask dilated_mask = cv2.dilate(mask, kernel) # Calculate distance transform of dilated mask dist_transform = cv2.distanceTransform(dilated_mask, cv2.DIST_L2, 5) # Normalize distance transform cv2.normalize(dist_transform, dist_transform, 0, 255, cv2.NORM_MINMAX) # Convert normalized distance transform to binary mask and convert it to uint8 _, masked_diff = cv2.threshold(dist_transform, 50, 255, cv2.THRESH_BINARY) masked_diff = masked_diff.astype(np.uint8) if not args.mask_feathering == 0: blur = args.mask_feathering # Set blur size as a fraction of bounding box size blur = int(max(w, h) * blur) # 10% of bounding box size if blur % 2 == 0: # Ensure blur size is odd blur += 1 masked_diff = cv2.GaussianBlur(masked_diff, (blur, blur), 0) # Convert mask to single channel where pixel values are from the alpha channel of the current mask mask = Image.fromarray(masked_diff) last_mask = mask # Update last_mask with the final mask after dilation and feathering # Convert numpy arrays to PIL Images input1 = Image.fromarray(img) input2 = Image.fromarray(original_img) # Resize mask to match image size # mask = Image.fromarray(mask) mask = mask.resize(input1.size) # Ensure images are the same size assert input1.size == input2.size == mask.size # Paste input1 onto input2 using the mask input2.paste(input1, (0, 0), mask) # Convert the final PIL Image back to a numpy array input2 = np.array(input2) # input2 = cv2.cvtColor(input2, cv2.COLOR_BGR2RGB) cv2.cvtColor(input2, cv2.COLOR_BGR2RGB, input2) return input2, mask def get_smoothened_boxes(boxes, T): for i in range(len(boxes)): if i + T > len(boxes): window = boxes[len(boxes) - T :] else: window = boxes[i : i + T] boxes[i] = np.mean(window, axis=0) return boxes def face_detect(images, results_file="last_detected_face.pkl"): # If results file exists, load it and return if os.path.exists(results_file): print("Using face detection data from last input") with open(results_file, "rb") as f: return pickle.load(f) results = [] pady1, pady2, padx1, padx2 = args.pads tqdm_partial = partial(tqdm, position=0, leave=True) for image, (rect) in tqdm_partial( zip(images, face_rect(images)), total=len(images), desc="detecting face in every frame", ncols=100, ): if rect is None: cv2.imwrite( "temp/faulty_frame.jpg", image ) # check this frame where the face was not detected. raise ValueError( "Face not detected! Ensure the video contains a face in all the frames." ) y1 = max(0, rect[1] - pady1) y2 = min(image.shape[0], rect[3] + pady2) x1 = max(0, rect[0] - padx1) x2 = min(image.shape[1], rect[2] + padx2) results.append([x1, y1, x2, y2]) boxes = np.array(results) if str(args.nosmooth) == "False": boxes = get_smoothened_boxes(boxes, T=5) results = [ [image[y1:y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes) ] # Save results to file with open(results_file, "wb") as f: pickle.dump(results, f) return results def datagen(frames, mels): img_batch, mel_batch, frame_batch, coords_batch = [], [], [], [] print("\r" + " " * 100, end="\r") if args.box[0] == -1: if not args.static: face_det_results = face_detect(frames) # BGR2RGB for CNN face detection else: face_det_results = face_detect([frames[0]]) else: print("Using the specified bounding box instead of face detection...") y1, y2, x1, x2 = args.box face_det_results = [[f[y1:y2, x1:x2], (y1, y2, x1, x2)] for f in frames] for i, m in enumerate(mels): idx = 0 if args.static else i % len(frames) frame_to_save = frames[idx].copy() face, coords = face_det_results[idx].copy() face = cv2.resize(face, (args.img_size, args.img_size)) img_batch.append(face) mel_batch.append(m) frame_batch.append(frame_to_save) coords_batch.append(coords) if len(img_batch) >= args.wav2lip_batch_size: img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) img_masked = img_batch.copy() img_masked[:, args.img_size // 2 :] = 0 img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.0 mel_batch = np.reshape( mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1] ) yield img_batch, mel_batch, frame_batch, coords_batch img_batch, mel_batch, frame_batch, coords_batch = [], [], [], [] if len(img_batch) > 0: img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) img_masked = img_batch.copy() img_masked[:, args.img_size // 2 :] = 0 img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.0 mel_batch = np.reshape( mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1] ) yield img_batch, mel_batch, frame_batch, coords_batch mel_step_size = 16 def _load(checkpoint_path): if device != "cpu": checkpoint = torch.load(checkpoint_path) else: checkpoint = torch.load( checkpoint_path, map_location=lambda storage, loc: storage ) return checkpoint def main(): args.img_size = 96 frame_number = 11 if os.path.isfile(args.face) and args.face.split(".")[1] in ["jpg", "png", "jpeg"]: args.static = True if not os.path.isfile(args.face): raise ValueError("--face argument must be a valid path to video/image file") elif args.face.split(".")[1] in ["jpg", "png", "jpeg"]: full_frames = [cv2.imread(args.face)] fps = args.fps else: if args.fullres != 1: print("Resizing video...") video_stream = cv2.VideoCapture(args.face) fps = video_stream.get(cv2.CAP_PROP_FPS) full_frames = [] while 1: still_reading, frame = video_stream.read() if not still_reading: video_stream.release() break if args.fullres != 1: aspect_ratio = frame.shape[1] / frame.shape[0] frame = cv2.resize( frame, (int(args.out_height * aspect_ratio), args.out_height) ) if args.rotate: frame = cv2.rotate(frame, cv2.cv2.ROTATE_90_CLOCKWISE) y1, y2, x1, x2 = args.crop if x2 == -1: x2 = frame.shape[1] if y2 == -1: y2 = frame.shape[0] frame = frame[y1:y2, x1:x2] full_frames.append(frame) if not args.audio.endswith(".wav"): print("Converting audio to .wav") subprocess.check_call( [ "ffmpeg", "-y", "-loglevel", "error", "-i", args.audio, "temp/temp.wav", ] ) args.audio = "temp/temp.wav" print("analysing audio...") wav = audio.load_wav(args.audio, 16000) mel = audio.melspectrogram(wav) if np.isnan(mel.reshape(-1)).sum() > 0: raise ValueError( "Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again" ) mel_chunks = [] mel_idx_multiplier = 80.0 / fps i = 0 while 1: start_idx = int(i * mel_idx_multiplier) if start_idx + mel_step_size > len(mel[0]): mel_chunks.append(mel[:, len(mel[0]) - mel_step_size :]) break mel_chunks.append(mel[:, start_idx : start_idx + mel_step_size]) i += 1 full_frames = full_frames[: len(mel_chunks)] if str(args.preview_settings) == "True": full_frames = [full_frames[0]] mel_chunks = [mel_chunks[0]] print(str(len(full_frames)) + " frames to process") batch_size = args.wav2lip_batch_size if str(args.preview_settings) == "True": gen = datagen(full_frames, mel_chunks) else: gen = datagen(full_frames.copy(), mel_chunks) for i, (img_batch, mel_batch, frames, coords) in enumerate( tqdm( gen, total=int(np.ceil(float(len(mel_chunks)) / batch_size)), desc="Processing Wav2Lip", ncols=100, ) ): if i == 0: if not args.quality == "Fast": print( f"mask size: {args.mask_dilation}, feathering: {args.mask_feathering}" ) if not args.quality == "Improved": print("Loading", args.sr_model) run_params = load_sr() print("Starting...") frame_h, frame_w = full_frames[0].shape[:-1] fourcc = cv2.VideoWriter_fourcc(*"mp4v") out = cv2.VideoWriter("temp/result.mp4", fourcc, fps, (frame_w, frame_h)) img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device) mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device) with torch.no_grad(): pred = model(mel_batch, img_batch) pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.0 for p, f, c in zip(pred, frames, coords): # cv2.imwrite('temp/f.jpg', f) y1, y2, x1, x2 = c if ( str(args.debug_mask) == "True" ): # makes the background black & white so you can see the mask better f = cv2.cvtColor(f, cv2.COLOR_BGR2GRAY) f = cv2.cvtColor(f, cv2.COLOR_GRAY2BGR) p = cv2.resize(p.astype(np.uint8), (x2 - x1, y2 - y1)) cf = f[y1:y2, x1:x2] if args.quality == "Enhanced": p = upscale(p, run_params) if args.quality in ["Enhanced", "Improved"]: if str(args.mouth_tracking) == "True": p, last_mask = create_tracked_mask(p, cf) else: p, last_mask = create_mask(p, cf) f[y1:y2, x1:x2] = p if not g_colab: # Display the frame if preview_window == "Face": cv2.imshow("face preview - press Q to abort", p) elif preview_window == "Full": cv2.imshow("full preview - press Q to abort", f) elif preview_window == "Both": cv2.imshow("face preview - press Q to abort", p) cv2.imshow("full preview - press Q to abort", f) key = cv2.waitKey(1) & 0xFF if key == ord('q'): exit() # Exit the loop when 'Q' is pressed if str(args.preview_settings) == "True": cv2.imwrite("temp/preview.jpg", f) if not g_colab: cv2.imshow("preview - press Q to close", f) if cv2.waitKey(-1) & 0xFF == ord('q'): exit() # Exit the loop when 'Q' is pressed else: out.write(f) # Close the window(s) when done cv2.destroyAllWindows() out.release() if str(args.preview_settings) == "False": print("converting to final video") subprocess.check_call([ "ffmpeg", "-y", "-loglevel", "error", "-i", "temp/result.mp4", "-i", args.audio, "-c:v", "libx264", args.outfile ]) if __name__ == "__main__": args = parser.parse_args() do_load(args.checkpoint_path) main()