import os import torch import torch.nn as nn import torchvision.transforms as transforms from PIL import Image import numpy as np import cv2 from scipy.fftpack import fft2, fftshift from skimage.feature import graycomatrix, graycoprops, local_binary_pattern import timm import gradio as gr class AttentionBlock(nn.Module): def __init__(self, in_features): super(AttentionBlock, self).__init__() self.attention = nn.Sequential( nn.Linear(in_features, max(in_features // 8, 1)), nn.ReLU(), nn.Linear(max(in_features // 8, 1), in_features), nn.Sigmoid() ) def forward(self, x): attention_weights = self.attention(x) return x * attention_weights class AdvancedFaceDetectionModel(nn.Module): def __init__(self, spectrum_length=181, lbp_n_bins=10): super(AdvancedFaceDetectionModel, self).__init__() self.efficientnet = timm.create_model('tf_efficientnetv2_b2', pretrained=True, num_classes=0) for param in self.efficientnet.conv_stem.parameters(): param.requires_grad = False for param in self.efficientnet.bn1.parameters(): param.requires_grad = False self.glcm_fc = nn.Sequential( nn.Linear(20, 64), nn.BatchNorm1d(64), nn.ReLU(), nn.Dropout(0.5) ) self.spectrum_conv = nn.Sequential( nn.Conv1d(1, 64, kernel_size=3, padding=1), nn.BatchNorm1d(64), nn.ReLU(), nn.AdaptiveAvgPool1d(1) ) self.edge_conv = nn.Sequential( nn.Conv2d(1, 32, kernel_size=3, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.AdaptiveAvgPool2d((8, 8)) ) self.lbp_fc = nn.Sequential( nn.Linear(lbp_n_bins, 64), nn.BatchNorm1d(64), nn.ReLU(), nn.Dropout(0.5) ) image_feature_size = self.efficientnet.num_features self.image_attention = AttentionBlock(image_feature_size) self.glcm_attention = AttentionBlock(64) self.spectrum_attention = AttentionBlock(64) self.edge_attention = AttentionBlock(32 * 8 * 8) self.lbp_attention = AttentionBlock(64) total_features = image_feature_size + 64 + 64 + (32 * 8 * 8) + 64 self.fusion = nn.Sequential( nn.Linear(total_features, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(0.5), nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.3), nn.Linear(256, 1) ) def forward(self, image, glcm_features, spectrum_features, edge_features, lbp_features): image_features = self.efficientnet(image) image_features = self.image_attention(image_features) glcm_features = self.glcm_fc(glcm_features) glcm_features = self.glcm_attention(glcm_features) spectrum_features = self.spectrum_conv(spectrum_features.unsqueeze(1)) spectrum_features = spectrum_features.squeeze(2) spectrum_features = self.spectrum_attention(spectrum_features) edge_features = self.edge_conv(edge_features.unsqueeze(1)) edge_features = edge_features.view(edge_features.size(0), -1) edge_features = self.edge_attention(edge_features) lbp_features = self.lbp_fc(lbp_features) lbp_features = self.lbp_attention(lbp_features) combined_features = torch.cat( (image_features, glcm_features, spectrum_features, edge_features, lbp_features), dim=1 ) output = self.fusion(combined_features) return output.squeeze(1) # 特征提取函数 def extract_glcm_features(image): image_uint8 = (image * 255).astype(np.uint8) image_uint8 = image_uint8 // 4 glcm = graycomatrix( image_uint8, distances=[1], angles=[0, np.pi / 4, np.pi / 2, 3 * np.pi / 4], levels=64, symmetric=True, normed=True ) contrast = graycoprops(glcm, 'contrast').flatten() dissimilarity = graycoprops(glcm, 'dissimilarity').flatten() homogeneity = graycoprops(glcm, 'homogeneity').flatten() energy = graycoprops(glcm, 'energy').flatten() correlation = graycoprops(glcm, 'correlation').flatten() features = np.hstack([contrast, dissimilarity, homogeneity, energy, correlation]) return features.astype(np.float32) def analyze_spectrum(image, target_spectrum_length=181): f = fft2(image) fshift = fftshift(f) magnitude_spectrum = 20 * np.log(np.abs(fshift) + 1e-8) center = np.array(magnitude_spectrum.shape) // 2 y, x = np.indices(magnitude_spectrum.shape) r = np.sqrt((x - center[1])**2 + (y - center[0])**2).astype(int) radial_mean = np.bincount(r.ravel(), magnitude_spectrum.ravel()) / np.bincount(r.ravel()) if len(radial_mean) < target_spectrum_length: radial_mean = np.pad(radial_mean, (0, target_spectrum_length - len(radial_mean)), 'constant') else: radial_mean = radial_mean[:target_spectrum_length] return radial_mean.astype(np.float32) def extract_edge_features(image): image_uint8 = (image * 255).astype(np.uint8) edges = cv2.Canny(image_uint8, 100, 200) edges_resized = cv2.resize(edges, (64, 64), interpolation=cv2.INTER_AREA) return edges_resized.astype(np.float32) / 255.0 def extract_lbp_features(image): radius = 1 n_points = 8 * radius METHOD = 'uniform' lbp = local_binary_pattern(image, n_points, radius, METHOD) n_bins = n_points + 2 hist, _ = np.histogram(lbp.ravel(), bins=n_bins, range=(0, n_bins), density=True) return hist.astype(np.float32) # 加载模型 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = AdvancedFaceDetectionModel(spectrum_length=181, lbp_n_bins=10).to(device) model.load_state_dict(torch.load('best_model.pth', map_location=device)) model.eval() # 图像预处理转换 transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def predict_image(image): """ Process uploaded image and return prediction result """ if image is None: return "Please upload an image" # Convert image format if isinstance(image, np.ndarray): image = Image.fromarray(image) # Apply transformations image_tensor = transform(image).unsqueeze(0) # Prepare image for feature extraction np_image = image_tensor.cpu().numpy().squeeze(0).transpose(1, 2, 0) np_image = np.clip(np_image, 0, 1) # Convert to grayscale gray_image = cv2.cvtColor((np_image * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY) gray_image = gray_image.astype(np.float32) / 255.0 # Extract features glcm_features = extract_glcm_features(gray_image) spectrum_features = analyze_spectrum(gray_image) edge_features = extract_edge_features(gray_image) lbp_features = extract_lbp_features(gray_image) # Convert to tensors and move to device with torch.no_grad(): image_tensor = image_tensor.to(device) glcm_features = torch.from_numpy(glcm_features).unsqueeze(0).to(device) spectrum_features = torch.from_numpy(spectrum_features).unsqueeze(0).to(device) edge_features = torch.from_numpy(edge_features).unsqueeze(0).to(device) lbp_features = torch.from_numpy(lbp_features).unsqueeze(0).to(device) # Model prediction outputs = model(image_tensor, glcm_features, spectrum_features, edge_features, lbp_features) prediction = torch.sigmoid(outputs).item() # Return prediction result (corrected logic) if prediction < 0.5: # Changed from > to < return "Real Face" else: return "AI-Generated Face" # Create Gradio interface iface = gr.Interface( fn=predict_image, inputs=gr.Image(type="pil"), outputs=gr.Text(label="Prediction Result"), title="Face Authentication System", description="Upload a face image to determine if it's a real face or an AI-generated face.", examples=[ # Add example image paths here ], article=""" This system uses advanced deep learning techniques to detect whether a face image is real or AI-generated. The model analyzes various image features including texture patterns, frequency spectrum, and local binary patterns to make its determination. """ ) # Launch the application iface.launch()