victorgg commited on
Commit
cf1cb9e
·
verified ·
1 Parent(s): a6fbd76

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -28
app.py CHANGED
@@ -3,32 +3,34 @@ import os
3
 
4
  import cv2
5
  import torch
6
- import numpy as np # Import numpy explicitly
7
- from PIL import Image # Use PIL for image processing
8
  from insightface.app import FaceAnalysis
9
  import face_align
10
 
11
  faceAnalysis = FaceAnalysis(name='buffalo_l')
12
- faceAnalysis.prepare(ctx_id=-1, det_size=(512, 512)) #ctx_id=-1 for CPU
13
 
14
  from StyleTransferModel_128 import StyleTransferModel
15
  import gradio as gr
16
 
17
  def parse_arguments():
18
  parser = argparse.ArgumentParser(description='Process command line arguments')
19
-
20
- parser.add_argument('--modelPath', required=True, help='Model path')
21
- parser.add_argument('--resolution', type=int, default=128, help='Resolution')
22
 
23
  return parser.parse_args()
24
 
25
  def get_device():
26
- return torch.device('cpu') # Force CPU
27
 
28
  def load_model(model_path):
29
  device = get_device()
30
  model = StyleTransferModel().to(device)
31
- model.load_state_dict(torch.load(model_path, map_location=device), strict=False)
 
 
 
 
32
  model.eval()
33
  return model
34
 
@@ -41,20 +43,20 @@ def swap_face(model, target_face, source_face_latent):
41
  with torch.no_grad():
42
  swapped_tensor = model(target_tensor, source_tensor)
43
 
44
- swapped_face = postprocess_face(swapped_tensor) # Use PIL-based postprocess
45
 
46
  return swapped_face, swapped_tensor
47
 
48
  def create_target(target_image, resolution):
49
- target_face = faceAnalysis.get(np.array(target_image))[0] # Convert PIL to numpy
50
 
51
- aligned_target_face, M = face_align.norm_crop2(np.array(target_image), target_face.kps, resolution) # Convert PIL to numpy
52
  target_face_blob = getBlob(aligned_target_face, (resolution, resolution))
53
 
54
  return target_face_blob, M
55
 
56
  def create_source(source_image):
57
- source_face = faceAnalysis.get(np.array(source_image))[0] # Convert PIL to numpy
58
  source_latent = getLatent(source_face)
59
 
60
  return source_latent
@@ -64,7 +66,7 @@ def postprocess_face(swapped_tensor):
64
  swapped_tensor = swapped_tensor.cpu().numpy()
65
  swapped_tensor = np.transpose(swapped_tensor, (0, 2, 3, 1))
66
  swapped_tensor = (swapped_tensor * 255).astype(np.uint8)
67
- swapped_face = Image.fromarray(swapped_tensor[0]) # Convert to PIL Image
68
  return swapped_face
69
 
70
  def getBlob(aligned_face, size):
@@ -80,39 +82,41 @@ def getLatent(source_face):
80
 
81
 
82
  def blend_swapped_image(swapped_face, target_img, M):
83
- swapped_face = np.array(swapped_face) # PIL to numpy
84
  swapped_face = cv2.warpAffine(swapped_face, M, (target_img.shape[1], target_img.shape[0]))
85
  mask = np.ones_like(swapped_face) * 255
86
  mask = cv2.warpAffine(mask, M, (target_img.shape[1], target_img.shape[0]))
87
 
88
- target_img = np.array(target_img) # PIL to numpy
89
  swapped_face = Image.blend(Image.fromarray(target_img), Image.fromarray(swapped_face), Image.fromarray(mask).convert("L"))
90
 
91
- return np.array(swapped_face) # numpy to PIL
92
 
93
 
94
- def process_images(target_image, source_image, model_path):
95
  args = parse_arguments()
96
- args.modelPath = model_path
97
- args.no_paste_back = False # or True, as you prefer
98
  args.resolution = 128
99
- model = load_model(args.modelPath)
 
 
 
 
 
100
 
101
  target_face_blob, M = create_target(target_image, args.resolution)
102
  source_latent = create_source(source_image)
103
  swapped_face, _ = swap_face(model, target_face_blob, source_latent)
104
 
105
- swapped_face = blend_swapped_image(swapped_face, target_image, M) # PIL images
106
 
107
- return Image.fromarray(swapped_face) # Return PIL image
108
 
109
 
110
  with gr.Blocks() as demo:
111
- target_image = gr.Image(label="Target Image", type="pil") # Use PIL type
112
- source_image = gr.Image(label="Source Image", type="pil") # Use PIL type
113
- model_path = gr.Textbox(label="Model Path", value="path/to/your/model.pth") # Add model path input
114
- output_image = gr.Image(label="Output Image", type="pil") # Use PIL type
115
  btn = gr.Button("Swap Face")
116
- btn.click(fn=process_images, inputs=[target_image, source_image, model_path], outputs=output_image)
117
 
118
- demo.launch() #no share = true for local running
 
3
 
4
  import cv2
5
  import torch
6
+ import numpy as np
7
+ from PIL import Image
8
  from insightface.app import FaceAnalysis
9
  import face_align
10
 
11
  faceAnalysis = FaceAnalysis(name='buffalo_l')
12
+ faceAnalysis.prepare(ctx_id=-1, det_size=(512, 512))
13
 
14
  from StyleTransferModel_128 import StyleTransferModel
15
  import gradio as gr
16
 
17
  def parse_arguments():
18
  parser = argparse.ArgumentParser(description='Process command line arguments')
19
+ parser.add_argument('--resolution', type=int, default=128, help='Resolution') #Removed model path
 
 
20
 
21
  return parser.parse_args()
22
 
23
  def get_device():
24
+ return torch.device('cpu')
25
 
26
  def load_model(model_path):
27
  device = get_device()
28
  model = StyleTransferModel().to(device)
29
+ try:
30
+ model.load_state_dict(torch.load(model_path, map_location=device), strict=False)
31
+ except FileNotFoundError:
32
+ print(f"Error: Model file not found at {model_path}")
33
+ return None
34
  model.eval()
35
  return model
36
 
 
43
  with torch.no_grad():
44
  swapped_tensor = model(target_tensor, source_tensor)
45
 
46
+ swapped_face = postprocess_face(swapped_tensor)
47
 
48
  return swapped_face, swapped_tensor
49
 
50
  def create_target(target_image, resolution):
51
+ target_face = faceAnalysis.get(np.array(target_image))[0]
52
 
53
+ aligned_target_face, M = face_align.norm_crop2(np.array(target_image), target_face.kps, resolution)
54
  target_face_blob = getBlob(aligned_target_face, (resolution, resolution))
55
 
56
  return target_face_blob, M
57
 
58
  def create_source(source_image):
59
+ source_face = faceAnalysis.get(np.array(source_image))[0]
60
  source_latent = getLatent(source_face)
61
 
62
  return source_latent
 
66
  swapped_tensor = swapped_tensor.cpu().numpy()
67
  swapped_tensor = np.transpose(swapped_tensor, (0, 2, 3, 1))
68
  swapped_tensor = (swapped_tensor * 255).astype(np.uint8)
69
+ swapped_face = Image.fromarray(swapped_tensor[0])
70
  return swapped_face
71
 
72
  def getBlob(aligned_face, size):
 
82
 
83
 
84
  def blend_swapped_image(swapped_face, target_img, M):
85
+ swapped_face = np.array(swapped_face)
86
  swapped_face = cv2.warpAffine(swapped_face, M, (target_img.shape[1], target_img.shape[0]))
87
  mask = np.ones_like(swapped_face) * 255
88
  mask = cv2.warpAffine(mask, M, (target_img.shape[1], target_img.shape[0]))
89
 
90
+ target_img = np.array(target_img)
91
  swapped_face = Image.blend(Image.fromarray(target_img), Image.fromarray(swapped_face), Image.fromarray(mask).convert("L"))
92
 
93
+ return np.array(swapped_face)
94
 
95
 
96
+ def process_images(target_image, source_image):
97
  args = parse_arguments()
 
 
98
  args.resolution = 128
99
+
100
+ model_path = "reswapper-429500.pth" # Hardcoded model path
101
+
102
+ model = load_model(model_path)
103
+ if model is None:
104
+ return "Error: Could not load the model. Check the path."
105
 
106
  target_face_blob, M = create_target(target_image, args.resolution)
107
  source_latent = create_source(source_image)
108
  swapped_face, _ = swap_face(model, target_face_blob, source_latent)
109
 
110
+ swapped_face = blend_swapped_image(swapped_face, target_image, M)
111
 
112
+ return Image.fromarray(swapped_face)
113
 
114
 
115
  with gr.Blocks() as demo:
116
+ target_image = gr.Image(label="Target Image", type="pil")
117
+ source_image = gr.Image(label="Source Image", type="pil")
118
+ output_image = gr.Image(label="Output Image", type="pil")
 
119
  btn = gr.Button("Swap Face")
120
+ btn.click(fn=process_images, inputs=[target_image, source_image], outputs=output_image)
121
 
122
+ demo.launch()