haijian06 commited on
Commit
36c5bbb
1 Parent(s): a16b0cd

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +246 -0
app.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import torchvision.transforms as transforms
5
+ from PIL import Image
6
+ import numpy as np
7
+ import cv2
8
+ from scipy.fftpack import fft2, fftshift
9
+ from skimage.feature import graycomatrix, graycoprops, local_binary_pattern
10
+ import timm
11
+ import gradio as gr
12
+
13
+ # GLCM feature extraction
14
+ def extract_glcm_features(image):
15
+ image_uint8 = (image * 255).astype(np.uint8)
16
+ image_uint8 = image_uint8 // 4
17
+
18
+ glcm = graycomatrix(
19
+ image_uint8,
20
+ distances=[1],
21
+ angles=[0, np.pi / 4, np.pi / 2, 3 * np.pi / 4],
22
+ levels=64,
23
+ symmetric=True,
24
+ normed=True
25
+ )
26
+
27
+ contrast = graycoprops(glcm, 'contrast').flatten()
28
+ dissimilarity = graycoprops(glcm, 'dissimilarity').flatten()
29
+ homogeneity = graycoprops(glcm, 'homogeneity').flatten()
30
+ energy = graycoprops(glcm, 'energy').flatten()
31
+ correlation = graycoprops(glcm, 'correlation').flatten()
32
+
33
+ features = np.hstack([contrast, dissimilarity, homogeneity, energy, correlation])
34
+ return features.astype(np.float32)
35
+
36
+ # Spectrum analysis
37
+ def analyze_spectrum(image, target_spectrum_length=181):
38
+ f = fft2(image)
39
+ fshift = fftshift(f)
40
+ magnitude_spectrum = 20 * np.log(np.abs(fshift) + 1e-8)
41
+
42
+ center = np.array(magnitude_spectrum.shape) // 2
43
+ y, x = np.indices(magnitude_spectrum.shape)
44
+ r = np.sqrt((x - center[1])**2 + (y - center[0])**2).astype(int)
45
+
46
+ radial_mean = np.bincount(r.ravel(), magnitude_spectrum.ravel()) / np.bincount(r.ravel())
47
+
48
+ if len(radial_mean) < target_spectrum_length:
49
+ radial_mean = np.pad(radial_mean, (0, target_spectrum_length - len(radial_mean)), 'constant')
50
+ else:
51
+ radial_mean = radial_mean[:target_spectrum_length]
52
+
53
+ return radial_mean.astype(np.float32)
54
+
55
+ # Edge feature extraction
56
+ def extract_edge_features(image):
57
+ image_uint8 = (image * 255).astype(np.uint8)
58
+ edges = cv2.Canny(image_uint8, 100, 200)
59
+ edges_resized = cv2.resize(edges, (64, 64), interpolation=cv2.INTER_AREA)
60
+ return edges_resized.astype(np.float32) / 255.0
61
+
62
+ # LBP feature extraction
63
+ def extract_lbp_features(image):
64
+ radius = 1
65
+ n_points = 8 * radius
66
+ METHOD = 'uniform'
67
+
68
+ lbp = local_binary_pattern(image, n_points, radius, METHOD)
69
+ n_bins = n_points + 2
70
+ hist, _ = np.histogram(lbp.ravel(), bins=n_bins, range=(0, n_bins), density=True)
71
+
72
+ return hist.astype(np.float32)
73
+
74
+ # Model architecture
75
+ class AttentionBlock(nn.Module):
76
+ def __init__(self, in_features):
77
+ super(AttentionBlock, self).__init__()
78
+ self.attention = nn.Sequential(
79
+ nn.Linear(in_features, max(in_features // 8, 1)),
80
+ nn.ReLU(),
81
+ nn.Linear(max(in_features // 8, 1), in_features),
82
+ nn.Sigmoid()
83
+ )
84
+
85
+ def forward(self, x):
86
+ attention_weights = self.attention(x)
87
+ return x * attention_weights
88
+
89
+ class AdvancedFaceDetectionModel(nn.Module):
90
+ def __init__(self, spectrum_length=181, lbp_n_bins=10):
91
+ super(AdvancedFaceDetectionModel, self).__init__()
92
+
93
+ self.efficientnet = timm.create_model('tf_efficientnetv2_b2', pretrained=False, num_classes=0)
94
+ for param in self.efficientnet.conv_stem.parameters():
95
+ param.requires_grad = False
96
+ for param in self.efficientnet.bn1.parameters():
97
+ param.requires_grad = False
98
+
99
+ self.glcm_fc = nn.Sequential(
100
+ nn.Linear(20, 64),
101
+ nn.BatchNorm1d(64),
102
+ nn.ReLU(),
103
+ nn.Dropout(0.5)
104
+ )
105
+
106
+ self.spectrum_conv = nn.Sequential(
107
+ nn.Conv1d(1, 64, kernel_size=3, padding=1),
108
+ nn.BatchNorm1d(64),
109
+ nn.ReLU(),
110
+ nn.AdaptiveAvgPool1d(1)
111
+ )
112
+
113
+ self.edge_conv = nn.Sequential(
114
+ nn.Conv2d(1, 32, kernel_size=3, padding=1),
115
+ nn.BatchNorm2d(32),
116
+ nn.ReLU(),
117
+ nn.AdaptiveAvgPool2d((8, 8))
118
+ )
119
+
120
+ self.lbp_fc = nn.Sequential(
121
+ nn.Linear(lbp_n_bins, 64),
122
+ nn.BatchNorm1d(64),
123
+ nn.ReLU(),
124
+ nn.Dropout(0.5)
125
+ )
126
+
127
+ image_feature_size = self.efficientnet.num_features
128
+ self.image_attention = AttentionBlock(image_feature_size)
129
+ self.glcm_attention = AttentionBlock(64)
130
+ self.spectrum_attention = AttentionBlock(64)
131
+ self.edge_attention = AttentionBlock(32 * 8 * 8)
132
+ self.lbp_attention = AttentionBlock(64)
133
+
134
+ total_features = image_feature_size + 64 + 64 + (32 * 8 * 8) + 64
135
+ self.fusion = nn.Sequential(
136
+ nn.Linear(total_features, 512),
137
+ nn.BatchNorm1d(512),
138
+ nn.ReLU(),
139
+ nn.Dropout(0.5),
140
+ nn.Linear(512, 256),
141
+ nn.BatchNorm1d(256),
142
+ nn.ReLU(),
143
+ nn.Dropout(0.3),
144
+ nn.Linear(256, 1)
145
+ )
146
+
147
+ def forward(self, image, glcm_features, spectrum_features, edge_features, lbp_features):
148
+ image_features = self.efficientnet(image)
149
+ image_features = self.image_attention(image_features)
150
+
151
+ glcm_features = self.glcm_fc(glcm_features)
152
+ glcm_features = self.glcm_attention(glcm_features)
153
+
154
+ spectrum_features = self.spectrum_conv(spectrum_features.unsqueeze(1))
155
+ spectrum_features = spectrum_features.squeeze(2)
156
+ spectrum_features = self.spectrum_attention(spectrum_features)
157
+
158
+ edge_features = self.edge_conv(edge_features.unsqueeze(1))
159
+ edge_features = edge_features.view(edge_features.size(0), -1)
160
+ edge_features = self.edge_attention(edge_features)
161
+
162
+ lbp_features = self.lbp_fc(lbp_features)
163
+ lbp_features = self.lbp_attention(lbp_features)
164
+
165
+ combined_features = torch.cat(
166
+ (image_features, glcm_features, spectrum_features, edge_features, lbp_features), dim=1
167
+ )
168
+
169
+ output = self.fusion(combined_features)
170
+ return output.squeeze(1)
171
+
172
+ # Initialize model and transform
173
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
174
+ model = AdvancedFaceDetectionModel(spectrum_length=181, lbp_n_bins=10).to(device)
175
+ model.load_state_dict(torch.load('best_model.pth', map_location=device))
176
+ model.eval()
177
+
178
+ transform = transforms.Compose([
179
+ transforms.Resize((256, 256)),
180
+ transforms.ToTensor(),
181
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
182
+ ])
183
+
184
+ def predict_image(image):
185
+ """
186
+ Process a single image and return prediction
187
+ """
188
+ # Convert to PIL Image if needed
189
+ if not isinstance(image, Image.Image):
190
+ image = Image.fromarray(image)
191
+
192
+ # Convert to RGB if needed
193
+ if image.mode != 'RGB':
194
+ image = image.convert('RGB')
195
+
196
+ # Apply transformations
197
+ image_tensor = transform(image).unsqueeze(0)
198
+
199
+ # Convert to NumPy array for feature extraction
200
+ np_image = image_tensor.cpu().numpy().squeeze(0).transpose(1, 2, 0)
201
+ np_image = np.clip(np_image, 0, 1)
202
+
203
+ # Convert to grayscale
204
+ gray_image = cv2.cvtColor((np_image * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY)
205
+ gray_image = gray_image.astype(np.float32) / 255.0
206
+
207
+ # Extract features
208
+ glcm_features = extract_glcm_features(gray_image)
209
+ spectrum_features = analyze_spectrum(gray_image)
210
+ edge_features = extract_edge_features(gray_image)
211
+ lbp_features = extract_lbp_features(gray_image)
212
+
213
+ # Move everything to device
214
+ with torch.no_grad():
215
+ image_tensor = image_tensor.to(device)
216
+ glcm_features = torch.from_numpy(glcm_features).unsqueeze(0).to(device)
217
+ spectrum_features = torch.from_numpy(spectrum_features).unsqueeze(0).to(device)
218
+ edge_features = torch.from_numpy(edge_features).unsqueeze(0).to(device)
219
+ lbp_features = torch.from_numpy(lbp_features).unsqueeze(0).to(device)
220
+
221
+ # Forward pass
222
+ outputs = model(image_tensor, glcm_features, spectrum_features, edge_features, lbp_features)
223
+ probability = torch.sigmoid(outputs).item()
224
+ prediction = "Real Face" if probability > 0.5 else "Fake Face"
225
+
226
+ return prediction, f"Confidence: {probability:.2%}"
227
+
228
+ # Create Gradio interface
229
+ iface = gr.Interface(
230
+ fn=predict_image,
231
+ inputs=gr.Image(type="pil"),
232
+ outputs=[
233
+ gr.Label(label="Prediction"),
234
+ gr.Label(label="Confidence")
235
+ ],
236
+ title="Face Authentication System",
237
+ description="Upload an image to determine if it contains a real or fake face.",
238
+ examples=[
239
+ ["example1.jpg"],
240
+ ["example2.jpg"]
241
+ ] if os.path.exists("example1.jpg") else None,
242
+ )
243
+
244
+ # Launch the app
245
+ if __name__ == "__main__":
246
+ iface.launch()