dhairyashah commited on
Commit
329a62d
verified
1 Parent(s): f207059

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -31
app.py CHANGED
@@ -12,6 +12,9 @@ from pytorch_grad_cam import GradCAM
12
  from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
13
  from pytorch_grad_cam.utils.image import show_cam_on_image
14
  import tempfile
 
 
 
15
 
16
  with zipfile.ZipFile("examples.zip","r") as zip_ref:
17
  zip_ref.extractall(".")
@@ -75,59 +78,108 @@ def process_frame(frame, mtcnn, model, cam, targets):
75
  return visualization, prediction, confidence
76
 
77
  @spaces.GPU
78
- def predict_video(input_video: str):
79
- """Predict the labels for each frame of the input video"""
80
  cap = cv2.VideoCapture(input_video)
81
- fps = cap.get(cv2.CAP_PROP_FPS)
82
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
83
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
84
 
85
  target_layers = [model.block8.branch1[-1]]
86
  cam = GradCAM(model=model, target_layers=target_layers)
87
  targets = [ClassifierOutputTarget(0)]
88
 
89
- temp_output = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
90
- out = cv2.VideoWriter(temp_output.name, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
91
 
92
- while cap.isOpened():
93
  ret, frame = cap.read()
94
  if not ret:
95
  break
96
 
97
- processed_frame, prediction, confidence = process_frame(frame, mtcnn, model, cam, targets)
98
 
99
- if processed_frame is not None:
100
- # Resize the processed frame to match the original video dimensions
101
- processed_frame = cv2.resize(processed_frame, (width, height))
102
-
103
- # Add text with prediction and confidence
104
- if prediction is not None and confidence is not None:
105
- text = f"{prediction}: {confidence:.2f}"
106
- else:
107
- text = "No prediction available"
108
- cv2.putText(processed_frame, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
109
-
110
- out.write(processed_frame)
111
- else:
112
- # If no face is detected, write the original frame
113
- out.write(frame)
114
 
115
  cap.release()
116
- out.release()
117
 
118
- return temp_output.name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
  interface = gr.Interface(
121
- fn=predict_video,
122
  inputs=[
123
  gr.Video(label="Input Video")
124
  ],
125
  outputs=[
126
- gr.Video(label="Output Video")
 
 
127
  ],
128
- title="Video Deepfake Detection",
129
- description="Upload a video to detect deepfakes in each frame."
 
 
130
  )
131
 
132
  if __name__ == "__main__":
133
- interface.launch()
 
12
  from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
13
  from pytorch_grad_cam.utils.image import show_cam_on_image
14
  import tempfile
15
+ import matplotlib.pyplot as plt
16
+ from tqdm import tqdm
17
+ import io
18
 
19
  with zipfile.ZipFile("examples.zip","r") as zip_ref:
20
  zip_ref.extractall(".")
 
78
  return visualization, prediction, confidence
79
 
80
  @spaces.GPU
81
+ def analyze_video(input_video: str):
82
+ """Analyze the video for deepfake detection"""
83
  cap = cv2.VideoCapture(input_video)
84
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
 
 
85
 
86
  target_layers = [model.block8.branch1[-1]]
87
  cam = GradCAM(model=model, target_layers=target_layers)
88
  targets = [ClassifierOutputTarget(0)]
89
 
90
+ frame_confidences = []
91
+ frame_predictions = []
92
 
93
+ for _ in tqdm(range(total_frames), desc="Analyzing video"):
94
  ret, frame = cap.read()
95
  if not ret:
96
  break
97
 
98
+ _, prediction, confidence = process_frame(frame, mtcnn, model, cam, targets)
99
 
100
+ if prediction is not None and confidence is not None:
101
+ frame_confidences.append(confidence)
102
+ frame_predictions.append(1 if prediction == "fake" else 0)
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
  cap.release()
 
105
 
106
+ # Calculate metrics
107
+ fake_percentage = (sum(frame_predictions) / len(frame_predictions)) * 100
108
+ avg_confidence = np.mean(frame_confidences)
109
+
110
+ # Create graphs
111
+ fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 12))
112
+
113
+ # Confidence over time
114
+ ax1.plot(frame_confidences)
115
+ ax1.set_title("Confidence Over Time")
116
+ ax1.set_xlabel("Frame")
117
+ ax1.set_ylabel("Confidence")
118
+ ax1.set_ylim(0, 1)
119
+
120
+ # Prediction distribution
121
+ ax2.hist(frame_predictions, bins=[0, 0.5, 1], rwidth=0.8)
122
+ ax2.set_title("Distribution of Predictions")
123
+ ax2.set_xlabel("Prediction (0: Real, 1: Fake)")
124
+ ax2.set_ylabel("Count")
125
+
126
+ # Save plot to bytes
127
+ buf = io.BytesIO()
128
+ plt.savefig(buf, format='png')
129
+ buf.seek(0)
130
+
131
+ # Create progress bar image
132
+ progress_fig, progress_ax = plt.subplots(figsize=(8, 2))
133
+ progress_ax.barh(["Fake"], [fake_percentage], color='red')
134
+ progress_ax.barh(["Fake"], [100 - fake_percentage], left=[fake_percentage], color='green')
135
+ progress_ax.set_xlim(0, 100)
136
+ progress_ax.set_title("Fake Percentage")
137
+ progress_ax.set_xlabel("Percentage")
138
+ progress_ax.text(fake_percentage, 0, f"{fake_percentage:.1f}%", va='center', ha='left')
139
+
140
+ # Save progress bar to bytes
141
+ progress_buf = io.BytesIO()
142
+ progress_fig.savefig(progress_buf, format='png')
143
+ progress_buf.seek(0)
144
+
145
+ return {
146
+ "fake_percentage": fake_percentage,
147
+ "avg_confidence": avg_confidence,
148
+ "analysis_plot": buf,
149
+ "progress_bar": progress_buf,
150
+ "total_frames": total_frames,
151
+ "processed_frames": len(frame_confidences)
152
+ }
153
+
154
+ def format_results(results):
155
+ return f"""
156
+ Analysis Results:
157
+ - Fake Percentage: {results['fake_percentage']:.2f}%
158
+ - Average Confidence: {results['avg_confidence']:.2f}
159
+ - Total Frames: {results['total_frames']}
160
+ - Processed Frames: {results['processed_frames']}
161
+ """
162
+
163
+ def analyze_and_format(input_video):
164
+ results = analyze_video(input_video)
165
+ text_results = format_results(results)
166
+ return text_results, results['analysis_plot'], results['progress_bar']
167
 
168
  interface = gr.Interface(
169
+ fn=analyze_and_format,
170
  inputs=[
171
  gr.Video(label="Input Video")
172
  ],
173
  outputs=[
174
+ gr.Textbox(label="Analysis Results"),
175
+ gr.Image(label="Analysis Plots"),
176
+ gr.Image(label="Fake Percentage")
177
  ],
178
+ title="Video Deepfake Analysis",
179
+ description="Upload a video to analyze for potential deepfakes.",
180
+ examples=[],
181
+ interpretation="default"
182
  )
183
 
184
  if __name__ == "__main__":
185
+ interface.launch(share=True)