Jannat24 commited on
Commit
c6178c5
·
verified ·
1 Parent(s): 5db35c9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -39
app.py CHANGED
@@ -63,77 +63,77 @@ transform = T.Compose([
63
 
64
  def gen_sources(deepfake_img):
65
  #----------------DeepFake Face Segmentation-----------------
66
- ##----------------------Initialize:Face Segmentation----------------------------------
67
  segmenter = FaceSegmenter(threshold=0.5)
68
- # Convert PIL Image to BGR numpy array for segmentation
69
  img_np = np.array(deepfake_img.convert('RGB'))
70
  img_bgr = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
71
- # Segment the face
72
  segmented_np = segmenter.segment_face(img_bgr)
73
- # Convert segmented numpy array (BGR) back to PIL Image
74
  deepfake_seg = Image.fromarray(cv2.cvtColor(segmented_np, cv2.COLOR_BGR2RGB))
75
- #------------Initialize:Decoder-F------------------------
 
76
  checkpoint_path_f = "./models/model_vaq1_ff.pth"
77
- # Load model checkpoints
78
  checkpoint_f = torch.load(checkpoint_path_f, map_location=device)
79
- # Load the state dictionary into the models
80
  model_vaq_f.load_state_dict(checkpoint_f, strict=True)
81
  model_vaq_f.eval()
82
- #------------Initialize:Decoder-G------------------------
83
  checkpoint_path_g = "./models/model_vaq2_gg.pth"
84
  checkpoint_g = torch.load(checkpoint_path_g, map_location=device)
85
- # Load the state dictionary into the models
86
  model_vaq_g.load_state_dict(checkpoint_g, strict=True)
87
  model_vaq_g.eval()
88
- ##------------------------Initialize Model-F-------------------------------------
89
  model_z1 = DeepfakeToSourceTransformer().to(device)
90
- model_z1.load_state_dict(torch.load("./models/model_z1_ff.pth",map_location=device),strict=True)
91
  model_z1.eval()
92
- ##------------------------Initialize Model-G-------------------------------------
93
  model_z2 = DeepfakeToSourceTransformer().to(device)
94
- model_z2.load_state_dict(torch.load("./models/model_z2_gg.pth",map_location=device),strict=True)
95
  model_z2.eval()
96
- ##--------------------Initialize:Evaluation---------------------------------------
97
  criterion = DF()
98
-
99
- ##----------------------Operation-------------------------------------------------
100
  with torch.no_grad():
101
- # Load and preprocess input image
102
- #img = Image.open(deepfake_img).convert('RGB')
103
- df_img = transform(deepfake_img.convert('RGB')).unsqueeze(0).to(device) # Shape: (1, 3, 256, 256)
104
  seg_img = transform(deepfake_seg).unsqueeze(0).to(device)
105
-
106
- # Calculate quantized_block for all images
107
- z_df, _, _ = model_vaq_f.encode(df_img)
108
- z_seg, _, _ = model_vaq_g.encode(seg_img)
109
- rec_z_img1 = model_z1(z_df)
110
- rec_z_img2 = model_z2(z_seg)
111
  rec_img1 = model_vaq_f.decode(rec_z_img1).squeeze(0)
112
  rec_img2 = model_vaq_g.decode(rec_z_img2).squeeze(0)
113
  rec_img1_pil = T.ToPILImage()(rec_img1)
114
  rec_img2_pil = T.ToPILImage()(rec_img2)
115
 
116
- # Save PIL images to in-memory buffers
117
- buffer1 = BytesIO()
118
- buffer2 = BytesIO()
119
- rec_img1_pil.save(buffer1, format="PNG")
120
- rec_img2_pil.save(buffer2, format="PNG")
 
 
 
 
121
 
122
- # Pass buffers to Gradio client
123
  result = client.predict(
124
- target=file(buffer1),
125
- source=file(buffer2), slider=100, adv_slider=100,
126
  settings=["Adversarial Defense"], api_name="/run_inference"
127
  )
128
 
 
 
 
 
129
  # Load result and compute loss
130
- dfimage_pil = Image.open(result) # Open the resulting image
131
- buffer3 = BytesIO()
132
- dfimage_pil.save(buffer3, format="PNG")
133
- rec_df = transform(Image.open(buffer3)).unsqueeze(0).to(device)
134
- rec_loss,_ = criterion(df_img, rec_df)
 
 
135
 
136
- return (rec_img1_pil, rec_img2_pil, dfimage_pil, round(rec_loss.item(),3))
137
 
138
  #________________________Create the Gradio interface_________________________________
139
  interface = gr.Interface(
 
63
 
64
  def gen_sources(deepfake_img):
65
  #----------------DeepFake Face Segmentation-----------------
 
66
  segmenter = FaceSegmenter(threshold=0.5)
 
67
  img_np = np.array(deepfake_img.convert('RGB'))
68
  img_bgr = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
 
69
  segmented_np = segmenter.segment_face(img_bgr)
 
70
  deepfake_seg = Image.fromarray(cv2.cvtColor(segmented_np, cv2.COLOR_BGR2RGB))
71
+
72
+ #------------Initialize Models------------------------
73
  checkpoint_path_f = "./models/model_vaq1_ff.pth"
 
74
  checkpoint_f = torch.load(checkpoint_path_f, map_location=device)
 
75
  model_vaq_f.load_state_dict(checkpoint_f, strict=True)
76
  model_vaq_f.eval()
77
+
78
  checkpoint_path_g = "./models/model_vaq2_gg.pth"
79
  checkpoint_g = torch.load(checkpoint_path_g, map_location=device)
 
80
  model_vaq_g.load_state_dict(checkpoint_g, strict=True)
81
  model_vaq_g.eval()
82
+
83
  model_z1 = DeepfakeToSourceTransformer().to(device)
84
+ model_z1.load_state_dict(torch.load("./models/model_z1_ff.pth", map_location=device), strict=True)
85
  model_z1.eval()
86
+
87
  model_z2 = DeepfakeToSourceTransformer().to(device)
88
+ model_z2.load_state_dict(torch.load("./models/model_z2_gg.pth", map_location=device), strict=True)
89
  model_z2.eval()
90
+
91
  criterion = DF()
92
+
 
93
  with torch.no_grad():
94
+ df_img = transform(deepfake_img.convert('RGB')).unsqueeze(0).to(device)
 
 
95
  seg_img = transform(deepfake_seg).unsqueeze(0).to(device)
96
+
97
+ z_df, _, _ = model_vaq_f.encode(df_img)
98
+ z_seg, _, _ = model_vaq_g.encode(seg_img)
99
+ rec_z_img1 = model_z1(z_df)
100
+ rec_z_img2 = model_z2(z_seg)
 
101
  rec_img1 = model_vaq_f.decode(rec_z_img1).squeeze(0)
102
  rec_img2 = model_vaq_g.decode(rec_z_img2).squeeze(0)
103
  rec_img1_pil = T.ToPILImage()(rec_img1)
104
  rec_img2_pil = T.ToPILImage()(rec_img2)
105
 
106
+ # Save PIL images to temporary files
107
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp1, \
108
+ tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp2:
109
+
110
+ rec_img1_pil.save(temp1, format="PNG")
111
+ rec_img2_pil.save(temp2, format="PNG")
112
+
113
+ temp1_path = temp1.name
114
+ temp2_path = temp2.name
115
 
116
+ # Pass file paths to Gradio client
117
  result = client.predict(
118
+ target=file(temp1_path),
119
+ source=file(temp2_path), slider=100, adv_slider=100,
120
  settings=["Adversarial Defense"], api_name="/run_inference"
121
  )
122
 
123
+ # Clean up temporary files
124
+ os.remove(temp1_path)
125
+ os.remove(temp2_path)
126
+
127
  # Load result and compute loss
128
+ dfimage_pil = Image.open(result)
129
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp3:
130
+ dfimage_pil.save(temp3, format="PNG")
131
+ rec_df = transform(Image.open(temp3.name)).unsqueeze(0).to(device)
132
+ os.remove(temp3.name)
133
+
134
+ rec_loss, _ = criterion(df_img, rec_df)
135
 
136
+ return (rec_img1_pil, rec_img2_pil, dfimage_pil, round(rec_loss.item(), 3))
137
 
138
  #________________________Create the Gradio interface_________________________________
139
  interface = gr.Interface(