dhairyashah commited on
Commit
f17100d
verified
1 Parent(s): 5f3487d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -89
app.py CHANGED
@@ -5,16 +5,13 @@ import torch.nn.functional as F
5
  from facenet_pytorch import MTCNN, InceptionResnetV1
6
  import os
7
  import numpy as np
8
- from PIL import Image as PILImage
9
  import zipfile
10
  import cv2
11
  from pytorch_grad_cam import GradCAM
12
  from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
13
  from pytorch_grad_cam.utils.image import show_cam_on_image
14
  import tempfile
15
- import matplotlib.pyplot as plt
16
- from tqdm import tqdm
17
- import io
18
 
19
  with zipfile.ZipFile("examples.zip","r") as zip_ref:
20
  zip_ref.extractall(".")
@@ -54,7 +51,7 @@ np.random.shuffle(examples) # shuffle
54
 
55
  @spaces.GPU
56
  def process_frame(frame, mtcnn, model, cam, targets):
57
- face = mtcnn(PILImage.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
58
  if face is None:
59
  return frame, None, None
60
 
@@ -78,112 +75,56 @@ def process_frame(frame, mtcnn, model, cam, targets):
78
  return visualization, prediction, confidence
79
 
80
  @spaces.GPU
81
- def analyze_video(input_video: str):
82
- """Analyze the video for deepfake detection"""
83
  cap = cv2.VideoCapture(input_video)
84
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
 
 
85
 
86
  target_layers = [model.block8.branch1[-1]]
87
  cam = GradCAM(model=model, target_layers=target_layers)
88
  targets = [ClassifierOutputTarget(0)]
89
 
90
- frame_confidences = []
91
- frame_predictions = []
92
 
93
- for _ in tqdm(range(total_frames), desc="Analyzing video"):
94
  ret, frame = cap.read()
95
  if not ret:
96
  break
97
 
98
- _, prediction, confidence = process_frame(frame, mtcnn, model, cam, targets)
99
 
100
- if prediction is not None and confidence is not None:
101
- frame_confidences.append(confidence)
102
- frame_predictions.append(1 if prediction == "fake" else 0)
 
 
 
 
 
 
 
 
 
103
 
104
  cap.release()
 
105
 
106
- # Calculate metrics
107
- fake_percentage = (sum(frame_predictions) / len(frame_predictions)) * 100 if frame_predictions else 0
108
- avg_confidence = np.mean(frame_confidences) if frame_confidences else 0
109
-
110
- # Create graphs
111
- fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 12))
112
-
113
- # Confidence over time
114
- ax1.plot(frame_confidences)
115
- ax1.set_title("Confidence Over Time")
116
- ax1.set_xlabel("Frame")
117
- ax1.set_ylabel("Confidence")
118
- ax1.set_ylim(0, 1)
119
-
120
- # Prediction distribution
121
- ax2.hist(frame_predictions, bins=[0, 0.5, 1], rwidth=0.8)
122
- ax2.set_title("Distribution of Predictions")
123
- ax2.set_xlabel("Prediction (0: Real, 1: Fake)")
124
- ax2.set_ylabel("Count")
125
-
126
- # Save plot to bytes
127
- buf = io.BytesIO()
128
- plt.savefig(buf, format='png')
129
- buf.seek(0)
130
-
131
- # Create progress bar image
132
- progress_fig, progress_ax = plt.subplots(figsize=(8, 2))
133
- progress_ax.barh(["Fake"], [fake_percentage], color='red')
134
- progress_ax.barh(["Fake"], [100 - fake_percentage], left=[fake_percentage], color='green')
135
- progress_ax.set_xlim(0, 100)
136
- progress_ax.set_title("Fake Percentage")
137
- progress_ax.set_xlabel("Percentage")
138
- progress_ax.text(fake_percentage, 0, f"{fake_percentage:.1f}%", va='center', ha='left')
139
-
140
- # Save progress bar to bytes
141
- progress_buf = io.BytesIO()
142
- progress_fig.savefig(progress_buf, format='png')
143
- progress_buf.seek(0)
144
-
145
- return {
146
- "fake_percentage": fake_percentage,
147
- "avg_confidence": avg_confidence,
148
- "analysis_plot": buf,
149
- "progress_bar": progress_buf,
150
- "total_frames": total_frames,
151
- "processed_frames": len(frame_confidences)
152
- }
153
-
154
- def format_results(results):
155
- return f"""
156
- Analysis Results:
157
- - Fake Percentage: {results['fake_percentage']:.2f}%
158
- - Average Confidence: {results['avg_confidence']:.2f}
159
- - Total Frames: {results['total_frames']}
160
- - Processed Frames: {results['processed_frames']}
161
- """
162
-
163
- def analyze_and_format(input_video):
164
- results = analyze_video(input_video)
165
- text_results = format_results(results)
166
-
167
- # Convert BytesIO to PIL Images
168
- analysis_plot = PILImage.open(results['analysis_plot'])
169
- progress_bar = PILImage.open(results['progress_bar'])
170
-
171
- return text_results, analysis_plot, progress_bar
172
 
173
  interface = gr.Interface(
174
- fn=analyze_and_format,
175
  inputs=[
176
  gr.Video(label="Input Video")
177
  ],
178
  outputs=[
179
- gr.Textbox(label="Analysis Results"),
180
- gr.Image(label="Analysis Plots"),
181
- gr.Image(label="Fake Percentage")
182
  ],
183
- title="Video Deepfake Analysis",
184
- description="Upload a video to analyze for potential deepfakes.",
185
- examples=[]
186
  )
187
 
188
  if __name__ == "__main__":
189
- interface.launch(share=True)
 
5
  from facenet_pytorch import MTCNN, InceptionResnetV1
6
  import os
7
  import numpy as np
8
+ from PIL import Image
9
  import zipfile
10
  import cv2
11
  from pytorch_grad_cam import GradCAM
12
  from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
13
  from pytorch_grad_cam.utils.image import show_cam_on_image
14
  import tempfile
 
 
 
15
 
16
  with zipfile.ZipFile("examples.zip","r") as zip_ref:
17
  zip_ref.extractall(".")
 
51
 
52
  @spaces.GPU
53
  def process_frame(frame, mtcnn, model, cam, targets):
54
+ face = mtcnn(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
55
  if face is None:
56
  return frame, None, None
57
 
 
75
  return visualization, prediction, confidence
76
 
77
  @spaces.GPU
78
+ def predict_video(input_video: str):
79
+ """Predict the labels for each frame of the input video"""
80
  cap = cv2.VideoCapture(input_video)
81
+ fps = cap.get(cv2.CAP_PROP_FPS)
82
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
83
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
84
 
85
  target_layers = [model.block8.branch1[-1]]
86
  cam = GradCAM(model=model, target_layers=target_layers)
87
  targets = [ClassifierOutputTarget(0)]
88
 
89
+ temp_output = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
90
+ out = cv2.VideoWriter(temp_output.name, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
91
 
92
+ while cap.isOpened():
93
  ret, frame = cap.read()
94
  if not ret:
95
  break
96
 
97
+ processed_frame, prediction, confidence = process_frame(frame, mtcnn, model, cam, targets)
98
 
99
+ if processed_frame is not None:
100
+ # Resize the processed frame to match the original video dimensions
101
+ processed_frame = cv2.resize(processed_frame, (width, height))
102
+
103
+ # Add text with prediction and confidence
104
+ text = f"{prediction}: {confidence:.2f}"
105
+ cv2.putText(processed_frame, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
106
+
107
+ out.write(processed_frame)
108
+ else:
109
+ # If no face is detected, write the original frame
110
+ out.write(frame)
111
 
112
  cap.release()
113
+ out.release()
114
 
115
+ return temp_output.name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
  interface = gr.Interface(
118
+ fn=predict_video,
119
  inputs=[
120
  gr.Video(label="Input Video")
121
  ],
122
  outputs=[
123
+ gr.Video(label="Output Video")
 
 
124
  ],
125
+ title="Video Deepfake Detection",
126
+ description="Upload a video to detect deepfakes in each frame."
 
127
  )
128
 
129
  if __name__ == "__main__":
130
+ interface.launch()