victorgg commited on
Commit
42d4d05
·
verified ·
1 Parent(s): ec36cc7

Update swap.py

Browse files
Files changed (1) hide show
  1. swap.py +63 -38
swap.py CHANGED
@@ -3,29 +3,27 @@ import os
3
 
4
  import cv2
5
  import torch
6
- import Image
 
7
  from insightface.app import FaceAnalysis
8
  import face_align
9
 
10
  faceAnalysis = FaceAnalysis(name='buffalo_l')
11
- faceAnalysis.prepare(ctx_id=0, det_size=(512, 512))
12
 
13
  from StyleTransferModel_128 import StyleTransferModel
 
14
 
15
  def parse_arguments():
16
  parser = argparse.ArgumentParser(description='Process command line arguments')
17
 
18
- parser.add_argument('--target', required=True, help='Target path')
19
- parser.add_argument('--source', required=True, help='Source path')
20
- parser.add_argument('--outputPath', required=True, help='Output path')
21
  parser.add_argument('--modelPath', required=True, help='Model path')
22
- parser.add_argument('--no-paste-back', action='store_true', help='Disable pasting back the swapped face onto the original image')
23
  parser.add_argument('--resolution', type=int, default=128, help='Resolution')
24
 
25
  return parser.parse_args()
26
 
27
  def get_device():
28
- return torch.device('cuda' if torch.cuda.is_available() else 'cpu')
29
 
30
  def load_model(model_path):
31
  device = get_device()
@@ -43,51 +41,78 @@ def swap_face(model, target_face, source_face_latent):
43
  with torch.no_grad():
44
  swapped_tensor = model(target_tensor, source_tensor)
45
 
46
- swapped_face = Image.postprocess_face(swapped_tensor)
47
-
48
  return swapped_face, swapped_tensor
49
 
50
  def create_target(target_image, resolution):
51
- if isinstance(target_image, str):
52
- target_image = cv2.imread(target_image)
53
 
54
- target_face = faceAnalysis.get(target_image)[0]
55
- aligned_target_face, M = face_align.norm_crop2(target_image, target_face.kps, resolution)
56
- target_face_blob = Image.getBlob(aligned_target_face, (resolution, resolution))
57
 
58
  return target_face_blob, M
59
 
60
- def create_source(source_img_path):
61
- source_image = cv2.imread(source_img_path)
 
62
 
63
- source_face = faceAnalysis.get(source_image)[0]
64
 
65
- source_latent = Image.getLatent(source_face)
66
 
67
- return source_latent
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
- def main():
70
- args = parse_arguments()
71
-
72
- # Access the arguments
73
- target_image_path = args.target
74
- source = args.source
75
- output_path = args.outputPath
76
- model_path = args.modelPath
77
 
78
- model = load_model(model_path)
 
 
 
 
79
 
80
- target_img = cv2.imread(target_image_path)
81
- target_face_blob, M = create_target(target_img, args.resolution)
82
- source_latent = create_source(source)
 
 
 
 
 
 
 
 
 
 
 
 
83
  swapped_face, _ = swap_face(model, target_face_blob, source_latent)
84
 
85
- if not args.no_paste_back:
86
- swapped_face = Image.blend_swapped_image(swapped_face, target_img, M)
 
 
87
 
88
- output_folder = os.path.dirname(output_path)
89
- os.makedirs(output_folder, exist_ok=True)
90
- cv2.imwrite(output_path, swapped_face)
 
 
 
 
91
 
92
- if __name__ == "__main__":
93
- main()
 
3
 
4
  import cv2
5
  import torch
6
+ import numpy as np # Import numpy explicitly
7
+ from PIL import Image # Use PIL for image processing
8
  from insightface.app import FaceAnalysis
9
  import face_align
10
 
11
  faceAnalysis = FaceAnalysis(name='buffalo_l')
12
+ faceAnalysis.prepare(ctx_id=-1, det_size=(512, 512)) #ctx_id=-1 for CPU
13
 
14
  from StyleTransferModel_128 import StyleTransferModel
15
+ import gradio as gr
16
 
17
  def parse_arguments():
18
  parser = argparse.ArgumentParser(description='Process command line arguments')
19
 
 
 
 
20
  parser.add_argument('--modelPath', required=True, help='Model path')
 
21
  parser.add_argument('--resolution', type=int, default=128, help='Resolution')
22
 
23
  return parser.parse_args()
24
 
25
  def get_device():
26
+ return torch.device('cpu') # Force CPU
27
 
28
  def load_model(model_path):
29
  device = get_device()
 
41
  with torch.no_grad():
42
  swapped_tensor = model(target_tensor, source_tensor)
43
 
44
+ swapped_face = postprocess_face(swapped_tensor) # Use PIL-based postprocess
45
+
46
  return swapped_face, swapped_tensor
47
 
48
  def create_target(target_image, resolution):
49
+ target_face = faceAnalysis.get(np.array(target_image))[0] # Convert PIL to numpy
 
50
 
51
+ aligned_target_face, M = face_align.norm_crop2(np.array(target_image), target_face.kps, resolution) # Convert PIL to numpy
52
+ target_face_blob = getBlob(aligned_target_face, (resolution, resolution))
 
53
 
54
  return target_face_blob, M
55
 
56
+ def create_source(source_image):
57
+ source_face = faceAnalysis.get(np.array(source_image))[0] # Convert PIL to numpy
58
+ source_latent = getLatent(source_face)
59
 
60
+ return source_latent
61
 
 
62
 
63
+ def postprocess_face(swapped_tensor):
64
+ swapped_tensor = swapped_tensor.cpu().numpy()
65
+ swapped_tensor = np.transpose(swapped_tensor, (0, 2, 3, 1))
66
+ swapped_tensor = (swapped_tensor * 255).astype(np.uint8)
67
+ swapped_face = Image.fromarray(swapped_tensor[0]) # Convert to PIL Image
68
+ return swapped_face
69
+
70
+ def getBlob(aligned_face, size):
71
+ aligned_face = cv2.resize(aligned_face, size)
72
+ aligned_face = aligned_face / 255.0
73
+ aligned_face = np.transpose(aligned_face, (2, 0, 1))
74
+ aligned_face = np.expand_dims(aligned_face, axis=0)
75
+ aligned_face = torch.from_numpy(aligned_face).float()
76
+ return aligned_face
77
+
78
+ def getLatent(source_face):
79
+ return source_face.embedding
80
 
 
 
 
 
 
 
 
 
81
 
82
+ def blend_swapped_image(swapped_face, target_img, M):
83
+ swapped_face = np.array(swapped_face) # PIL to numpy
84
+ swapped_face = cv2.warpAffine(swapped_face, M, (target_img.shape[1], target_img.shape[0]))
85
+ mask = np.ones_like(swapped_face) * 255
86
+ mask = cv2.warpAffine(mask, M, (target_img.shape[1], target_img.shape[0]))
87
 
88
+ target_img = np.array(target_img) # PIL to numpy
89
+ swapped_face = Image.blend(Image.fromarray(target_img), Image.fromarray(swapped_face), Image.fromarray(mask).convert("L"))
90
+
91
+ return np.array(swapped_face) # numpy to PIL
92
+
93
+
94
+ def process_images(target_image, source_image, model_path):
95
+ args = parse_arguments()
96
+ args.modelPath = model_path
97
+ args.no_paste_back = False # or True, as you prefer
98
+ args.resolution = 128
99
+ model = load_model(args.modelPath)
100
+
101
+ target_face_blob, M = create_target(target_image, args.resolution)
102
+ source_latent = create_source(source_image)
103
  swapped_face, _ = swap_face(model, target_face_blob, source_latent)
104
 
105
+ swapped_face = blend_swapped_image(swapped_face, target_image, M) # PIL images
106
+
107
+ return Image.fromarray(swapped_face) # Return PIL image
108
+
109
 
110
+ with gr.Blocks() as demo:
111
+ target_image = gr.Image(label="Target Image", type="pil") # Use PIL type
112
+ source_image = gr.Image(label="Source Image", type="pil") # Use PIL type
113
+ model_path = gr.Textbox(label="Model Path", value="path/to/your/model.pth") # Add model path input
114
+ output_image = gr.Image(label="Output Image", type="pil") # Use PIL type
115
+ btn = gr.Button("Swap Face")
116
+ btn.click(fn=process_images, inputs=[target_image, source_image, model_path], outputs=output_image)
117
 
118
+ demo.launch() #no share = true for local running