dhairyashah commited on
Commit
adfe793
·
verified ·
1 Parent(s): 7396db7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -23
app.py CHANGED
@@ -1,12 +1,14 @@
1
- import spaces
2
  import os
3
- from flask import Flask, request, jsonify
4
- from werkzeug.utils import secure_filename
5
  import cv2
 
6
  import torch
 
7
  import torch.nn.functional as F
8
  from facenet_pytorch import MTCNN, InceptionResnetV1
9
- import numpy as np
 
 
 
10
 
11
  app = Flask(__name__)
12
 
@@ -21,20 +23,50 @@ os.makedirs(UPLOAD_FOLDER, exist_ok=True)
21
  # Device configuration
22
  DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
23
 
 
24
  mtcnn = MTCNN(select_largest=False, post_process=False, device=DEVICE).to(DEVICE).eval()
25
 
26
- model = InceptionResnetV1(pretrained="vggface2", classify=True, num_classes=1, device=DEVICE)
27
- checkpoint = torch.load("resnetinceptionv1_epoch_32.pth", map_location=torch.device('cpu'))
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  model.load_state_dict(checkpoint['model_state_dict'])
29
- model.to(DEVICE)
30
  model.eval()
31
 
32
  def allowed_file(filename):
33
  return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  @spaces.GPU
36
  def process_frame(frame):
37
- face = mtcnn(frame)
 
38
  if face is None:
39
  return None, None
40
 
@@ -46,7 +78,7 @@ def process_frame(frame):
46
  face = face / 255.0
47
 
48
  with torch.no_grad():
49
- output = torch.sigmoid(model(face).squeeze(0))
50
  prediction = "fake" if output.item() >= 0.5 else "real"
51
 
52
  return prediction, output.item()
@@ -76,14 +108,10 @@ def preprocess_video(video_path, output_path):
76
  out.release()
77
 
78
  @spaces.GPU
79
- def analyze_video(video_path):
80
- preprocessed_path = os.path.join(app.config['UPLOAD_FOLDER'], 'preprocessed.mp4')
81
- preprocess_video(video_path, preprocessed_path)
82
-
83
- cap = cv2.VideoCapture(preprocessed_path)
84
- frame_count = 0
85
- fake_count = 0
86
- total_processed = 0
87
 
88
  while cap.isOpened():
89
  ret, frame = cap.read()
@@ -94,17 +122,26 @@ def analyze_video(video_path):
94
  prediction, confidence = process_frame(rgb_frame)
95
 
96
  if prediction is not None:
97
- total_processed += 1
98
- if prediction == "fake":
99
- fake_count += 1
100
 
101
- frame_count += 1
 
 
102
 
103
  cap.release()
 
 
 
 
 
 
 
 
 
104
  os.remove(preprocessed_path)
105
 
106
- if total_processed > 0:
107
- fake_percentage = (fake_count / total_processed) * 100
108
  return fake_percentage
109
  else:
110
  return 0
 
 
1
  import os
 
 
2
  import cv2
3
+ import numpy as np
4
  import torch
5
+ import torch.nn as nn
6
  import torch.nn.functional as F
7
  from facenet_pytorch import MTCNN, InceptionResnetV1
8
+ from collections import deque
9
+ from flask import Flask, request, jsonify
10
+ from werkzeug.utils import secure_filename
11
+ import spaces
12
 
13
  app = Flask(__name__)
14
 
 
23
  # Device configuration
24
  DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
25
 
26
+ # Model initialization
27
  mtcnn = MTCNN(select_largest=False, post_process=False, device=DEVICE).to(DEVICE).eval()
28
 
29
+ @spaces.GPU
30
+ class EnsembleModel(nn.Module):
31
+ def __init__(self, num_models=3):
32
+ super(EnsembleModel, self).__init__()
33
+ self.models = nn.ModuleList([
34
+ InceptionResnetV1(pretrained="vggface2", classify=True, num_classes=1)
35
+ for _ in range(num_models)
36
+ ])
37
+
38
+ def forward(self, x):
39
+ outputs = [torch.sigmoid(model(x)) for model in self.models]
40
+ return torch.mean(torch.stack(outputs), dim=0)
41
+
42
+ model = EnsembleModel().to(DEVICE)
43
+ checkpoint = torch.load("ensemble_model.pth", map_location=torch.device('cpu'))
44
  model.load_state_dict(checkpoint['model_state_dict'])
 
45
  model.eval()
46
 
47
  def allowed_file(filename):
48
  return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
49
 
50
+ @spaces.GPU
51
+ def augment_frame(frame):
52
+ # Random horizontal flip
53
+ if np.random.rand() > 0.5:
54
+ frame = cv2.flip(frame, 1)
55
+
56
+ # Random brightness adjustment
57
+ brightness = np.random.uniform(0.8, 1.2)
58
+ frame = cv2.convertScaleAbs(frame, alpha=brightness, beta=0)
59
+
60
+ # Random contrast adjustment
61
+ contrast = np.random.uniform(0.8, 1.2)
62
+ frame = cv2.addWeighted(frame, contrast, frame, 0, 0)
63
+
64
+ return frame
65
+
66
  @spaces.GPU
67
  def process_frame(frame):
68
+ augmented_frame = augment_frame(frame)
69
+ face = mtcnn(augmented_frame)
70
  if face is None:
71
  return None, None
72
 
 
78
  face = face / 255.0
79
 
80
  with torch.no_grad():
81
+ output = model(face).squeeze(0)
82
  prediction = "fake" if output.item() >= 0.5 else "real"
83
 
84
  return prediction, output.item()
 
108
  out.release()
109
 
110
  @spaces.GPU
111
+ def analyze_temporal(video_path, window_size=5):
112
+ cap = cv2.VideoCapture(video_path)
113
+ predictions = deque(maxlen=window_size)
114
+ frame_predictions = []
 
 
 
 
115
 
116
  while cap.isOpened():
117
  ret, frame = cap.read()
 
122
  prediction, confidence = process_frame(rgb_frame)
123
 
124
  if prediction is not None:
125
+ predictions.append(1 if prediction == "fake" else 0)
 
 
126
 
127
+ if len(predictions) == window_size:
128
+ avg_prediction = sum(predictions) / window_size
129
+ frame_predictions.append(avg_prediction)
130
 
131
  cap.release()
132
+ return frame_predictions
133
+
134
+ @spaces.GPU
135
+ def analyze_video(video_path):
136
+ preprocessed_path = os.path.join(app.config['UPLOAD_FOLDER'], 'preprocessed.mp4')
137
+ preprocess_video(video_path, preprocessed_path)
138
+
139
+ frame_predictions = analyze_temporal(preprocessed_path)
140
+
141
  os.remove(preprocessed_path)
142
 
143
+ if frame_predictions:
144
+ fake_percentage = (sum(pred > 0.5 for pred in frame_predictions) / len(frame_predictions)) * 100
145
  return fake_percentage
146
  else:
147
  return 0