quocanh34's picture
first commit
3773ad2
import sys
# sys.path.append('./CodeFormer/CodeFormer')
sys.path.append('./post_process/inswapper/CodeFormer/CodeFormer')
import time
import os
import cv2
import copy
import torch
import torch.nn.functional as F
from torchvision.transforms.functional import normalize
# import insightface
# import onnxruntime
import numpy as np
import PIL
from PIL import Image
from typing import List, Union, Dict, Set, Tuple
# from basicsr.utils import imwrite, img2tensor, tensor2img
# from basicsr.utils.download_util import load_file_from_url
# from facelib.utils.face_restoration_helper import FaceRestoreHelper
# from facelib.utils.misc import is_gray
# from basicsr.archs.rrdbnet_arch import RRDBNet
# from basicsr.utils.realesrgan_utils import RealESRGANer
# from basicsr.utils.registry import ARCH_REGISTRY
# sess_options = onnxruntime.SessionOptions()
# sess_options.intra_op_num_threads = 1#os.cpu_count()
# sess_options.inter_op_num_threads = 1#os.cpu_count()
class FaceSwapper:
def __init__(self, faceswapper, faceanalyser, codeformer):
self.faceswapper = faceswapper
self.faceanalyser = faceanalyser
self.codeformer = codeformer
self.source_indexes = "-1"
self.target_indexes = "-1"
self.face_restore = True
self.background_enhance = False
self.face_upsample = False
self.upscale = 1
self.codeformer_fidelity = 0.8
# @staticmethod
# def getFaceSwapModel(model_path: str, providers):
# model = insightface.model_zoo.get_model(model_path, providers=providers)
# return model
# @staticmethod
# def getFaceAnalyser(model_path: str, providers,
# det_size=(320, 320)):
# face_analyser = insightface.app.FaceAnalysis(name="buffalo_l", root="./checkpoints", providers=providers)
# face_analyser.prepare(ctx_id=0, det_size=det_size)
# return face_analyser
@staticmethod
def get_one_face(face_analyser,
frame:np.ndarray):
face = face_analyser.get(frame)
try:
return min(face, key=lambda x: x.bbox[0])
except ValueError:
return None
@staticmethod
def get_many_faces(face_analyser,
frame:np.ndarray):
"""
get faces from left to right by order
"""
try:
face = face_analyser.get(frame)
return sorted(face, key=lambda x: x.bbox[0])
except IndexError:
return None
@staticmethod
def swap_face(face_swapper,
source_faces,
target_faces,
source_index,
target_index,
temp_frame):
"""
paste source_face on target image
"""
source_face = source_faces[source_index]
target_face = target_faces[target_index]
return face_swapper.get(temp_frame, target_face, source_face, paste_back=True)
def process(self,
source_img: Union[Image.Image, List],
target_img: Image.Image,
source_indexes: str,
target_indexes: str,
face_swapper,
face_analyser):
# load machine default available providers
# providers = onnxruntime.get_available_providers()
# cuda_provider = ("CUDAExecutionProvider", {"cudnn_conv_algo_search": "DEFAULT"})
# load face_analyser
# getfaceanalysertime = time.time()
# face_analyser = self.getFaceAnalyser(model, providers=[cuda_provider])
# print(f"getfaceanalysertime: {time.time() - getfaceanalysertime} s.")
# load face_swapper
# getfaceswaptime = time.time()
# model_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), model)
# face_swapper = self.getFaceSwapModel(model_path, providers=[cuda_provider])
# print(f"getfaceswaptime: {time.time() - getfaceswaptime} s.")
# read target image
target_img = cv2.cvtColor(np.array(target_img), cv2.COLOR_RGB2BGR)
# detect faces that will be replaced in the target image
gettargetfaces = time.time()
target_faces = self.get_many_faces(face_analyser, target_img)
print(f"gettargetfaces: {time.time() - gettargetfaces} s.")
num_target_faces = len(target_faces)
num_source_images = len(source_img)
if target_faces is not None:
temp_frame = copy.deepcopy(target_img)
if isinstance(source_img, list) and num_source_images == num_target_faces:
print("Replacing faces in target image from the left to the right by order")
for i in range(num_target_faces):
getsourcefaces = time.time()
source_faces = self.get_many_faces(face_analyser, cv2.cvtColor(np.array(source_img[i]), cv2.COLOR_RGB2BGR))
print(f"getsourcefaces: {time.time() - getsourcefaces} s.")
source_index = i
target_index = i
if source_faces is None:
raise Exception("No source faces found!")
swapfacetime = time.time()
temp_frame = self.swap_face(
face_swapper,
source_faces,
target_faces,
source_index,
target_index,
temp_frame
)
print(f"swapfacetime: {time.time() - swapfacetime} s.")
elif num_source_images == 1:
# detect source faces that will be replaced into the target image
source_faces = self.get_many_faces(face_analyser, cv2.cvtColor(np.array(source_img[0]), cv2.COLOR_RGB2BGR))
num_source_faces = len(source_faces)
print(f"Source faces: {num_source_faces}")
print(f"Target faces: {num_target_faces}")
if source_faces is None:
raise Exception("No source faces found!")
if target_indexes == "-1":
if num_source_faces == 1:
print("Replacing all faces in target image with the same face from the source image")
num_iterations = num_target_faces
elif num_source_faces < num_target_faces:
print("There are less faces in the source image than the target image, replacing as many as we can")
num_iterations = num_source_faces
elif num_target_faces < num_source_faces:
print("There are less faces in the target image than the source image, replacing as many as we can")
num_iterations = num_target_faces
else:
print("Replacing all faces in the target image with the faces from the source image")
num_iterations = num_target_faces
for i in range(num_iterations):
source_index = 0 if num_source_faces == 1 else i
target_index = i
temp_frame = self.swap_face(
face_swapper,
source_faces,
target_faces,
source_index,
target_index,
temp_frame
)
else:
print("Replacing specific face(s) in the target image with specific face(s) from the source image")
if source_indexes == "-1":
source_indexes = ','.join(map(lambda x: str(x), range(num_source_faces)))
if target_indexes == "-1":
target_indexes = ','.join(map(lambda x: str(x), range(num_target_faces)))
source_indexes = source_indexes.split(',')
target_indexes = target_indexes.split(',')
num_source_faces_to_swap = len(source_indexes)
num_target_faces_to_swap = len(target_indexes)
if num_source_faces_to_swap > num_source_faces:
raise Exception("Number of source indexes is greater than the number of faces in the source image")
if num_target_faces_to_swap > num_target_faces:
raise Exception("Number of target indexes is greater than the number of faces in the target image")
if num_source_faces_to_swap > num_target_faces_to_swap:
num_iterations = num_source_faces_to_swap
else:
num_iterations = num_target_faces_to_swap
if num_source_faces_to_swap == num_target_faces_to_swap:
for index in range(num_iterations):
source_index = int(source_indexes[index])
target_index = int(target_indexes[index])
if source_index > num_source_faces-1:
raise ValueError(f"Source index {source_index} is higher than the number of faces in the source image")
if target_index > num_target_faces-1:
raise ValueError(f"Target index {target_index} is higher than the number of faces in the target image")
temp_frame = self.swap_face(
face_swapper,
source_faces,
target_faces,
source_index,
target_index,
temp_frame
)
else:
raise Exception("Unsupported face configuration")
result = temp_frame
else:
print("No target faces found!")
result_image = Image.fromarray(cv2.cvtColor(result, cv2.COLOR_BGR2RGB))
return result_image
def run_faceswap(self, source_img : List[PIL.Image.Image], target_img : PIL.Image.Image):
print("Source image list:", source_img)
# download from https://huggingface.co/deepinsight/inswapper/tree/main
# model = "./checkpoints/inswapper_128.onnx"
# so = onnxruntime.SessionOptions()
# so.inter_op_num_threads = 4
# so.intra_op_num_threads = 2
# session = onnxruntime.InferenceSession(model_file, sess_options=so)
result_image = self.process(source_img, target_img, self.source_indexes, self.target_indexes, self.faceswapper, self.faceanalyser)
if self.face_restore:
# from post_process.inswapper.restoration import check_ckpts, set_realesrgan, face_restoration
from post_process.inswapper.restoration import face_restoration
# make sure the ckpts downloaded successfully
# check_ckpts()
# https://huggingface.co/spaces/sczhou/CodeFormer
# upsampler = set_realesrgan()
upsampler = None
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
# codeformer_net = ARCH_REGISTRY.get("CodeFormer")(dim_embd=512,
# codebook_size=1024,
# n_head=8,
# n_layers=9,
# connect_list=["32", "64", "128", "256"],
# ).to(device)
# ckpt_path = "CodeFormer/CodeFormer/weights/CodeFormer/codeformer.pth"
# checkpoint = torch.load(ckpt_path)["params_ema"]
# codeformer_net.load_state_dict(checkpoint)
# codeformer_net.eval()
result_image = cv2.cvtColor(np.array(result_image), cv2.COLOR_RGB2BGR)
result_image = face_restoration(result_image,
self.background_enhance,
self.face_upsample,
self.upscale,
self.codeformer_fidelity,
upsampler,
self.codeformer,
device)
result_image = Image.fromarray(result_image)
return result_image