Update app.py
Browse files
app.py
CHANGED
@@ -1,12 +1,14 @@
|
|
1 |
-
import spaces
|
2 |
import os
|
3 |
-
from flask import Flask, request, jsonify
|
4 |
-
from werkzeug.utils import secure_filename
|
5 |
import cv2
|
|
|
6 |
import torch
|
|
|
7 |
import torch.nn.functional as F
|
8 |
from facenet_pytorch import MTCNN, InceptionResnetV1
|
9 |
-
|
|
|
|
|
|
|
10 |
|
11 |
app = Flask(__name__)
|
12 |
|
@@ -21,20 +23,50 @@ os.makedirs(UPLOAD_FOLDER, exist_ok=True)
|
|
21 |
# Device configuration
|
22 |
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
23 |
|
|
|
24 |
mtcnn = MTCNN(select_largest=False, post_process=False, device=DEVICE).to(DEVICE).eval()
|
25 |
|
26 |
-
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
model.load_state_dict(checkpoint['model_state_dict'])
|
29 |
-
model.to(DEVICE)
|
30 |
model.eval()
|
31 |
|
32 |
def allowed_file(filename):
|
33 |
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
|
34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
@spaces.GPU
|
36 |
def process_frame(frame):
|
37 |
-
|
|
|
38 |
if face is None:
|
39 |
return None, None
|
40 |
|
@@ -46,7 +78,7 @@ def process_frame(frame):
|
|
46 |
face = face / 255.0
|
47 |
|
48 |
with torch.no_grad():
|
49 |
-
output =
|
50 |
prediction = "fake" if output.item() >= 0.5 else "real"
|
51 |
|
52 |
return prediction, output.item()
|
@@ -76,14 +108,10 @@ def preprocess_video(video_path, output_path):
|
|
76 |
out.release()
|
77 |
|
78 |
@spaces.GPU
|
79 |
-
def
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
cap = cv2.VideoCapture(preprocessed_path)
|
84 |
-
frame_count = 0
|
85 |
-
fake_count = 0
|
86 |
-
total_processed = 0
|
87 |
|
88 |
while cap.isOpened():
|
89 |
ret, frame = cap.read()
|
@@ -94,17 +122,26 @@ def analyze_video(video_path):
|
|
94 |
prediction, confidence = process_frame(rgb_frame)
|
95 |
|
96 |
if prediction is not None:
|
97 |
-
|
98 |
-
if prediction == "fake":
|
99 |
-
fake_count += 1
|
100 |
|
101 |
-
|
|
|
|
|
102 |
|
103 |
cap.release()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
os.remove(preprocessed_path)
|
105 |
|
106 |
-
if
|
107 |
-
fake_percentage = (
|
108 |
return fake_percentage
|
109 |
else:
|
110 |
return 0
|
|
|
|
|
1 |
import os
|
|
|
|
|
2 |
import cv2
|
3 |
+
import numpy as np
|
4 |
import torch
|
5 |
+
import torch.nn as nn
|
6 |
import torch.nn.functional as F
|
7 |
from facenet_pytorch import MTCNN, InceptionResnetV1
|
8 |
+
from collections import deque
|
9 |
+
from flask import Flask, request, jsonify
|
10 |
+
from werkzeug.utils import secure_filename
|
11 |
+
import spaces
|
12 |
|
13 |
app = Flask(__name__)
|
14 |
|
|
|
23 |
# Device configuration
|
24 |
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
25 |
|
26 |
+
# Model initialization
|
27 |
mtcnn = MTCNN(select_largest=False, post_process=False, device=DEVICE).to(DEVICE).eval()
|
28 |
|
29 |
+
@spaces.GPU
|
30 |
+
class EnsembleModel(nn.Module):
|
31 |
+
def __init__(self, num_models=3):
|
32 |
+
super(EnsembleModel, self).__init__()
|
33 |
+
self.models = nn.ModuleList([
|
34 |
+
InceptionResnetV1(pretrained="vggface2", classify=True, num_classes=1)
|
35 |
+
for _ in range(num_models)
|
36 |
+
])
|
37 |
+
|
38 |
+
def forward(self, x):
|
39 |
+
outputs = [torch.sigmoid(model(x)) for model in self.models]
|
40 |
+
return torch.mean(torch.stack(outputs), dim=0)
|
41 |
+
|
42 |
+
model = EnsembleModel().to(DEVICE)
|
43 |
+
checkpoint = torch.load("ensemble_model.pth", map_location=torch.device('cpu'))
|
44 |
model.load_state_dict(checkpoint['model_state_dict'])
|
|
|
45 |
model.eval()
|
46 |
|
47 |
def allowed_file(filename):
|
48 |
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
|
49 |
|
50 |
+
@spaces.GPU
|
51 |
+
def augment_frame(frame):
|
52 |
+
# Random horizontal flip
|
53 |
+
if np.random.rand() > 0.5:
|
54 |
+
frame = cv2.flip(frame, 1)
|
55 |
+
|
56 |
+
# Random brightness adjustment
|
57 |
+
brightness = np.random.uniform(0.8, 1.2)
|
58 |
+
frame = cv2.convertScaleAbs(frame, alpha=brightness, beta=0)
|
59 |
+
|
60 |
+
# Random contrast adjustment
|
61 |
+
contrast = np.random.uniform(0.8, 1.2)
|
62 |
+
frame = cv2.addWeighted(frame, contrast, frame, 0, 0)
|
63 |
+
|
64 |
+
return frame
|
65 |
+
|
66 |
@spaces.GPU
|
67 |
def process_frame(frame):
|
68 |
+
augmented_frame = augment_frame(frame)
|
69 |
+
face = mtcnn(augmented_frame)
|
70 |
if face is None:
|
71 |
return None, None
|
72 |
|
|
|
78 |
face = face / 255.0
|
79 |
|
80 |
with torch.no_grad():
|
81 |
+
output = model(face).squeeze(0)
|
82 |
prediction = "fake" if output.item() >= 0.5 else "real"
|
83 |
|
84 |
return prediction, output.item()
|
|
|
108 |
out.release()
|
109 |
|
110 |
@spaces.GPU
|
111 |
+
def analyze_temporal(video_path, window_size=5):
|
112 |
+
cap = cv2.VideoCapture(video_path)
|
113 |
+
predictions = deque(maxlen=window_size)
|
114 |
+
frame_predictions = []
|
|
|
|
|
|
|
|
|
115 |
|
116 |
while cap.isOpened():
|
117 |
ret, frame = cap.read()
|
|
|
122 |
prediction, confidence = process_frame(rgb_frame)
|
123 |
|
124 |
if prediction is not None:
|
125 |
+
predictions.append(1 if prediction == "fake" else 0)
|
|
|
|
|
126 |
|
127 |
+
if len(predictions) == window_size:
|
128 |
+
avg_prediction = sum(predictions) / window_size
|
129 |
+
frame_predictions.append(avg_prediction)
|
130 |
|
131 |
cap.release()
|
132 |
+
return frame_predictions
|
133 |
+
|
134 |
+
@spaces.GPU
|
135 |
+
def analyze_video(video_path):
|
136 |
+
preprocessed_path = os.path.join(app.config['UPLOAD_FOLDER'], 'preprocessed.mp4')
|
137 |
+
preprocess_video(video_path, preprocessed_path)
|
138 |
+
|
139 |
+
frame_predictions = analyze_temporal(preprocessed_path)
|
140 |
+
|
141 |
os.remove(preprocessed_path)
|
142 |
|
143 |
+
if frame_predictions:
|
144 |
+
fake_percentage = (sum(pred > 0.5 for pred in frame_predictions) / len(frame_predictions)) * 100
|
145 |
return fake_percentage
|
146 |
else:
|
147 |
return 0
|