haijian06 commited on
Commit
612a8c2
1 Parent(s): 53c822f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -20
app.py CHANGED
@@ -182,33 +182,33 @@ transform = transforms.Compose([
182
 
183
  def predict_image(image):
184
  """
185
- 处理上传的图片并返回预测结果
186
  """
187
  if image is None:
188
- return "请上传图片"
189
 
190
- # 转换图片格式
191
  if isinstance(image, np.ndarray):
192
  image = Image.fromarray(image)
193
 
194
- # 应用转换
195
  image_tensor = transform(image).unsqueeze(0)
196
 
197
- # 准备特征提取用的图像
198
  np_image = image_tensor.cpu().numpy().squeeze(0).transpose(1, 2, 0)
199
  np_image = np.clip(np_image, 0, 1)
200
 
201
- # 转换为灰度图
202
  gray_image = cv2.cvtColor((np_image * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY)
203
  gray_image = gray_image.astype(np.float32) / 255.0
204
 
205
- # 提取特征
206
  glcm_features = extract_glcm_features(gray_image)
207
  spectrum_features = analyze_spectrum(gray_image)
208
  edge_features = extract_edge_features(gray_image)
209
  lbp_features = extract_lbp_features(gray_image)
210
 
211
- # 转换为张量并移到设备
212
  with torch.no_grad():
213
  image_tensor = image_tensor.to(device)
214
  glcm_features = torch.from_numpy(glcm_features).unsqueeze(0).to(device)
@@ -216,27 +216,32 @@ def predict_image(image):
216
  edge_features = torch.from_numpy(edge_features).unsqueeze(0).to(device)
217
  lbp_features = torch.from_numpy(lbp_features).unsqueeze(0).to(device)
218
 
219
- # 模型预测
220
  outputs = model(image_tensor, glcm_features, spectrum_features, edge_features, lbp_features)
221
  prediction = torch.sigmoid(outputs).item()
222
 
223
- # 返回预测结果
224
- if prediction > 0.5:
225
- return "真实人脸图片"
226
  else:
227
- return "虚假人脸图片"
228
 
229
- # 创建Gradio界面
230
  iface = gr.Interface(
231
  fn=predict_image,
232
  inputs=gr.Image(type="pil"),
233
- outputs=gr.Text(label="预测结果"),
234
- title="人脸真伪检测",
235
- description="上传一张人脸图片,模型将判断是真实人脸还是虚假人脸。",
236
  examples=[
237
- # 这里可以添加示例图片路径
238
- ]
 
 
 
 
 
239
  )
240
 
241
- # 启动应用
242
  iface.launch()
 
182
 
183
  def predict_image(image):
184
  """
185
+ Process uploaded image and return prediction result
186
  """
187
  if image is None:
188
+ return "Please upload an image"
189
 
190
+ # Convert image format
191
  if isinstance(image, np.ndarray):
192
  image = Image.fromarray(image)
193
 
194
+ # Apply transformations
195
  image_tensor = transform(image).unsqueeze(0)
196
 
197
+ # Prepare image for feature extraction
198
  np_image = image_tensor.cpu().numpy().squeeze(0).transpose(1, 2, 0)
199
  np_image = np.clip(np_image, 0, 1)
200
 
201
+ # Convert to grayscale
202
  gray_image = cv2.cvtColor((np_image * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY)
203
  gray_image = gray_image.astype(np.float32) / 255.0
204
 
205
+ # Extract features
206
  glcm_features = extract_glcm_features(gray_image)
207
  spectrum_features = analyze_spectrum(gray_image)
208
  edge_features = extract_edge_features(gray_image)
209
  lbp_features = extract_lbp_features(gray_image)
210
 
211
+ # Convert to tensors and move to device
212
  with torch.no_grad():
213
  image_tensor = image_tensor.to(device)
214
  glcm_features = torch.from_numpy(glcm_features).unsqueeze(0).to(device)
 
216
  edge_features = torch.from_numpy(edge_features).unsqueeze(0).to(device)
217
  lbp_features = torch.from_numpy(lbp_features).unsqueeze(0).to(device)
218
 
219
+ # Model prediction
220
  outputs = model(image_tensor, glcm_features, spectrum_features, edge_features, lbp_features)
221
  prediction = torch.sigmoid(outputs).item()
222
 
223
+ # Return prediction result (corrected logic)
224
+ if prediction < 0.5: # Changed from > to <
225
+ return "Real Face"
226
  else:
227
+ return "AI-Generated Face"
228
 
229
+ # Create Gradio interface
230
  iface = gr.Interface(
231
  fn=predict_image,
232
  inputs=gr.Image(type="pil"),
233
+ outputs=gr.Text(label="Prediction Result"),
234
+ title="Face Authentication System",
235
+ description="Upload a face image to determine if it's a real face or an AI-generated face.",
236
  examples=[
237
+ # Add example image paths here
238
+ ],
239
+ article="""
240
+ This system uses advanced deep learning techniques to detect whether a face image is real or AI-generated.
241
+ The model analyzes various image features including texture patterns, frequency spectrum, and local binary patterns
242
+ to make its determination.
243
+ """
244
  )
245
 
246
+ # Launch the application
247
  iface.launch()