Spaces:
Running
Running
import gradio as gr | |
import torch | |
import torch.nn.functional as F | |
from facenet_pytorch import MTCNN, InceptionResnetV1 | |
import cv2 | |
from pytorch_grad_cam import GradCAM | |
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget | |
from pytorch_grad_cam.utils.image import show_cam_on_image | |
from PIL import Image | |
import numpy as np | |
import warnings | |
warnings.filterwarnings("ignore") | |
# Download and Load Model | |
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' | |
mtcnn = MTCNN( | |
select_largest=False, | |
post_process=False, | |
device=DEVICE | |
).to(DEVICE).eval() | |
model = InceptionResnetV1( | |
pretrained="vggface2", | |
classify=True, | |
num_classes=1, | |
device=DEVICE | |
) | |
checkpoint = torch.load("resnetinceptionv1_epoch_32.pth", map_location=torch.device('cpu')) | |
model.load_state_dict(checkpoint['model_state_dict']) | |
model.to(DEVICE) | |
model.eval() | |
# Model Inference | |
def predict_frame(frame): | |
"""Predict whether the input frame contains a real or fake face""" | |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
frame_pil = Image.fromarray(frame) | |
face = mtcnn(frame_pil) | |
if face is None: | |
return None, None # No face detected | |
# Preprocess the face | |
face = F.interpolate(face.unsqueeze(0), size=(256, 256), mode='bilinear', align_corners=False) | |
face = face.to(DEVICE, dtype=torch.float32) / 255.0 | |
# Predict | |
with torch.no_grad(): | |
output = torch.sigmoid(model(face).squeeze(0)) | |
prediction = "real" if output.item() < 0.5 else "fake" | |
# Confidence scores | |
real_prediction = 1 - output.item() | |
fake_prediction = output.item() | |
confidences = { | |
'real': real_prediction, | |
'fake': fake_prediction | |
} | |
# Visualize | |
target_layers = [model.block8.branch1[-1]] | |
use_cuda = True if torch.cuda.is_available() else False | |
cam = GradCAM(model=model, target_layers=target_layers, use_cuda=use_cuda) | |
targets = [ClassifierOutputTarget(0)] | |
grayscale_cam = cam(input_tensor=face, targets=targets, eigen_smooth=True) | |
grayscale_cam = grayscale_cam[0, :] | |
face_np = face.squeeze(0).permute(1, 2, 0).cpu().numpy() | |
visualization = show_cam_on_image(face_np, grayscale_cam, use_rgb=True) | |
face_with_mask = cv2.addWeighted((face_np * 255).astype(np.uint8), 1, (visualization * 255).astype(np.uint8), 0.5, 0) | |
return prediction, face_with_mask | |
def predict_video(input_video): | |
cap = cv2.VideoCapture(input_video) | |
frames = [] | |
confidences = [] | |
frame_count = 0 | |
skip_frames = 20 | |
while True: | |
ret, frame = cap.read() | |
if not ret: | |
break | |
frame_count+=1 | |
if frame_count % skip_frames != 0: # Skip frames if not divisible by skip_frames | |
continue | |
prediction, frame_with_mask = predict_frame(frame) | |
frames.append(frame_with_mask) | |
confidences.append(prediction) | |
cap.release() | |
# Determine the final prediction based on the maximum occurrence of predictions | |
final_prediction = 'fake' if confidences.count('fake') > confidences.count('real') else 'real' | |
return final_prediction | |
# Gradio Interface | |
interface = gr.Interface( | |
fn=predict_video, | |
inputs=[ | |
gr.Video(label="Input Video") | |
], | |
outputs=[ | |
gr.Label(label="Class"), | |
], | |
title="Deep fake video Detection", | |
description="Detect whether the Video is fake or real" | |
) | |
interface.launch() | |