Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torchvision.transforms as transforms
|
5 |
+
from PIL import Image
|
6 |
+
import numpy as np
|
7 |
+
import cv2
|
8 |
+
from scipy.fftpack import fft2, fftshift
|
9 |
+
from skimage.feature import graycomatrix, graycoprops, local_binary_pattern
|
10 |
+
import timm
|
11 |
+
import gradio as gr
|
12 |
+
|
13 |
+
# GLCM feature extraction
|
14 |
+
def extract_glcm_features(image):
|
15 |
+
image_uint8 = (image * 255).astype(np.uint8)
|
16 |
+
image_uint8 = image_uint8 // 4
|
17 |
+
|
18 |
+
glcm = graycomatrix(
|
19 |
+
image_uint8,
|
20 |
+
distances=[1],
|
21 |
+
angles=[0, np.pi / 4, np.pi / 2, 3 * np.pi / 4],
|
22 |
+
levels=64,
|
23 |
+
symmetric=True,
|
24 |
+
normed=True
|
25 |
+
)
|
26 |
+
|
27 |
+
contrast = graycoprops(glcm, 'contrast').flatten()
|
28 |
+
dissimilarity = graycoprops(glcm, 'dissimilarity').flatten()
|
29 |
+
homogeneity = graycoprops(glcm, 'homogeneity').flatten()
|
30 |
+
energy = graycoprops(glcm, 'energy').flatten()
|
31 |
+
correlation = graycoprops(glcm, 'correlation').flatten()
|
32 |
+
|
33 |
+
features = np.hstack([contrast, dissimilarity, homogeneity, energy, correlation])
|
34 |
+
return features.astype(np.float32)
|
35 |
+
|
36 |
+
# Spectrum analysis
|
37 |
+
def analyze_spectrum(image, target_spectrum_length=181):
|
38 |
+
f = fft2(image)
|
39 |
+
fshift = fftshift(f)
|
40 |
+
magnitude_spectrum = 20 * np.log(np.abs(fshift) + 1e-8)
|
41 |
+
|
42 |
+
center = np.array(magnitude_spectrum.shape) // 2
|
43 |
+
y, x = np.indices(magnitude_spectrum.shape)
|
44 |
+
r = np.sqrt((x - center[1])**2 + (y - center[0])**2).astype(int)
|
45 |
+
|
46 |
+
radial_mean = np.bincount(r.ravel(), magnitude_spectrum.ravel()) / np.bincount(r.ravel())
|
47 |
+
|
48 |
+
if len(radial_mean) < target_spectrum_length:
|
49 |
+
radial_mean = np.pad(radial_mean, (0, target_spectrum_length - len(radial_mean)), 'constant')
|
50 |
+
else:
|
51 |
+
radial_mean = radial_mean[:target_spectrum_length]
|
52 |
+
|
53 |
+
return radial_mean.astype(np.float32)
|
54 |
+
|
55 |
+
# Edge feature extraction
|
56 |
+
def extract_edge_features(image):
|
57 |
+
image_uint8 = (image * 255).astype(np.uint8)
|
58 |
+
edges = cv2.Canny(image_uint8, 100, 200)
|
59 |
+
edges_resized = cv2.resize(edges, (64, 64), interpolation=cv2.INTER_AREA)
|
60 |
+
return edges_resized.astype(np.float32) / 255.0
|
61 |
+
|
62 |
+
# LBP feature extraction
|
63 |
+
def extract_lbp_features(image):
|
64 |
+
radius = 1
|
65 |
+
n_points = 8 * radius
|
66 |
+
METHOD = 'uniform'
|
67 |
+
|
68 |
+
lbp = local_binary_pattern(image, n_points, radius, METHOD)
|
69 |
+
n_bins = n_points + 2
|
70 |
+
hist, _ = np.histogram(lbp.ravel(), bins=n_bins, range=(0, n_bins), density=True)
|
71 |
+
|
72 |
+
return hist.astype(np.float32)
|
73 |
+
|
74 |
+
# Model architecture
|
75 |
+
class AttentionBlock(nn.Module):
|
76 |
+
def __init__(self, in_features):
|
77 |
+
super(AttentionBlock, self).__init__()
|
78 |
+
self.attention = nn.Sequential(
|
79 |
+
nn.Linear(in_features, max(in_features // 8, 1)),
|
80 |
+
nn.ReLU(),
|
81 |
+
nn.Linear(max(in_features // 8, 1), in_features),
|
82 |
+
nn.Sigmoid()
|
83 |
+
)
|
84 |
+
|
85 |
+
def forward(self, x):
|
86 |
+
attention_weights = self.attention(x)
|
87 |
+
return x * attention_weights
|
88 |
+
|
89 |
+
class AdvancedFaceDetectionModel(nn.Module):
|
90 |
+
def __init__(self, spectrum_length=181, lbp_n_bins=10):
|
91 |
+
super(AdvancedFaceDetectionModel, self).__init__()
|
92 |
+
|
93 |
+
self.efficientnet = timm.create_model('tf_efficientnetv2_b2', pretrained=False, num_classes=0)
|
94 |
+
for param in self.efficientnet.conv_stem.parameters():
|
95 |
+
param.requires_grad = False
|
96 |
+
for param in self.efficientnet.bn1.parameters():
|
97 |
+
param.requires_grad = False
|
98 |
+
|
99 |
+
self.glcm_fc = nn.Sequential(
|
100 |
+
nn.Linear(20, 64),
|
101 |
+
nn.BatchNorm1d(64),
|
102 |
+
nn.ReLU(),
|
103 |
+
nn.Dropout(0.5)
|
104 |
+
)
|
105 |
+
|
106 |
+
self.spectrum_conv = nn.Sequential(
|
107 |
+
nn.Conv1d(1, 64, kernel_size=3, padding=1),
|
108 |
+
nn.BatchNorm1d(64),
|
109 |
+
nn.ReLU(),
|
110 |
+
nn.AdaptiveAvgPool1d(1)
|
111 |
+
)
|
112 |
+
|
113 |
+
self.edge_conv = nn.Sequential(
|
114 |
+
nn.Conv2d(1, 32, kernel_size=3, padding=1),
|
115 |
+
nn.BatchNorm2d(32),
|
116 |
+
nn.ReLU(),
|
117 |
+
nn.AdaptiveAvgPool2d((8, 8))
|
118 |
+
)
|
119 |
+
|
120 |
+
self.lbp_fc = nn.Sequential(
|
121 |
+
nn.Linear(lbp_n_bins, 64),
|
122 |
+
nn.BatchNorm1d(64),
|
123 |
+
nn.ReLU(),
|
124 |
+
nn.Dropout(0.5)
|
125 |
+
)
|
126 |
+
|
127 |
+
image_feature_size = self.efficientnet.num_features
|
128 |
+
self.image_attention = AttentionBlock(image_feature_size)
|
129 |
+
self.glcm_attention = AttentionBlock(64)
|
130 |
+
self.spectrum_attention = AttentionBlock(64)
|
131 |
+
self.edge_attention = AttentionBlock(32 * 8 * 8)
|
132 |
+
self.lbp_attention = AttentionBlock(64)
|
133 |
+
|
134 |
+
total_features = image_feature_size + 64 + 64 + (32 * 8 * 8) + 64
|
135 |
+
self.fusion = nn.Sequential(
|
136 |
+
nn.Linear(total_features, 512),
|
137 |
+
nn.BatchNorm1d(512),
|
138 |
+
nn.ReLU(),
|
139 |
+
nn.Dropout(0.5),
|
140 |
+
nn.Linear(512, 256),
|
141 |
+
nn.BatchNorm1d(256),
|
142 |
+
nn.ReLU(),
|
143 |
+
nn.Dropout(0.3),
|
144 |
+
nn.Linear(256, 1)
|
145 |
+
)
|
146 |
+
|
147 |
+
def forward(self, image, glcm_features, spectrum_features, edge_features, lbp_features):
|
148 |
+
image_features = self.efficientnet(image)
|
149 |
+
image_features = self.image_attention(image_features)
|
150 |
+
|
151 |
+
glcm_features = self.glcm_fc(glcm_features)
|
152 |
+
glcm_features = self.glcm_attention(glcm_features)
|
153 |
+
|
154 |
+
spectrum_features = self.spectrum_conv(spectrum_features.unsqueeze(1))
|
155 |
+
spectrum_features = spectrum_features.squeeze(2)
|
156 |
+
spectrum_features = self.spectrum_attention(spectrum_features)
|
157 |
+
|
158 |
+
edge_features = self.edge_conv(edge_features.unsqueeze(1))
|
159 |
+
edge_features = edge_features.view(edge_features.size(0), -1)
|
160 |
+
edge_features = self.edge_attention(edge_features)
|
161 |
+
|
162 |
+
lbp_features = self.lbp_fc(lbp_features)
|
163 |
+
lbp_features = self.lbp_attention(lbp_features)
|
164 |
+
|
165 |
+
combined_features = torch.cat(
|
166 |
+
(image_features, glcm_features, spectrum_features, edge_features, lbp_features), dim=1
|
167 |
+
)
|
168 |
+
|
169 |
+
output = self.fusion(combined_features)
|
170 |
+
return output.squeeze(1)
|
171 |
+
|
172 |
+
# Initialize model and transform
|
173 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
174 |
+
model = AdvancedFaceDetectionModel(spectrum_length=181, lbp_n_bins=10).to(device)
|
175 |
+
model.load_state_dict(torch.load('best_model.pth', map_location=device))
|
176 |
+
model.eval()
|
177 |
+
|
178 |
+
transform = transforms.Compose([
|
179 |
+
transforms.Resize((256, 256)),
|
180 |
+
transforms.ToTensor(),
|
181 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
182 |
+
])
|
183 |
+
|
184 |
+
def predict_image(image):
|
185 |
+
"""
|
186 |
+
Process a single image and return prediction
|
187 |
+
"""
|
188 |
+
# Convert to PIL Image if needed
|
189 |
+
if not isinstance(image, Image.Image):
|
190 |
+
image = Image.fromarray(image)
|
191 |
+
|
192 |
+
# Convert to RGB if needed
|
193 |
+
if image.mode != 'RGB':
|
194 |
+
image = image.convert('RGB')
|
195 |
+
|
196 |
+
# Apply transformations
|
197 |
+
image_tensor = transform(image).unsqueeze(0)
|
198 |
+
|
199 |
+
# Convert to NumPy array for feature extraction
|
200 |
+
np_image = image_tensor.cpu().numpy().squeeze(0).transpose(1, 2, 0)
|
201 |
+
np_image = np.clip(np_image, 0, 1)
|
202 |
+
|
203 |
+
# Convert to grayscale
|
204 |
+
gray_image = cv2.cvtColor((np_image * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY)
|
205 |
+
gray_image = gray_image.astype(np.float32) / 255.0
|
206 |
+
|
207 |
+
# Extract features
|
208 |
+
glcm_features = extract_glcm_features(gray_image)
|
209 |
+
spectrum_features = analyze_spectrum(gray_image)
|
210 |
+
edge_features = extract_edge_features(gray_image)
|
211 |
+
lbp_features = extract_lbp_features(gray_image)
|
212 |
+
|
213 |
+
# Move everything to device
|
214 |
+
with torch.no_grad():
|
215 |
+
image_tensor = image_tensor.to(device)
|
216 |
+
glcm_features = torch.from_numpy(glcm_features).unsqueeze(0).to(device)
|
217 |
+
spectrum_features = torch.from_numpy(spectrum_features).unsqueeze(0).to(device)
|
218 |
+
edge_features = torch.from_numpy(edge_features).unsqueeze(0).to(device)
|
219 |
+
lbp_features = torch.from_numpy(lbp_features).unsqueeze(0).to(device)
|
220 |
+
|
221 |
+
# Forward pass
|
222 |
+
outputs = model(image_tensor, glcm_features, spectrum_features, edge_features, lbp_features)
|
223 |
+
probability = torch.sigmoid(outputs).item()
|
224 |
+
prediction = "Real Face" if probability > 0.5 else "Fake Face"
|
225 |
+
|
226 |
+
return prediction, f"Confidence: {probability:.2%}"
|
227 |
+
|
228 |
+
# Create Gradio interface
|
229 |
+
iface = gr.Interface(
|
230 |
+
fn=predict_image,
|
231 |
+
inputs=gr.Image(type="pil"),
|
232 |
+
outputs=[
|
233 |
+
gr.Label(label="Prediction"),
|
234 |
+
gr.Label(label="Confidence")
|
235 |
+
],
|
236 |
+
title="Face Authentication System",
|
237 |
+
description="Upload an image to determine if it contains a real or fake face.",
|
238 |
+
examples=[
|
239 |
+
["example1.jpg"],
|
240 |
+
["example2.jpg"]
|
241 |
+
] if os.path.exists("example1.jpg") else None,
|
242 |
+
)
|
243 |
+
|
244 |
+
# Launch the app
|
245 |
+
if __name__ == "__main__":
|
246 |
+
iface.launch()
|