haijian06 commited on
Commit
8711d89
·
verified ·
1 Parent(s): 2c67320

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -98
app.py CHANGED
@@ -10,68 +10,7 @@ from skimage.feature import graycomatrix, graycoprops, local_binary_pattern
10
  import timm
11
  import gradio as gr
12
 
13
- # GLCM feature extraction
14
- def extract_glcm_features(image):
15
- image_uint8 = (image * 255).astype(np.uint8)
16
- image_uint8 = image_uint8 // 4
17
-
18
- glcm = graycomatrix(
19
- image_uint8,
20
- distances=[1],
21
- angles=[0, np.pi / 4, np.pi / 2, 3 * np.pi / 4],
22
- levels=64,
23
- symmetric=True,
24
- normed=True
25
- )
26
-
27
- contrast = graycoprops(glcm, 'contrast').flatten()
28
- dissimilarity = graycoprops(glcm, 'dissimilarity').flatten()
29
- homogeneity = graycoprops(glcm, 'homogeneity').flatten()
30
- energy = graycoprops(glcm, 'energy').flatten()
31
- correlation = graycoprops(glcm, 'correlation').flatten()
32
-
33
- features = np.hstack([contrast, dissimilarity, homogeneity, energy, correlation])
34
- return features.astype(np.float32)
35
-
36
- # Spectrum analysis
37
- def analyze_spectrum(image, target_spectrum_length=181):
38
- f = fft2(image)
39
- fshift = fftshift(f)
40
- magnitude_spectrum = 20 * np.log(np.abs(fshift) + 1e-8)
41
-
42
- center = np.array(magnitude_spectrum.shape) // 2
43
- y, x = np.indices(magnitude_spectrum.shape)
44
- r = np.sqrt((x - center[1])**2 + (y - center[0])**2).astype(int)
45
-
46
- radial_mean = np.bincount(r.ravel(), magnitude_spectrum.ravel()) / np.bincount(r.ravel())
47
-
48
- if len(radial_mean) < target_spectrum_length:
49
- radial_mean = np.pad(radial_mean, (0, target_spectrum_length - len(radial_mean)), 'constant')
50
- else:
51
- radial_mean = radial_mean[:target_spectrum_length]
52
-
53
- return radial_mean.astype(np.float32)
54
-
55
- # Edge feature extraction
56
- def extract_edge_features(image):
57
- image_uint8 = (image * 255).astype(np.uint8)
58
- edges = cv2.Canny(image_uint8, 100, 200)
59
- edges_resized = cv2.resize(edges, (64, 64), interpolation=cv2.INTER_AREA)
60
- return edges_resized.astype(np.float32) / 255.0
61
-
62
- # LBP feature extraction
63
- def extract_lbp_features(image):
64
- radius = 1
65
- n_points = 8 * radius
66
- METHOD = 'uniform'
67
-
68
- lbp = local_binary_pattern(image, n_points, radius, METHOD)
69
- n_bins = n_points + 2
70
- hist, _ = np.histogram(lbp.ravel(), bins=n_bins, range=(0, n_bins), density=True)
71
-
72
- return hist.astype(np.float32)
73
-
74
- # Model architecture
75
  class AttentionBlock(nn.Module):
76
  def __init__(self, in_features):
77
  super(AttentionBlock, self).__init__()
@@ -90,7 +29,7 @@ class AdvancedFaceDetectionModel(nn.Module):
90
  def __init__(self, spectrum_length=181, lbp_n_bins=10):
91
  super(AdvancedFaceDetectionModel, self).__init__()
92
 
93
- self.efficientnet = timm.create_model('tf_efficientnetv2_b2', pretrained=False, num_classes=0)
94
  for param in self.efficientnet.conv_stem.parameters():
95
  param.requires_grad = False
96
  for param in self.efficientnet.bn1.parameters():
@@ -169,12 +108,72 @@ class AdvancedFaceDetectionModel(nn.Module):
169
  output = self.fusion(combined_features)
170
  return output.squeeze(1)
171
 
172
- # Initialize model and transform
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
174
  model = AdvancedFaceDetectionModel(spectrum_length=181, lbp_n_bins=10).to(device)
175
  model.load_state_dict(torch.load('best_model.pth', map_location=device))
176
  model.eval()
177
 
 
178
  transform = transforms.Compose([
179
  transforms.Resize((256, 256)),
180
  transforms.ToTensor(),
@@ -183,64 +182,61 @@ transform = transforms.Compose([
183
 
184
  def predict_image(image):
185
  """
186
- Process a single image and return prediction
187
  """
188
- # Convert to PIL Image if needed
189
- if not isinstance(image, Image.Image):
190
- image = Image.fromarray(image)
191
 
192
- # Convert to RGB if needed
193
- if image.mode != 'RGB':
194
- image = image.convert('RGB')
195
 
196
- # Apply transformations
197
  image_tensor = transform(image).unsqueeze(0)
198
-
199
- # Convert to NumPy array for feature extraction
200
  np_image = image_tensor.cpu().numpy().squeeze(0).transpose(1, 2, 0)
201
  np_image = np.clip(np_image, 0, 1)
202
-
203
- # Convert to grayscale
204
  gray_image = cv2.cvtColor((np_image * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY)
205
  gray_image = gray_image.astype(np.float32) / 255.0
206
-
207
- # Extract features
208
  glcm_features = extract_glcm_features(gray_image)
209
  spectrum_features = analyze_spectrum(gray_image)
210
  edge_features = extract_edge_features(gray_image)
211
  lbp_features = extract_lbp_features(gray_image)
212
-
213
- # Move everything to device
214
  with torch.no_grad():
215
  image_tensor = image_tensor.to(device)
216
  glcm_features = torch.from_numpy(glcm_features).unsqueeze(0).to(device)
217
  spectrum_features = torch.from_numpy(spectrum_features).unsqueeze(0).to(device)
218
  edge_features = torch.from_numpy(edge_features).unsqueeze(0).to(device)
219
  lbp_features = torch.from_numpy(lbp_features).unsqueeze(0).to(device)
220
-
221
- # Forward pass
222
  outputs = model(image_tensor, glcm_features, spectrum_features, edge_features, lbp_features)
223
- probability = torch.sigmoid(outputs).item()
224
- prediction = "Real Face" if probability > 0.5 else "Fake Face"
225
 
226
- return prediction, f"Confidence: {probability:.2%}"
 
 
 
 
227
 
228
- # Create Gradio interface
229
  iface = gr.Interface(
230
  fn=predict_image,
231
  inputs=gr.Image(type="pil"),
232
- outputs=[
233
- gr.Label(label="Prediction"),
234
- gr.Label(label="Confidence")
235
- ],
236
- title="Face Authentication System",
237
- description="Upload an image to determine if it contains a real or fake face.",
238
  examples=[
239
- ["example1.jpg"],
240
- ["example2.jpg"]
241
- ] if os.path.exists("example1.jpg") else None,
242
  )
243
 
244
- # Launch the app
245
- if __name__ == "__main__":
246
- iface.launch()
 
10
  import timm
11
  import gradio as gr
12
 
13
+ # [把你原来代码中的AttentionBlock和AdvancedFaceDetectionModel类定义放在这里]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  class AttentionBlock(nn.Module):
15
  def __init__(self, in_features):
16
  super(AttentionBlock, self).__init__()
 
29
  def __init__(self, spectrum_length=181, lbp_n_bins=10):
30
  super(AdvancedFaceDetectionModel, self).__init__()
31
 
32
+ self.efficientnet = timm.create_model('tf_efficientnetv2_b2', pretrained=True, num_classes=0)
33
  for param in self.efficientnet.conv_stem.parameters():
34
  param.requires_grad = False
35
  for param in self.efficientnet.bn1.parameters():
 
108
  output = self.fusion(combined_features)
109
  return output.squeeze(1)
110
 
111
+ # 特征提取函数
112
+ def extract_glcm_features(image):
113
+ image_uint8 = (image * 255).astype(np.uint8)
114
+ image_uint8 = image_uint8 // 4
115
+
116
+ glcm = graycomatrix(
117
+ image_uint8,
118
+ distances=[1],
119
+ angles=[0, np.pi / 4, np.pi / 2, 3 * np.pi / 4],
120
+ levels=64,
121
+ symmetric=True,
122
+ normed=True
123
+ )
124
+
125
+ contrast = graycoprops(glcm, 'contrast').flatten()
126
+ dissimilarity = graycoprops(glcm, 'dissimilarity').flatten()
127
+ homogeneity = graycoprops(glcm, 'homogeneity').flatten()
128
+ energy = graycoprops(glcm, 'energy').flatten()
129
+ correlation = graycoprops(glcm, 'correlation').flatten()
130
+
131
+ features = np.hstack([contrast, dissimilarity, homogeneity, energy, correlation])
132
+ return features.astype(np.float32)
133
+
134
+ def analyze_spectrum(image, target_spectrum_length=181):
135
+ f = fft2(image)
136
+ fshift = fftshift(f)
137
+ magnitude_spectrum = 20 * np.log(np.abs(fshift) + 1e-8)
138
+
139
+ center = np.array(magnitude_spectrum.shape) // 2
140
+ y, x = np.indices(magnitude_spectrum.shape)
141
+ r = np.sqrt((x - center[1])**2 + (y - center[0])**2).astype(int)
142
+
143
+ radial_mean = np.bincount(r.ravel(), magnitude_spectrum.ravel()) / np.bincount(r.ravel())
144
+
145
+ if len(radial_mean) < target_spectrum_length:
146
+ radial_mean = np.pad(radial_mean, (0, target_spectrum_length - len(radial_mean)), 'constant')
147
+ else:
148
+ radial_mean = radial_mean[:target_spectrum_length]
149
+
150
+ return radial_mean.astype(np.float32)
151
+
152
+ def extract_edge_features(image):
153
+ image_uint8 = (image * 255).astype(np.uint8)
154
+ edges = cv2.Canny(image_uint8, 100, 200)
155
+ edges_resized = cv2.resize(edges, (64, 64), interpolation=cv2.INTER_AREA)
156
+ return edges_resized.astype(np.float32) / 255.0
157
+
158
+ def extract_lbp_features(image):
159
+ radius = 1
160
+ n_points = 8 * radius
161
+ METHOD = 'uniform'
162
+
163
+ lbp = local_binary_pattern(image, n_points, radius, METHOD)
164
+
165
+ n_bins = n_points + 2
166
+ hist, _ = np.histogram(lbp.ravel(), bins=n_bins, range=(0, n_bins), density=True)
167
+
168
+ return hist.astype(np.float32)
169
+
170
+ # 加载模型
171
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
172
  model = AdvancedFaceDetectionModel(spectrum_length=181, lbp_n_bins=10).to(device)
173
  model.load_state_dict(torch.load('best_model.pth', map_location=device))
174
  model.eval()
175
 
176
+ # 图像预处理转换
177
  transform = transforms.Compose([
178
  transforms.Resize((256, 256)),
179
  transforms.ToTensor(),
 
182
 
183
  def predict_image(image):
184
  """
185
+ 处理上传的图片并返回预测结果
186
  """
187
+ if image is None:
188
+ return "请上传图片"
 
189
 
190
+ # 转换图片格式
191
+ if isinstance(image, np.ndarray):
192
+ image = Image.fromarray(image)
193
 
194
+ # 应用转换
195
  image_tensor = transform(image).unsqueeze(0)
196
+
197
+ # 准备特征提取用的图像
198
  np_image = image_tensor.cpu().numpy().squeeze(0).transpose(1, 2, 0)
199
  np_image = np.clip(np_image, 0, 1)
200
+
201
+ # 转换为灰度图
202
  gray_image = cv2.cvtColor((np_image * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY)
203
  gray_image = gray_image.astype(np.float32) / 255.0
204
+
205
+ # 提取特征
206
  glcm_features = extract_glcm_features(gray_image)
207
  spectrum_features = analyze_spectrum(gray_image)
208
  edge_features = extract_edge_features(gray_image)
209
  lbp_features = extract_lbp_features(gray_image)
210
+
211
+ # 转换为张量并移到设备
212
  with torch.no_grad():
213
  image_tensor = image_tensor.to(device)
214
  glcm_features = torch.from_numpy(glcm_features).unsqueeze(0).to(device)
215
  spectrum_features = torch.from_numpy(spectrum_features).unsqueeze(0).to(device)
216
  edge_features = torch.from_numpy(edge_features).unsqueeze(0).to(device)
217
  lbp_features = torch.from_numpy(lbp_features).unsqueeze(0).to(device)
218
+
219
+ # 模型预测
220
  outputs = model(image_tensor, glcm_features, spectrum_features, edge_features, lbp_features)
221
+ prediction = torch.sigmoid(outputs).item()
 
222
 
223
+ # 返回预测结果
224
+ if prediction > 0.5:
225
+ return "真实人脸图片"
226
+ else:
227
+ return "虚假人脸图片"
228
 
229
+ # 创建Gradio界面
230
  iface = gr.Interface(
231
  fn=predict_image,
232
  inputs=gr.Image(type="pil"),
233
+ outputs=gr.Text(label="预测结果"),
234
+ title="人脸真伪检测",
235
+ description="上传一张人脸图片,模型将判断是真实人脸还是虚假人脸。",
 
 
 
236
  examples=[
237
+ # 这里可以添加示例图片路径
238
+ ]
 
239
  )
240
 
241
+ # 启动应用
242
+ iface.launch()