Spaces:
Sleeping
Sleeping
import spaces | |
import gradio as gr | |
import torch | |
import torch.nn.functional as F | |
from facenet_pytorch import MTCNN, InceptionResnetV1 | |
import os | |
import numpy as np | |
from PIL import Image as PILImage | |
import zipfile | |
import cv2 | |
from pytorch_grad_cam import GradCAM | |
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget | |
from pytorch_grad_cam.utils.image import show_cam_on_image | |
import tempfile | |
import matplotlib.pyplot as plt | |
from tqdm import tqdm | |
import io | |
with zipfile.ZipFile("examples.zip","r") as zip_ref: | |
zip_ref.extractall(".") | |
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu' | |
mtcnn = MTCNN( | |
select_largest=False, | |
post_process=False, | |
device=DEVICE | |
).to(DEVICE).eval() | |
model = InceptionResnetV1( | |
pretrained="vggface2", | |
classify=True, | |
num_classes=1, | |
device=DEVICE | |
) | |
checkpoint = torch.load("resnetinceptionv1_epoch_32.pth", map_location=torch.device('cpu')) | |
model.load_state_dict(checkpoint['model_state_dict']) | |
model.to(DEVICE) | |
model.eval() | |
EXAMPLES_FOLDER = 'examples' | |
examples_names = os.listdir(EXAMPLES_FOLDER) | |
examples = [] | |
for example_name in examples_names: | |
example_path = os.path.join(EXAMPLES_FOLDER, example_name) | |
label = example_name.split('_')[0] | |
example = { | |
'path': example_path, | |
'label': label | |
} | |
examples.append(example) | |
np.random.shuffle(examples) # shuffle | |
def process_frame(frame, mtcnn, model, cam, targets): | |
face = mtcnn(PILImage.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))) | |
if face is None: | |
return frame, None, None | |
face = face.unsqueeze(0) | |
face = F.interpolate(face, size=(256, 256), mode='bilinear', align_corners=False) | |
face = face.to(DEVICE) | |
face = face.to(torch.float32) | |
face = face / 255.0 | |
face_image_to_plot = face.squeeze(0).permute(1, 2, 0).cpu().detach().numpy() | |
grayscale_cam = cam(input_tensor=face, targets=targets, eigen_smooth=True) | |
grayscale_cam = grayscale_cam[0, :] | |
visualization = show_cam_on_image(face_image_to_plot, grayscale_cam, use_rgb=True) | |
with torch.no_grad(): | |
output = torch.sigmoid(model(face).squeeze(0)) | |
prediction = "real" if output.item() < 0.5 else "fake" | |
confidence = 1 - output.item() if prediction == "real" else output.item() | |
return visualization, prediction, confidence | |
def analyze_video(input_video: str): | |
"""Analyze the video for deepfake detection""" | |
cap = cv2.VideoCapture(input_video) | |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
target_layers = [model.block8.branch1[-1]] | |
cam = GradCAM(model=model, target_layers=target_layers) | |
targets = [ClassifierOutputTarget(0)] | |
frame_confidences = [] | |
frame_predictions = [] | |
for _ in tqdm(range(total_frames), desc="Analyzing video"): | |
ret, frame = cap.read() | |
if not ret: | |
break | |
_, prediction, confidence = process_frame(frame, mtcnn, model, cam, targets) | |
if prediction is not None and confidence is not None: | |
frame_confidences.append(confidence) | |
frame_predictions.append(1 if prediction == "fake" else 0) | |
cap.release() | |
# Calculate metrics | |
fake_percentage = (sum(frame_predictions) / len(frame_predictions)) * 100 if frame_predictions else 0 | |
avg_confidence = np.mean(frame_confidences) if frame_confidences else 0 | |
# Create graphs | |
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 12)) | |
# Confidence over time | |
ax1.plot(frame_confidences) | |
ax1.set_title("Confidence Over Time") | |
ax1.set_xlabel("Frame") | |
ax1.set_ylabel("Confidence") | |
ax1.set_ylim(0, 1) | |
# Prediction distribution | |
ax2.hist(frame_predictions, bins=[0, 0.5, 1], rwidth=0.8) | |
ax2.set_title("Distribution of Predictions") | |
ax2.set_xlabel("Prediction (0: Real, 1: Fake)") | |
ax2.set_ylabel("Count") | |
# Save plot to bytes | |
buf = io.BytesIO() | |
plt.savefig(buf, format='png') | |
buf.seek(0) | |
# Create progress bar image | |
progress_fig, progress_ax = plt.subplots(figsize=(8, 2)) | |
progress_ax.barh(["Fake"], [fake_percentage], color='red') | |
progress_ax.barh(["Fake"], [100 - fake_percentage], left=[fake_percentage], color='green') | |
progress_ax.set_xlim(0, 100) | |
progress_ax.set_title("Fake Percentage") | |
progress_ax.set_xlabel("Percentage") | |
progress_ax.text(fake_percentage, 0, f"{fake_percentage:.1f}%", va='center', ha='left') | |
# Save progress bar to bytes | |
progress_buf = io.BytesIO() | |
progress_fig.savefig(progress_buf, format='png') | |
progress_buf.seek(0) | |
return { | |
"fake_percentage": fake_percentage, | |
"avg_confidence": avg_confidence, | |
"analysis_plot": buf, | |
"progress_bar": progress_buf, | |
"total_frames": total_frames, | |
"processed_frames": len(frame_confidences) | |
} | |
def format_results(results): | |
return f""" | |
Analysis Results: | |
- Fake Percentage: {results['fake_percentage']:.2f}% | |
- Average Confidence: {results['avg_confidence']:.2f} | |
- Total Frames: {results['total_frames']} | |
- Processed Frames: {results['processed_frames']} | |
""" | |
def analyze_and_format(input_video): | |
results = analyze_video(input_video) | |
text_results = format_results(results) | |
# Convert BytesIO to PIL Images | |
analysis_plot = PILImage.open(results['analysis_plot']) | |
progress_bar = PILImage.open(results['progress_bar']) | |
return text_results, analysis_plot, progress_bar | |
interface = gr.Interface( | |
fn=analyze_and_format, | |
inputs=[ | |
gr.Video(label="Input Video") | |
], | |
outputs=[ | |
gr.Textbox(label="Analysis Results"), | |
gr.Image(label="Analysis Plots"), | |
gr.Image(label="Fake Percentage") | |
], | |
title="Video Deepfake Analysis", | |
description="Upload a video to analyze for potential deepfakes.", | |
examples=[] | |
) | |
if __name__ == "__main__": | |
interface.launch(share=True) |