import sys # sys.path.append('./CodeFormer/CodeFormer') sys.path.append('./post_process/inswapper/CodeFormer/CodeFormer') import time import os import cv2 import copy import torch import torch.nn.functional as F from torchvision.transforms.functional import normalize # import insightface # import onnxruntime import numpy as np import PIL from PIL import Image from typing import List, Union, Dict, Set, Tuple # from basicsr.utils import imwrite, img2tensor, tensor2img # from basicsr.utils.download_util import load_file_from_url # from facelib.utils.face_restoration_helper import FaceRestoreHelper # from facelib.utils.misc import is_gray # from basicsr.archs.rrdbnet_arch import RRDBNet # from basicsr.utils.realesrgan_utils import RealESRGANer # from basicsr.utils.registry import ARCH_REGISTRY # sess_options = onnxruntime.SessionOptions() # sess_options.intra_op_num_threads = 1#os.cpu_count() # sess_options.inter_op_num_threads = 1#os.cpu_count() class FaceSwapper: def __init__(self, faceswapper, faceanalyser, codeformer): self.faceswapper = faceswapper self.faceanalyser = faceanalyser self.codeformer = codeformer self.source_indexes = "-1" self.target_indexes = "-1" self.face_restore = True self.background_enhance = False self.face_upsample = False self.upscale = 1 self.codeformer_fidelity = 0.8 # @staticmethod # def getFaceSwapModel(model_path: str, providers): # model = insightface.model_zoo.get_model(model_path, providers=providers) # return model # @staticmethod # def getFaceAnalyser(model_path: str, providers, # det_size=(320, 320)): # face_analyser ="buffalo_l", root="./checkpoints", providers=providers) # face_analyser.prepare(ctx_id=0, det_size=det_size) # return face_analyser @staticmethod def get_one_face(face_analyser, frame:np.ndarray): face = face_analyser.get(frame) try: return min(face, key=lambda x: x.bbox[0]) except ValueError: return None @staticmethod def get_many_faces(face_analyser, frame:np.ndarray): """ get faces from left to right by order """ try: face = face_analyser.get(frame) return sorted(face, key=lambda x: x.bbox[0]) except IndexError: return None @staticmethod def swap_face(face_swapper, source_faces, target_faces, source_index, target_index, temp_frame): """ paste source_face on target image """ source_face = source_faces[source_index] target_face = target_faces[target_index] return face_swapper.get(temp_frame, target_face, source_face, paste_back=True) def process(self, source_img: Union[Image.Image, List], target_img: Image.Image, source_indexes: str, target_indexes: str, face_swapper, face_analyser): # load machine default available providers # providers = onnxruntime.get_available_providers() # cuda_provider = ("CUDAExecutionProvider", {"cudnn_conv_algo_search": "DEFAULT"}) # load face_analyser # getfaceanalysertime = time.time() # face_analyser = self.getFaceAnalyser(model, providers=[cuda_provider]) # print(f"getfaceanalysertime: {time.time() - getfaceanalysertime} s.") # load face_swapper # getfaceswaptime = time.time() # model_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), model) # face_swapper = self.getFaceSwapModel(model_path, providers=[cuda_provider]) # print(f"getfaceswaptime: {time.time() - getfaceswaptime} s.") # read target image target_img = cv2.cvtColor(np.array(target_img), cv2.COLOR_RGB2BGR) # detect faces that will be replaced in the target image gettargetfaces = time.time() target_faces = self.get_many_faces(face_analyser, target_img) print(f"gettargetfaces: {time.time() - gettargetfaces} s.") num_target_faces = len(target_faces) num_source_images = len(source_img) if target_faces is not None: temp_frame = copy.deepcopy(target_img) if isinstance(source_img, list) and num_source_images == num_target_faces: print("Replacing faces in target image from the left to the right by order") for i in range(num_target_faces): getsourcefaces = time.time() source_faces = self.get_many_faces(face_analyser, cv2.cvtColor(np.array(source_img[i]), cv2.COLOR_RGB2BGR)) print(f"getsourcefaces: {time.time() - getsourcefaces} s.") source_index = i target_index = i if source_faces is None: raise Exception("No source faces found!") swapfacetime = time.time() temp_frame = self.swap_face( face_swapper, source_faces, target_faces, source_index, target_index, temp_frame ) print(f"swapfacetime: {time.time() - swapfacetime} s.") elif num_source_images == 1: # detect source faces that will be replaced into the target image source_faces = self.get_many_faces(face_analyser, cv2.cvtColor(np.array(source_img[0]), cv2.COLOR_RGB2BGR)) num_source_faces = len(source_faces) print(f"Source faces: {num_source_faces}") print(f"Target faces: {num_target_faces}") if source_faces is None: raise Exception("No source faces found!") if target_indexes == "-1": if num_source_faces == 1: print("Replacing all faces in target image with the same face from the source image") num_iterations = num_target_faces elif num_source_faces < num_target_faces: print("There are less faces in the source image than the target image, replacing as many as we can") num_iterations = num_source_faces elif num_target_faces < num_source_faces: print("There are less faces in the target image than the source image, replacing as many as we can") num_iterations = num_target_faces else: print("Replacing all faces in the target image with the faces from the source image") num_iterations = num_target_faces for i in range(num_iterations): source_index = 0 if num_source_faces == 1 else i target_index = i temp_frame = self.swap_face( face_swapper, source_faces, target_faces, source_index, target_index, temp_frame ) else: print("Replacing specific face(s) in the target image with specific face(s) from the source image") if source_indexes == "-1": source_indexes = ','.join(map(lambda x: str(x), range(num_source_faces))) if target_indexes == "-1": target_indexes = ','.join(map(lambda x: str(x), range(num_target_faces))) source_indexes = source_indexes.split(',') target_indexes = target_indexes.split(',') num_source_faces_to_swap = len(source_indexes) num_target_faces_to_swap = len(target_indexes) if num_source_faces_to_swap > num_source_faces: raise Exception("Number of source indexes is greater than the number of faces in the source image") if num_target_faces_to_swap > num_target_faces: raise Exception("Number of target indexes is greater than the number of faces in the target image") if num_source_faces_to_swap > num_target_faces_to_swap: num_iterations = num_source_faces_to_swap else: num_iterations = num_target_faces_to_swap if num_source_faces_to_swap == num_target_faces_to_swap: for index in range(num_iterations): source_index = int(source_indexes[index]) target_index = int(target_indexes[index]) if source_index > num_source_faces-1: raise ValueError(f"Source index {source_index} is higher than the number of faces in the source image") if target_index > num_target_faces-1: raise ValueError(f"Target index {target_index} is higher than the number of faces in the target image") temp_frame = self.swap_face( face_swapper, source_faces, target_faces, source_index, target_index, temp_frame ) else: raise Exception("Unsupported face configuration") result = temp_frame else: print("No target faces found!") result_image = Image.fromarray(cv2.cvtColor(result, cv2.COLOR_BGR2RGB)) return result_image def run_faceswap(self, source_img : List[PIL.Image.Image], target_img : PIL.Image.Image): print("Source image list:", source_img) # download from # model = "./checkpoints/inswapper_128.onnx" # so = onnxruntime.SessionOptions() # so.inter_op_num_threads = 4 # so.intra_op_num_threads = 2 # session = onnxruntime.InferenceSession(model_file, sess_options=so) result_image = self.process(source_img, target_img, self.source_indexes, self.target_indexes, self.faceswapper, self.faceanalyser) if self.face_restore: # from post_process.inswapper.restoration import check_ckpts, set_realesrgan, face_restoration from post_process.inswapper.restoration import face_restoration # make sure the ckpts downloaded successfully # check_ckpts() # # upsampler = set_realesrgan() upsampler = None device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu") # codeformer_net = ARCH_REGISTRY.get("CodeFormer")(dim_embd=512, # codebook_size=1024, # n_head=8, # n_layers=9, # connect_list=["32", "64", "128", "256"], # ).to(device) # ckpt_path = "CodeFormer/CodeFormer/weights/CodeFormer/codeformer.pth" # checkpoint = torch.load(ckpt_path)["params_ema"] # codeformer_net.load_state_dict(checkpoint) # codeformer_net.eval() result_image = cv2.cvtColor(np.array(result_image), cv2.COLOR_RGB2BGR) result_image = face_restoration(result_image, self.background_enhance, self.face_upsample, self.upscale, self.codeformer_fidelity, upsampler, self.codeformer, device) result_image = Image.fromarray(result_image) return result_image