File size: 4,230 Bytes
fddcc1b
c6185a5
 
 
 
 
 
f17100d
c6185a5
02c225a
 
 
 
10269f1
c6185a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248f463
c6185a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fddcc1b
10269f1
f17100d
c6185a5
10269f1
 
 
c6185a5
 
 
 
 
10269f1
02c225a
 
 
 
10269f1
c6185a5
 
 
10269f1
 
 
 
 
f17100d
 
10269f1
f17100d
 
 
10269f1
 
 
 
 
f17100d
 
10269f1
f17100d
10269f1
 
 
c6185a5
f17100d
c6185a5
f17100d
 
 
 
 
510c41e
 
 
 
f17100d
 
 
 
 
 
10269f1
 
f17100d
10269f1
f17100d
c6185a5
 
f17100d
c6185a5
10269f1
c6185a5
 
f17100d
c6185a5
f17100d
 
10269f1
 
 
f17100d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import spaces
import gradio as gr
import torch
import torch.nn.functional as F
from facenet_pytorch import MTCNN, InceptionResnetV1
import os
import numpy as np
from PIL import Image
import zipfile
import cv2
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
import tempfile

with zipfile.ZipFile("examples.zip","r") as zip_ref:
    zip_ref.extractall(".")

DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'

mtcnn = MTCNN(
    select_largest=False,
    post_process=False,
    device=DEVICE
).to(DEVICE).eval()

model = InceptionResnetV1(
    pretrained="vggface2",
    classify=True,
    num_classes=1,
    device=DEVICE
)

checkpoint = torch.load("resnetinceptionv1_epoch_32.pth", map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['model_state_dict'])
model.to(DEVICE)
model.eval()

EXAMPLES_FOLDER = 'examples'
examples_names = os.listdir(EXAMPLES_FOLDER)
examples = []
for example_name in examples_names:
    example_path = os.path.join(EXAMPLES_FOLDER, example_name)
    label = example_name.split('_')[0]
    example = {
        'path': example_path,
        'label': label
    }
    examples.append(example)
np.random.shuffle(examples) # shuffle

@spaces.GPU
def process_frame(frame, mtcnn, model, cam, targets):
    face = mtcnn(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
    if face is None:
        return frame, None, None
    
    face = face.unsqueeze(0)
    face = F.interpolate(face, size=(256, 256), mode='bilinear', align_corners=False)
    
    face = face.to(DEVICE)
    face = face.to(torch.float32)
    face = face / 255.0
    face_image_to_plot = face.squeeze(0).permute(1, 2, 0).cpu().detach().numpy()

    grayscale_cam = cam(input_tensor=face, targets=targets, eigen_smooth=True)
    grayscale_cam = grayscale_cam[0, :]
    visualization = show_cam_on_image(face_image_to_plot, grayscale_cam, use_rgb=True)
    
    with torch.no_grad():
        output = torch.sigmoid(model(face).squeeze(0))
        prediction = "real" if output.item() < 0.5 else "fake"
        confidence = 1 - output.item() if prediction == "real" else output.item()
    
    return visualization, prediction, confidence

@spaces.GPU
def predict_video(input_video: str):
    """Predict the labels for each frame of the input video"""
    cap = cv2.VideoCapture(input_video)
    fps = cap.get(cv2.CAP_PROP_FPS)
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    
    target_layers = [model.block8.branch1[-1]]
    cam = GradCAM(model=model, target_layers=target_layers)
    targets = [ClassifierOutputTarget(0)]
    
    temp_output = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
    out = cv2.VideoWriter(temp_output.name, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
    
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        
        processed_frame, prediction, confidence = process_frame(frame, mtcnn, model, cam, targets)
        
        if processed_frame is not None:
            # Resize the processed frame to match the original video dimensions
            processed_frame = cv2.resize(processed_frame, (width, height))
            
            # Add text with prediction and confidence
            if prediction is not None and confidence is not None:
                text = f"{prediction}: {confidence:.2f}"
            else:
                text = "No prediction available"
            cv2.putText(processed_frame, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
            
            out.write(processed_frame)
        else:
            # If no face is detected, write the original frame
            out.write(frame)
    
    cap.release()
    out.release()
    
    return temp_output.name

interface = gr.Interface(
    fn=predict_video,
    inputs=[
        gr.Video(label="Input Video")
    ],
    outputs=[
        gr.Video(label="Output Video")
    ],
    title="Video Deepfake Detection",
    description="Upload a video to detect deepfakes in each frame."
)

if __name__ == "__main__":
    interface.launch()