dhairyashah commited on
Commit
7396db7
·
verified ·
1 Parent(s): 657d714

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -29
app.py CHANGED
@@ -1,15 +1,12 @@
1
  import spaces
2
- from flask import Flask, request, jsonify
3
  import os
 
4
  from werkzeug.utils import secure_filename
5
  import cv2
6
  import torch
7
  import torch.nn.functional as F
8
  from facenet_pytorch import MTCNN, InceptionResnetV1
9
  import numpy as np
10
- from pytorch_grad_cam import GradCAM
11
- from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
12
- import os
13
 
14
  app = Flask(__name__)
15
 
@@ -27,7 +24,6 @@ DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
27
  mtcnn = MTCNN(select_largest=False, post_process=False, device=DEVICE).to(DEVICE).eval()
28
 
29
  model = InceptionResnetV1(pretrained="vggface2", classify=True, num_classes=1, device=DEVICE)
30
- # Model Credits: https://huggingface.co/spaces/dhairyashah/deepfake-alpha-version/blob/main/CREDITS.md
31
  checkpoint = torch.load("resnetinceptionv1_epoch_32.pth", map_location=torch.device('cpu'))
32
  model.load_state_dict(checkpoint['model_state_dict'])
33
  model.to(DEVICE)
@@ -41,75 +37,103 @@ def process_frame(frame):
41
  face = mtcnn(frame)
42
  if face is None:
43
  return None, None
44
-
45
  face = face.unsqueeze(0)
46
  face = F.interpolate(face, size=(256, 256), mode='bilinear', align_corners=False)
47
-
48
  face = face.to(DEVICE)
49
  face = face.to(torch.float32)
50
  face = face / 255.0
51
-
52
  with torch.no_grad():
53
  output = torch.sigmoid(model(face).squeeze(0))
54
  prediction = "fake" if output.item() >= 0.5 else "real"
55
-
56
  return prediction, output.item()
57
 
58
  @spaces.GPU
59
- def analyze_video(video_path, sample_rate=30):
60
  cap = cv2.VideoCapture(video_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  frame_count = 0
62
  fake_count = 0
63
  total_processed = 0
64
-
65
  while cap.isOpened():
66
  ret, frame = cap.read()
67
  if not ret:
68
  break
69
-
70
- if frame_count % sample_rate == 0:
71
- rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
72
- prediction, confidence = process_frame(rgb_frame)
73
-
74
- if prediction is not None:
75
- total_processed += 1
76
- if prediction == "fake":
77
- fake_count += 1
78
-
79
  frame_count += 1
80
-
81
  cap.release()
82
-
 
83
  if total_processed > 0:
84
  fake_percentage = (fake_count / total_processed) * 100
85
  return fake_percentage
86
  else:
87
  return 0
88
 
 
89
  @app.route('/analyze', methods=['POST'])
90
  def analyze_video_api():
91
  if 'video' not in request.files:
92
  return jsonify({'error': 'No video file provided'}), 400
93
-
94
  file = request.files['video']
95
-
96
  if file.filename == '':
97
  return jsonify({'error': 'No selected file'}), 400
98
-
99
  if file and allowed_file(file.filename):
100
  filename = secure_filename(file.filename)
101
  filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
102
  file.save(filepath)
103
-
104
  try:
105
  fake_percentage = analyze_video(filepath)
106
  os.remove(filepath) # Remove the file after analysis
107
-
108
  result = {
109
  'fake_percentage': round(fake_percentage, 2),
110
  'is_likely_deepfake': fake_percentage >= 60
111
  }
112
-
113
  return jsonify(result), 200
114
  except Exception as e:
115
  os.remove(filepath) # Remove the file if an error occurs
 
1
  import spaces
 
2
  import os
3
+ from flask import Flask, request, jsonify
4
  from werkzeug.utils import secure_filename
5
  import cv2
6
  import torch
7
  import torch.nn.functional as F
8
  from facenet_pytorch import MTCNN, InceptionResnetV1
9
  import numpy as np
 
 
 
10
 
11
  app = Flask(__name__)
12
 
 
24
  mtcnn = MTCNN(select_largest=False, post_process=False, device=DEVICE).to(DEVICE).eval()
25
 
26
  model = InceptionResnetV1(pretrained="vggface2", classify=True, num_classes=1, device=DEVICE)
 
27
  checkpoint = torch.load("resnetinceptionv1_epoch_32.pth", map_location=torch.device('cpu'))
28
  model.load_state_dict(checkpoint['model_state_dict'])
29
  model.to(DEVICE)
 
37
  face = mtcnn(frame)
38
  if face is None:
39
  return None, None
40
+
41
  face = face.unsqueeze(0)
42
  face = F.interpolate(face, size=(256, 256), mode='bilinear', align_corners=False)
43
+
44
  face = face.to(DEVICE)
45
  face = face.to(torch.float32)
46
  face = face / 255.0
47
+
48
  with torch.no_grad():
49
  output = torch.sigmoid(model(face).squeeze(0))
50
  prediction = "fake" if output.item() >= 0.5 else "real"
51
+
52
  return prediction, output.item()
53
 
54
  @spaces.GPU
55
+ def preprocess_video(video_path, output_path):
56
  cap = cv2.VideoCapture(video_path)
57
+ fps = cap.get(cv2.CAP_PROP_FPS)
58
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
59
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
60
+
61
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
62
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
63
+
64
+ while cap.isOpened():
65
+ ret, frame = cap.read()
66
+ if not ret:
67
+ break
68
+
69
+ rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
70
+ face = mtcnn.detect(rgb_frame)
71
+
72
+ if face[0] is not None:
73
+ out.write(frame)
74
+
75
+ cap.release()
76
+ out.release()
77
+
78
+ @spaces.GPU
79
+ def analyze_video(video_path):
80
+ preprocessed_path = os.path.join(app.config['UPLOAD_FOLDER'], 'preprocessed.mp4')
81
+ preprocess_video(video_path, preprocessed_path)
82
+
83
+ cap = cv2.VideoCapture(preprocessed_path)
84
  frame_count = 0
85
  fake_count = 0
86
  total_processed = 0
87
+
88
  while cap.isOpened():
89
  ret, frame = cap.read()
90
  if not ret:
91
  break
92
+
93
+ rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
94
+ prediction, confidence = process_frame(rgb_frame)
95
+
96
+ if prediction is not None:
97
+ total_processed += 1
98
+ if prediction == "fake":
99
+ fake_count += 1
100
+
 
101
  frame_count += 1
102
+
103
  cap.release()
104
+ os.remove(preprocessed_path)
105
+
106
  if total_processed > 0:
107
  fake_percentage = (fake_count / total_processed) * 100
108
  return fake_percentage
109
  else:
110
  return 0
111
 
112
+ @spaces.GPU
113
  @app.route('/analyze', methods=['POST'])
114
  def analyze_video_api():
115
  if 'video' not in request.files:
116
  return jsonify({'error': 'No video file provided'}), 400
117
+
118
  file = request.files['video']
119
+
120
  if file.filename == '':
121
  return jsonify({'error': 'No selected file'}), 400
122
+
123
  if file and allowed_file(file.filename):
124
  filename = secure_filename(file.filename)
125
  filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
126
  file.save(filepath)
127
+
128
  try:
129
  fake_percentage = analyze_video(filepath)
130
  os.remove(filepath) # Remove the file after analysis
131
+
132
  result = {
133
  'fake_percentage': round(fake_percentage, 2),
134
  'is_likely_deepfake': fake_percentage >= 60
135
  }
136
+
137
  return jsonify(result), 200
138
  except Exception as e:
139
  os.remove(filepath) # Remove the file if an error occurs