Update app.py
Browse files
app.py
CHANGED
@@ -63,77 +63,77 @@ transform = T.Compose([
|
|
63 |
|
64 |
def gen_sources(deepfake_img):
|
65 |
#----------------DeepFake Face Segmentation-----------------
|
66 |
-
##----------------------Initialize:Face Segmentation----------------------------------
|
67 |
segmenter = FaceSegmenter(threshold=0.5)
|
68 |
-
# Convert PIL Image to BGR numpy array for segmentation
|
69 |
img_np = np.array(deepfake_img.convert('RGB'))
|
70 |
img_bgr = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
|
71 |
-
# Segment the face
|
72 |
segmented_np = segmenter.segment_face(img_bgr)
|
73 |
-
# Convert segmented numpy array (BGR) back to PIL Image
|
74 |
deepfake_seg = Image.fromarray(cv2.cvtColor(segmented_np, cv2.COLOR_BGR2RGB))
|
75 |
-
|
|
|
76 |
checkpoint_path_f = "./models/model_vaq1_ff.pth"
|
77 |
-
# Load model checkpoints
|
78 |
checkpoint_f = torch.load(checkpoint_path_f, map_location=device)
|
79 |
-
# Load the state dictionary into the models
|
80 |
model_vaq_f.load_state_dict(checkpoint_f, strict=True)
|
81 |
model_vaq_f.eval()
|
82 |
-
|
83 |
checkpoint_path_g = "./models/model_vaq2_gg.pth"
|
84 |
checkpoint_g = torch.load(checkpoint_path_g, map_location=device)
|
85 |
-
# Load the state dictionary into the models
|
86 |
model_vaq_g.load_state_dict(checkpoint_g, strict=True)
|
87 |
model_vaq_g.eval()
|
88 |
-
|
89 |
model_z1 = DeepfakeToSourceTransformer().to(device)
|
90 |
-
model_z1.load_state_dict(torch.load("./models/model_z1_ff.pth",map_location=device),strict=True)
|
91 |
model_z1.eval()
|
92 |
-
|
93 |
model_z2 = DeepfakeToSourceTransformer().to(device)
|
94 |
-
model_z2.load_state_dict(torch.load("./models/model_z2_gg.pth",map_location=device),strict=True)
|
95 |
model_z2.eval()
|
96 |
-
|
97 |
criterion = DF()
|
98 |
-
|
99 |
-
##----------------------Operation-------------------------------------------------
|
100 |
with torch.no_grad():
|
101 |
-
|
102 |
-
#img = Image.open(deepfake_img).convert('RGB')
|
103 |
-
df_img = transform(deepfake_img.convert('RGB')).unsqueeze(0).to(device) # Shape: (1, 3, 256, 256)
|
104 |
seg_img = transform(deepfake_seg).unsqueeze(0).to(device)
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
rec_z_img2 = model_z2(z_seg)
|
111 |
rec_img1 = model_vaq_f.decode(rec_z_img1).squeeze(0)
|
112 |
rec_img2 = model_vaq_g.decode(rec_z_img2).squeeze(0)
|
113 |
rec_img1_pil = T.ToPILImage()(rec_img1)
|
114 |
rec_img2_pil = T.ToPILImage()(rec_img2)
|
115 |
|
116 |
-
# Save PIL images to
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
|
|
|
|
|
|
|
|
121 |
|
122 |
-
# Pass
|
123 |
result = client.predict(
|
124 |
-
target=file(
|
125 |
-
source=file(
|
126 |
settings=["Adversarial Defense"], api_name="/run_inference"
|
127 |
)
|
128 |
|
|
|
|
|
|
|
|
|
129 |
# Load result and compute loss
|
130 |
-
dfimage_pil = Image.open(result)
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
|
|
|
|
135 |
|
136 |
-
return (rec_img1_pil, rec_img2_pil, dfimage_pil, round(rec_loss.item(),3))
|
137 |
|
138 |
#________________________Create the Gradio interface_________________________________
|
139 |
interface = gr.Interface(
|
|
|
63 |
|
64 |
def gen_sources(deepfake_img):
|
65 |
#----------------DeepFake Face Segmentation-----------------
|
|
|
66 |
segmenter = FaceSegmenter(threshold=0.5)
|
|
|
67 |
img_np = np.array(deepfake_img.convert('RGB'))
|
68 |
img_bgr = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
|
|
|
69 |
segmented_np = segmenter.segment_face(img_bgr)
|
|
|
70 |
deepfake_seg = Image.fromarray(cv2.cvtColor(segmented_np, cv2.COLOR_BGR2RGB))
|
71 |
+
|
72 |
+
#------------Initialize Models------------------------
|
73 |
checkpoint_path_f = "./models/model_vaq1_ff.pth"
|
|
|
74 |
checkpoint_f = torch.load(checkpoint_path_f, map_location=device)
|
|
|
75 |
model_vaq_f.load_state_dict(checkpoint_f, strict=True)
|
76 |
model_vaq_f.eval()
|
77 |
+
|
78 |
checkpoint_path_g = "./models/model_vaq2_gg.pth"
|
79 |
checkpoint_g = torch.load(checkpoint_path_g, map_location=device)
|
|
|
80 |
model_vaq_g.load_state_dict(checkpoint_g, strict=True)
|
81 |
model_vaq_g.eval()
|
82 |
+
|
83 |
model_z1 = DeepfakeToSourceTransformer().to(device)
|
84 |
+
model_z1.load_state_dict(torch.load("./models/model_z1_ff.pth", map_location=device), strict=True)
|
85 |
model_z1.eval()
|
86 |
+
|
87 |
model_z2 = DeepfakeToSourceTransformer().to(device)
|
88 |
+
model_z2.load_state_dict(torch.load("./models/model_z2_gg.pth", map_location=device), strict=True)
|
89 |
model_z2.eval()
|
90 |
+
|
91 |
criterion = DF()
|
92 |
+
|
|
|
93 |
with torch.no_grad():
|
94 |
+
df_img = transform(deepfake_img.convert('RGB')).unsqueeze(0).to(device)
|
|
|
|
|
95 |
seg_img = transform(deepfake_seg).unsqueeze(0).to(device)
|
96 |
+
|
97 |
+
z_df, _, _ = model_vaq_f.encode(df_img)
|
98 |
+
z_seg, _, _ = model_vaq_g.encode(seg_img)
|
99 |
+
rec_z_img1 = model_z1(z_df)
|
100 |
+
rec_z_img2 = model_z2(z_seg)
|
|
|
101 |
rec_img1 = model_vaq_f.decode(rec_z_img1).squeeze(0)
|
102 |
rec_img2 = model_vaq_g.decode(rec_z_img2).squeeze(0)
|
103 |
rec_img1_pil = T.ToPILImage()(rec_img1)
|
104 |
rec_img2_pil = T.ToPILImage()(rec_img2)
|
105 |
|
106 |
+
# Save PIL images to temporary files
|
107 |
+
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp1, \
|
108 |
+
tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp2:
|
109 |
+
|
110 |
+
rec_img1_pil.save(temp1, format="PNG")
|
111 |
+
rec_img2_pil.save(temp2, format="PNG")
|
112 |
+
|
113 |
+
temp1_path = temp1.name
|
114 |
+
temp2_path = temp2.name
|
115 |
|
116 |
+
# Pass file paths to Gradio client
|
117 |
result = client.predict(
|
118 |
+
target=file(temp1_path),
|
119 |
+
source=file(temp2_path), slider=100, adv_slider=100,
|
120 |
settings=["Adversarial Defense"], api_name="/run_inference"
|
121 |
)
|
122 |
|
123 |
+
# Clean up temporary files
|
124 |
+
os.remove(temp1_path)
|
125 |
+
os.remove(temp2_path)
|
126 |
+
|
127 |
# Load result and compute loss
|
128 |
+
dfimage_pil = Image.open(result)
|
129 |
+
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp3:
|
130 |
+
dfimage_pil.save(temp3, format="PNG")
|
131 |
+
rec_df = transform(Image.open(temp3.name)).unsqueeze(0).to(device)
|
132 |
+
os.remove(temp3.name)
|
133 |
+
|
134 |
+
rec_loss, _ = criterion(df_img, rec_df)
|
135 |
|
136 |
+
return (rec_img1_pil, rec_img2_pil, dfimage_pil, round(rec_loss.item(), 3))
|
137 |
|
138 |
#________________________Create the Gradio interface_________________________________
|
139 |
interface = gr.Interface(
|