File size: 4,356 Bytes
742d952
 
 
 
 
42d4d05
 
742d952
 
 
 
42d4d05
742d952
 
42d4d05
742d952
 
 
 
 
 
 
 
 
 
42d4d05
742d952
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42d4d05
 
742d952
 
 
42d4d05
742d952
42d4d05
 
742d952
 
 
42d4d05
 
 
742d952
42d4d05
742d952
 
42d4d05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
742d952
 
42d4d05
 
 
 
 
742d952
42d4d05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
742d952
 
42d4d05
 
 
 
742d952
42d4d05
 
 
 
 
 
 
742d952
42d4d05
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import argparse
import os

import cv2
import torch
import numpy as np  # Import numpy explicitly
from PIL import Image  # Use PIL for image processing
from insightface.app import FaceAnalysis
import face_align

faceAnalysis = FaceAnalysis(name='buffalo_l')
faceAnalysis.prepare(ctx_id=-1, det_size=(512, 512)) #ctx_id=-1 for CPU

from StyleTransferModel_128 import StyleTransferModel
import gradio as gr

def parse_arguments():
    parser = argparse.ArgumentParser(description='Process command line arguments')
    
    parser.add_argument('--modelPath', required=True, help='Model path')
    parser.add_argument('--resolution', type=int, default=128, help='Resolution')

    return parser.parse_args()

def get_device():
    return torch.device('cpu')  # Force CPU

def load_model(model_path):
    device = get_device()
    model = StyleTransferModel().to(device)
    model.load_state_dict(torch.load(model_path, map_location=device), strict=False)
    model.eval()
    return model

def swap_face(model, target_face, source_face_latent):
    device = get_device()

    target_tensor = torch.from_numpy(target_face).to(device)
    source_tensor = torch.from_numpy(source_face_latent).to(device)

    with torch.no_grad():
        swapped_tensor = model(target_tensor, source_tensor)
    
    swapped_face = postprocess_face(swapped_tensor) # Use PIL-based postprocess

    return swapped_face, swapped_tensor

def create_target(target_image, resolution):
    target_face = faceAnalysis.get(np.array(target_image))[0] # Convert PIL to numpy

    aligned_target_face, M = face_align.norm_crop2(np.array(target_image), target_face.kps, resolution) # Convert PIL to numpy
    target_face_blob = getBlob(aligned_target_face, (resolution, resolution))

    return target_face_blob, M

def create_source(source_image):
    source_face = faceAnalysis.get(np.array(source_image))[0] # Convert PIL to numpy
    source_latent = getLatent(source_face)

    return source_latent


def postprocess_face(swapped_tensor):
    swapped_tensor = swapped_tensor.cpu().numpy()
    swapped_tensor = np.transpose(swapped_tensor, (0, 2, 3, 1))
    swapped_tensor = (swapped_tensor * 255).astype(np.uint8)
    swapped_face = Image.fromarray(swapped_tensor[0]) # Convert to PIL Image
    return swapped_face

def getBlob(aligned_face, size):
    aligned_face = cv2.resize(aligned_face, size)
    aligned_face = aligned_face / 255.0
    aligned_face = np.transpose(aligned_face, (2, 0, 1))
    aligned_face = np.expand_dims(aligned_face, axis=0)
    aligned_face = torch.from_numpy(aligned_face).float()
    return aligned_face

def getLatent(source_face):
    return source_face.embedding


def blend_swapped_image(swapped_face, target_img, M):
    swapped_face = np.array(swapped_face) # PIL to numpy
    swapped_face = cv2.warpAffine(swapped_face, M, (target_img.shape[1], target_img.shape[0]))
    mask = np.ones_like(swapped_face) * 255
    mask = cv2.warpAffine(mask, M, (target_img.shape[1], target_img.shape[0]))

    target_img = np.array(target_img) # PIL to numpy
    swapped_face = Image.blend(Image.fromarray(target_img), Image.fromarray(swapped_face), Image.fromarray(mask).convert("L"))

    return np.array(swapped_face) # numpy to PIL


def process_images(target_image, source_image, model_path):
    args = parse_arguments()
    args.modelPath = model_path
    args.no_paste_back = False # or True, as you prefer
    args.resolution = 128
    model = load_model(args.modelPath)

    target_face_blob, M = create_target(target_image, args.resolution)
    source_latent = create_source(source_image)
    swapped_face, _ = swap_face(model, target_face_blob, source_latent)

    swapped_face = blend_swapped_image(swapped_face, target_image, M) # PIL images

    return Image.fromarray(swapped_face) # Return PIL image


with gr.Blocks() as demo:
    target_image = gr.Image(label="Target Image", type="pil") # Use PIL type
    source_image = gr.Image(label="Source Image", type="pil") # Use PIL type
    model_path = gr.Textbox(label="Model Path", value="path/to/your/model.pth") # Add model path input
    output_image = gr.Image(label="Output Image", type="pil") # Use PIL type
    btn = gr.Button("Swap Face")
    btn.click(fn=process_images, inputs=[target_image, source_image, model_path], outputs=output_image)

demo.launch() #no share = true for local running