dhairyashah commited on
Commit
7a60200
·
verified ·
1 Parent(s): adfe793

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -102
app.py CHANGED
@@ -1,14 +1,15 @@
 
 
1
  import os
 
2
  import cv2
3
- import numpy as np
4
  import torch
5
- import torch.nn as nn
6
  import torch.nn.functional as F
7
  from facenet_pytorch import MTCNN, InceptionResnetV1
8
- from collections import deque
9
- from flask import Flask, request, jsonify
10
- from werkzeug.utils import secure_filename
11
- import spaces
12
 
13
  app = Flask(__name__)
14
 
@@ -23,154 +24,92 @@ os.makedirs(UPLOAD_FOLDER, exist_ok=True)
23
  # Device configuration
24
  DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
25
 
26
- # Model initialization
27
  mtcnn = MTCNN(select_largest=False, post_process=False, device=DEVICE).to(DEVICE).eval()
28
 
29
- @spaces.GPU
30
- class EnsembleModel(nn.Module):
31
- def __init__(self, num_models=3):
32
- super(EnsembleModel, self).__init__()
33
- self.models = nn.ModuleList([
34
- InceptionResnetV1(pretrained="vggface2", classify=True, num_classes=1)
35
- for _ in range(num_models)
36
- ])
37
-
38
- def forward(self, x):
39
- outputs = [torch.sigmoid(model(x)) for model in self.models]
40
- return torch.mean(torch.stack(outputs), dim=0)
41
-
42
- model = EnsembleModel().to(DEVICE)
43
- checkpoint = torch.load("ensemble_model.pth", map_location=torch.device('cpu'))
44
  model.load_state_dict(checkpoint['model_state_dict'])
 
45
  model.eval()
46
 
47
  def allowed_file(filename):
48
  return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
49
 
50
- @spaces.GPU
51
- def augment_frame(frame):
52
- # Random horizontal flip
53
- if np.random.rand() > 0.5:
54
- frame = cv2.flip(frame, 1)
55
-
56
- # Random brightness adjustment
57
- brightness = np.random.uniform(0.8, 1.2)
58
- frame = cv2.convertScaleAbs(frame, alpha=brightness, beta=0)
59
-
60
- # Random contrast adjustment
61
- contrast = np.random.uniform(0.8, 1.2)
62
- frame = cv2.addWeighted(frame, contrast, frame, 0, 0)
63
-
64
- return frame
65
-
66
  @spaces.GPU
67
  def process_frame(frame):
68
- augmented_frame = augment_frame(frame)
69
- face = mtcnn(augmented_frame)
70
  if face is None:
71
  return None, None
72
-
73
  face = face.unsqueeze(0)
74
  face = F.interpolate(face, size=(256, 256), mode='bilinear', align_corners=False)
75
-
76
  face = face.to(DEVICE)
77
  face = face.to(torch.float32)
78
  face = face / 255.0
79
-
80
  with torch.no_grad():
81
- output = model(face).squeeze(0)
82
  prediction = "fake" if output.item() >= 0.5 else "real"
83
-
84
  return prediction, output.item()
85
 
86
  @spaces.GPU
87
- def preprocess_video(video_path, output_path):
88
  cap = cv2.VideoCapture(video_path)
89
- fps = cap.get(cv2.CAP_PROP_FPS)
90
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
91
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
92
-
93
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
94
- out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
95
-
96
- while cap.isOpened():
97
- ret, frame = cap.read()
98
- if not ret:
99
- break
100
-
101
- rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
102
- face = mtcnn.detect(rgb_frame)
103
-
104
- if face[0] is not None:
105
- out.write(frame)
106
-
107
- cap.release()
108
- out.release()
109
 
110
- @spaces.GPU
111
- def analyze_temporal(video_path, window_size=5):
112
- cap = cv2.VideoCapture(video_path)
113
- predictions = deque(maxlen=window_size)
114
- frame_predictions = []
115
-
116
  while cap.isOpened():
117
  ret, frame = cap.read()
118
  if not ret:
119
  break
120
-
121
- rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
122
- prediction, confidence = process_frame(rgb_frame)
123
-
124
- if prediction is not None:
125
- predictions.append(1 if prediction == "fake" else 0)
126
-
127
- if len(predictions) == window_size:
128
- avg_prediction = sum(predictions) / window_size
129
- frame_predictions.append(avg_prediction)
130
-
 
131
  cap.release()
132
- return frame_predictions
133
 
134
- @spaces.GPU
135
- def analyze_video(video_path):
136
- preprocessed_path = os.path.join(app.config['UPLOAD_FOLDER'], 'preprocessed.mp4')
137
- preprocess_video(video_path, preprocessed_path)
138
-
139
- frame_predictions = analyze_temporal(preprocessed_path)
140
-
141
- os.remove(preprocessed_path)
142
-
143
- if frame_predictions:
144
- fake_percentage = (sum(pred > 0.5 for pred in frame_predictions) / len(frame_predictions)) * 100
145
  return fake_percentage
146
  else:
147
  return 0
148
 
149
- @spaces.GPU
150
  @app.route('/analyze', methods=['POST'])
151
  def analyze_video_api():
152
  if 'video' not in request.files:
153
  return jsonify({'error': 'No video file provided'}), 400
154
-
155
  file = request.files['video']
156
-
157
  if file.filename == '':
158
  return jsonify({'error': 'No selected file'}), 400
159
-
160
  if file and allowed_file(file.filename):
161
  filename = secure_filename(file.filename)
162
  filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
163
  file.save(filepath)
164
-
165
  try:
166
  fake_percentage = analyze_video(filepath)
167
  os.remove(filepath) # Remove the file after analysis
168
-
169
  result = {
170
  'fake_percentage': round(fake_percentage, 2),
171
  'is_likely_deepfake': fake_percentage >= 60
172
  }
173
-
174
  return jsonify(result), 200
175
  except Exception as e:
176
  os.remove(filepath) # Remove the file if an error occurs
 
1
+ import spaces
2
+ from flask import Flask, request, jsonify
3
  import os
4
+ from werkzeug.utils import secure_filename
5
  import cv2
 
6
  import torch
 
7
  import torch.nn.functional as F
8
  from facenet_pytorch import MTCNN, InceptionResnetV1
9
+ import numpy as np
10
+ from pytorch_grad_cam import GradCAM
11
+ from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
12
+ import os
13
 
14
  app = Flask(__name__)
15
 
 
24
  # Device configuration
25
  DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
26
 
 
27
  mtcnn = MTCNN(select_largest=False, post_process=False, device=DEVICE).to(DEVICE).eval()
28
 
29
+ model = InceptionResnetV1(pretrained="vggface2", classify=True, num_classes=1, device=DEVICE)
30
+ # Model Credits: https://huggingface.co/spaces/dhairyashah/deepfake-alpha-version/blob/main/CREDITS.md
31
+ checkpoint = torch.load("resnetinceptionv1_epoch_32.pth", map_location=torch.device('cpu'))
 
 
 
 
 
 
 
 
 
 
 
 
32
  model.load_state_dict(checkpoint['model_state_dict'])
33
+ model.to(DEVICE)
34
  model.eval()
35
 
36
  def allowed_file(filename):
37
  return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  @spaces.GPU
40
  def process_frame(frame):
41
+ face = mtcnn(frame)
 
42
  if face is None:
43
  return None, None
44
+
45
  face = face.unsqueeze(0)
46
  face = F.interpolate(face, size=(256, 256), mode='bilinear', align_corners=False)
47
+
48
  face = face.to(DEVICE)
49
  face = face.to(torch.float32)
50
  face = face / 255.0
51
+
52
  with torch.no_grad():
53
+ output = torch.sigmoid(model(face).squeeze(0))
54
  prediction = "fake" if output.item() >= 0.5 else "real"
55
+
56
  return prediction, output.item()
57
 
58
  @spaces.GPU
59
+ def analyze_video(video_path, sample_rate=30):
60
  cap = cv2.VideoCapture(video_path)
61
+ frame_count = 0
62
+ fake_count = 0
63
+ total_processed = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
 
 
 
 
 
 
65
  while cap.isOpened():
66
  ret, frame = cap.read()
67
  if not ret:
68
  break
69
+
70
+ if frame_count % sample_rate == 0:
71
+ rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
72
+ prediction, confidence = process_frame(rgb_frame)
73
+
74
+ if prediction is not None:
75
+ total_processed += 1
76
+ if prediction == "fake":
77
+ fake_count += 1
78
+
79
+ frame_count += 1
80
+
81
  cap.release()
 
82
 
83
+ if total_processed > 0:
84
+ fake_percentage = (fake_count / total_processed) * 100
 
 
 
 
 
 
 
 
 
85
  return fake_percentage
86
  else:
87
  return 0
88
 
 
89
  @app.route('/analyze', methods=['POST'])
90
  def analyze_video_api():
91
  if 'video' not in request.files:
92
  return jsonify({'error': 'No video file provided'}), 400
93
+
94
  file = request.files['video']
95
+
96
  if file.filename == '':
97
  return jsonify({'error': 'No selected file'}), 400
98
+
99
  if file and allowed_file(file.filename):
100
  filename = secure_filename(file.filename)
101
  filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
102
  file.save(filepath)
103
+
104
  try:
105
  fake_percentage = analyze_video(filepath)
106
  os.remove(filepath) # Remove the file after analysis
107
+
108
  result = {
109
  'fake_percentage': round(fake_percentage, 2),
110
  'is_likely_deepfake': fake_percentage >= 60
111
  }
112
+
113
  return jsonify(result), 200
114
  except Exception as e:
115
  os.remove(filepath) # Remove the file if an error occurs