|
import argparse |
|
import os |
|
|
|
import cv2 |
|
import torch |
|
import numpy as np |
|
from PIL import Image |
|
from insightface.app import FaceAnalysis |
|
import face_align |
|
|
|
faceAnalysis = FaceAnalysis(name='buffalo_l') |
|
faceAnalysis.prepare(ctx_id=-1, det_size=(512, 512)) |
|
|
|
from StyleTransferModel_128 import StyleTransferModel |
|
import gradio as gr |
|
|
|
def parse_arguments(): |
|
parser = argparse.ArgumentParser(description='Process command line arguments') |
|
|
|
parser.add_argument('--modelPath', required=True, help='Model path') |
|
parser.add_argument('--resolution', type=int, default=128, help='Resolution') |
|
|
|
return parser.parse_args() |
|
|
|
def get_device(): |
|
return torch.device('cpu') |
|
|
|
def load_model(model_path): |
|
device = get_device() |
|
model = StyleTransferModel().to(device) |
|
model.load_state_dict(torch.load(model_path, map_location=device), strict=False) |
|
model.eval() |
|
return model |
|
|
|
def swap_face(model, target_face, source_face_latent): |
|
device = get_device() |
|
|
|
target_tensor = torch.from_numpy(target_face).to(device) |
|
source_tensor = torch.from_numpy(source_face_latent).to(device) |
|
|
|
with torch.no_grad(): |
|
swapped_tensor = model(target_tensor, source_tensor) |
|
|
|
swapped_face = postprocess_face(swapped_tensor) |
|
|
|
return swapped_face, swapped_tensor |
|
|
|
def create_target(target_image, resolution): |
|
target_face = faceAnalysis.get(np.array(target_image))[0] |
|
|
|
aligned_target_face, M = face_align.norm_crop2(np.array(target_image), target_face.kps, resolution) |
|
target_face_blob = getBlob(aligned_target_face, (resolution, resolution)) |
|
|
|
return target_face_blob, M |
|
|
|
def create_source(source_image): |
|
source_face = faceAnalysis.get(np.array(source_image))[0] |
|
source_latent = getLatent(source_face) |
|
|
|
return source_latent |
|
|
|
|
|
def postprocess_face(swapped_tensor): |
|
swapped_tensor = swapped_tensor.cpu().numpy() |
|
swapped_tensor = np.transpose(swapped_tensor, (0, 2, 3, 1)) |
|
swapped_tensor = (swapped_tensor * 255).astype(np.uint8) |
|
swapped_face = Image.fromarray(swapped_tensor[0]) |
|
return swapped_face |
|
|
|
def getBlob(aligned_face, size): |
|
aligned_face = cv2.resize(aligned_face, size) |
|
aligned_face = aligned_face / 255.0 |
|
aligned_face = np.transpose(aligned_face, (2, 0, 1)) |
|
aligned_face = np.expand_dims(aligned_face, axis=0) |
|
aligned_face = torch.from_numpy(aligned_face).float() |
|
return aligned_face |
|
|
|
def getLatent(source_face): |
|
return source_face.embedding |
|
|
|
|
|
def blend_swapped_image(swapped_face, target_img, M): |
|
swapped_face = np.array(swapped_face) |
|
swapped_face = cv2.warpAffine(swapped_face, M, (target_img.shape[1], target_img.shape[0])) |
|
mask = np.ones_like(swapped_face) * 255 |
|
mask = cv2.warpAffine(mask, M, (target_img.shape[1], target_img.shape[0])) |
|
|
|
target_img = np.array(target_img) |
|
swapped_face = Image.blend(Image.fromarray(target_img), Image.fromarray(swapped_face), Image.fromarray(mask).convert("L")) |
|
|
|
return np.array(swapped_face) |
|
|
|
|
|
def process_images(target_image, source_image, model_path): |
|
args = parse_arguments() |
|
args.modelPath = model_path |
|
args.no_paste_back = False |
|
args.resolution = 128 |
|
model = load_model(args.modelPath) |
|
|
|
target_face_blob, M = create_target(target_image, args.resolution) |
|
source_latent = create_source(source_image) |
|
swapped_face, _ = swap_face(model, target_face_blob, source_latent) |
|
|
|
swapped_face = blend_swapped_image(swapped_face, target_image, M) |
|
|
|
return Image.fromarray(swapped_face) |
|
|
|
|
|
with gr.Blocks() as demo: |
|
target_image = gr.Image(label="Target Image", type="pil") |
|
source_image = gr.Image(label="Source Image", type="pil") |
|
model_path = gr.Textbox(label="Model Path", value="path/to/your/model.pth") |
|
output_image = gr.Image(label="Output Image", type="pil") |
|
btn = gr.Button("Swap Face") |
|
btn.click(fn=process_images, inputs=[target_image, source_image, model_path], outputs=output_image) |
|
|
|
demo.launch() |
|
|