SpyC0der77 commited on
Commit
7561365
·
verified ·
1 Parent(s): 95f91c4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -37
app.py CHANGED
@@ -2,18 +2,112 @@ import cv2
2
  import numpy as np
3
  import csv
4
  import math
 
5
  import tempfile
6
  import os
7
  import gradio as gr
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  def read_motion_csv(csv_filename):
10
  """
11
- Reads a CSV file with columns: frame, mag, ang, zoom.
12
- For each row, computes a displacement from mag and ang and
13
- accumulates these to build a per-frame cumulative offset.
14
 
15
  Returns:
16
- A dictionary mapping frame numbers to (dx, dy) offsets.
17
  """
18
  motion_data = {}
19
  cumulative_dx = 0.0
@@ -24,15 +118,13 @@ def read_motion_csv(csv_filename):
24
  frame_num = int(row['frame'])
25
  mag = float(row['mag'])
26
  ang = float(row['ang'])
27
- # Convert angle (in degrees) to radians
28
  rad = math.radians(ang)
29
- # Compute displacement vector from magnitude and angle
30
  dx = mag * math.cos(rad)
31
  dy = mag * math.sin(rad)
32
- # Accumulate the displacement over frames
33
  cumulative_dx += dx
34
  cumulative_dy += dy
35
- # Store the negative cumulative offset to counteract motion
36
  motion_data[frame_num] = (-cumulative_dx, -cumulative_dy)
37
  return motion_data
38
 
@@ -42,25 +134,24 @@ def stabilize_video_using_csv(video_file, csv_file, zoom=1.0, output_file=None):
42
 
43
  Args:
44
  video_file (str): Path to the input video.
45
- csv_file (str): Path to the CSV file generated by the detection code.
46
- zoom (float): Optional zoom factor to apply before stabilization (default: 1.0).
47
  output_file (str): Path for the output stabilized video. If None, a temporary file is created.
48
-
49
  Returns:
50
- output_file (str): The path to the stabilized video file.
51
  """
52
  # Read motion data from CSV
53
  motion_data = read_motion_csv(csv_file)
54
 
55
  cap = cv2.VideoCapture(video_file)
56
  if not cap.isOpened():
57
- raise ValueError("Could not open video file.")
58
 
59
  fps = cap.get(cv2.CAP_PROP_FPS)
60
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
61
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
62
 
63
- # Create a temporary file for output if not provided
64
  if output_file is None:
65
  temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
66
  output_file = temp_file.name
@@ -83,10 +174,10 @@ def stabilize_video_using_csv(video_file, csv_file, zoom=1.0, output_file=None):
83
  start_y = max((zoomed_h - height) // 2, 0)
84
  frame = zoomed_frame[start_y:start_y+height, start_x:start_x+width]
85
 
86
- # Retrieve stabilization offset from CSV data (if available)
87
  dx, dy = motion_data.get(frame_num, (0, 0))
88
 
89
- # Apply an affine transformation to counteract the motion
90
  transform = np.array([[1, 0, dx],
91
  [0, 1, dy]], dtype=np.float32)
92
  stabilized_frame = cv2.warpAffine(frame, transform, (width, height))
@@ -96,48 +187,47 @@ def stabilize_video_using_csv(video_file, csv_file, zoom=1.0, output_file=None):
96
 
97
  cap.release()
98
  out.release()
 
99
  return output_file
100
 
101
- def process_video(video_file, csv_file, zoom):
102
- """
103
- Gradio interface function to stabilize a video.
104
- Accepts an input video file, a motion CSV file, and a zoom factor.
105
- Returns the original video and the stabilized video.
106
  """
107
- # Ensure the video file is provided
108
- if video_file is None:
109
- raise ValueError("Please upload a video file.")
110
 
111
- # Convert file inputs to file paths if they come as dictionaries
 
 
 
112
  if isinstance(video_file, dict):
113
  video_file = video_file.get("name", None)
114
- if isinstance(csv_file, dict):
115
- csv_file = csv_file.get("name", None)
116
-
117
- # Check that both file paths are available
118
  if video_file is None:
119
- raise ValueError("Video file path is missing.")
120
- if csv_file is None:
121
- raise ValueError("CSV file path is missing. Please upload a CSV file.")
122
 
 
 
 
123
  stabilized_path = stabilize_video_using_csv(video_file, csv_file, zoom=zoom)
124
  return video_file, stabilized_path
125
 
 
126
  with gr.Blocks() as demo:
127
- gr.Markdown("# Video Stabilization with Motion Data")
 
 
128
  with gr.Row():
129
  with gr.Column():
130
  video_input = gr.Video(label="Input Video")
131
- csv_input = gr.File(label="Motion CSV File (e.g., video.flow.csv)", file_count="single")
132
  zoom_slider = gr.Slider(minimum=1.0, maximum=2.0, step=0.1, value=1.0, label="Zoom Factor")
133
- process_button = gr.Button("Stabilize Video")
134
  with gr.Column():
135
  original_video = gr.Video(label="Original Video")
136
  stabilized_video = gr.Video(label="Stabilized Video")
137
 
138
  process_button.click(
139
- fn=process_video,
140
- inputs=[video_input, csv_input, zoom_slider],
141
  outputs=[original_video, stabilized_video]
142
  )
143
 
 
2
  import numpy as np
3
  import csv
4
  import math
5
+ import torch
6
  import tempfile
7
  import os
8
  import gradio as gr
9
 
10
+ # Load the RAFT model from torch.hub (uses the 'raft_small' variant)
11
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+ print(f"Using device: {device}")
13
+ model = torch.hub.load("princeton-vl/RAFT", "raft_small", pretrained=True)
14
+ model = model.to(device)
15
+ model.eval()
16
+
17
+ def generate_motion_csv(video_file, output_csv=None):
18
+ """
19
+ Uses the RAFT model to compute optical flow between consecutive frames,
20
+ then writes a CSV file (with columns: frame, mag, ang, zoom) where:
21
+ - mag: median magnitude of the flow,
22
+ - ang: median angle (in degrees), and
23
+ - zoom: fraction of pixels moving away from the image center.
24
+
25
+ Args:
26
+ video_file (str): Path to the input video.
27
+ output_csv (str): Optional path for output CSV file. If None, a temporary file is used.
28
+
29
+ Returns:
30
+ output_csv (str): Path to the generated CSV file.
31
+ """
32
+ if output_csv is None:
33
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.csv')
34
+ output_csv = temp_file.name
35
+ temp_file.close()
36
+
37
+ cap = cv2.VideoCapture(video_file)
38
+ if not cap.isOpened():
39
+ raise ValueError("Could not open video file for CSV generation.")
40
+
41
+ # Prepare CSV file for writing
42
+ with open(output_csv, 'w', newline='') as csvfile:
43
+ fieldnames = ['frame', 'mag', 'ang', 'zoom']
44
+ writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
45
+ writer.writeheader()
46
+
47
+ ret, prev_frame = cap.read()
48
+ if not ret:
49
+ raise ValueError("Cannot read first frame from video.")
50
+
51
+ # Convert the first frame to tensor
52
+ prev_frame_rgb = cv2.cvtColor(prev_frame, cv2.COLOR_BGR2RGB)
53
+ prev_tensor = torch.from_numpy(prev_frame_rgb).permute(2,0,1).float().unsqueeze(0) / 255.0
54
+ prev_tensor = prev_tensor.to(device)
55
+
56
+ frame_idx = 1
57
+ while True:
58
+ ret, frame = cap.read()
59
+ if not ret:
60
+ break
61
+
62
+ curr_frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
63
+ curr_tensor = torch.from_numpy(curr_frame_rgb).permute(2,0,1).float().unsqueeze(0) / 255.0
64
+ curr_tensor = curr_tensor.to(device)
65
+
66
+ # Use RAFT to compute optical flow between previous and current frame.
67
+ with torch.no_grad():
68
+ # The RAFT model returns a low-resolution flow and an upsampled (high-res) flow.
69
+ flow_low, flow_up = model(prev_tensor, curr_tensor, iters=20, test_mode=True)
70
+ # Convert flow to numpy array (shape: H x W x 2)
71
+ flow = flow_up[0].permute(1,2,0).cpu().numpy()
72
+
73
+ # Compute median magnitude and angle for the optical flow
74
+ mag, ang = cv2.cartToPolar(flow[...,0], flow[...,1], angleInDegrees=True)
75
+ median_mag = np.median(mag)
76
+ median_ang = np.median(ang)
77
+
78
+ # Compute a "zoom factor": fraction of pixels moving away from the center.
79
+ h, w = flow.shape[:2]
80
+ center_x, center_y = w / 2, h / 2
81
+ x_coords, y_coords = np.meshgrid(np.arange(w), np.arange(h))
82
+ x_offset = x_coords - center_x
83
+ y_offset = y_coords - center_y
84
+ # Dot product between flow vectors and pixel offsets:
85
+ dot = flow[...,0] * x_offset + flow[...,1] * y_offset
86
+ zoom_factor = np.count_nonzero(dot > 0) / (w * h)
87
+
88
+ # Write the computed metrics to the CSV file.
89
+ writer.writerow({
90
+ 'frame': frame_idx,
91
+ 'mag': median_mag,
92
+ 'ang': median_ang,
93
+ 'zoom': zoom_factor
94
+ })
95
+
96
+ # Update for next iteration
97
+ prev_tensor = curr_tensor.clone()
98
+ frame_idx += 1
99
+
100
+ cap.release()
101
+ print(f"Motion CSV generated: {output_csv}")
102
+ return output_csv
103
+
104
  def read_motion_csv(csv_filename):
105
  """
106
+ Reads the CSV file (columns: frame, mag, ang, zoom) and computes a cumulative
107
+ offset per frame to be used for stabilization.
 
108
 
109
  Returns:
110
+ A dictionary mapping frame numbers to (dx, dy) offsets (the negative cumulative displacement).
111
  """
112
  motion_data = {}
113
  cumulative_dx = 0.0
 
118
  frame_num = int(row['frame'])
119
  mag = float(row['mag'])
120
  ang = float(row['ang'])
121
+ # Convert angle (in degrees) to radians.
122
  rad = math.radians(ang)
 
123
  dx = mag * math.cos(rad)
124
  dy = mag * math.sin(rad)
 
125
  cumulative_dx += dx
126
  cumulative_dy += dy
127
+ # Negative cumulative offset counteracts the detected motion.
128
  motion_data[frame_num] = (-cumulative_dx, -cumulative_dy)
129
  return motion_data
130
 
 
134
 
135
  Args:
136
  video_file (str): Path to the input video.
137
+ csv_file (str): Path to the motion CSV file.
138
+ zoom (float): Zoom factor to apply before stabilization (default: 1.0, no zoom).
139
  output_file (str): Path for the output stabilized video. If None, a temporary file is created.
140
+
141
  Returns:
142
+ output_file (str): Path to the stabilized video file.
143
  """
144
  # Read motion data from CSV
145
  motion_data = read_motion_csv(csv_file)
146
 
147
  cap = cv2.VideoCapture(video_file)
148
  if not cap.isOpened():
149
+ raise ValueError("Could not open video file for stabilization.")
150
 
151
  fps = cap.get(cv2.CAP_PROP_FPS)
152
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
153
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
154
 
 
155
  if output_file is None:
156
  temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
157
  output_file = temp_file.name
 
174
  start_y = max((zoomed_h - height) // 2, 0)
175
  frame = zoomed_frame[start_y:start_y+height, start_x:start_x+width]
176
 
177
+ # Get the stabilization offset for the current frame (default to (0,0) if not available)
178
  dx, dy = motion_data.get(frame_num, (0, 0))
179
 
180
+ # Apply an affine transformation to counteract the motion.
181
  transform = np.array([[1, 0, dx],
182
  [0, 1, dy]], dtype=np.float32)
183
  stabilized_frame = cv2.warpAffine(frame, transform, (width, height))
 
187
 
188
  cap.release()
189
  out.release()
190
+ print(f"Stabilized video saved to: {output_file}")
191
  return output_file
192
 
193
+ def process_video_ai(video_file, zoom):
 
 
 
 
194
  """
195
+ Gradio interface function: Given an input video and a zoom factor,
196
+ it uses a deep learning model (RAFT) to generate motion data (video.flow.csv)
197
+ and then stabilizes the video based on that data.
198
 
199
+ Returns:
200
+ A tuple containing the original video file path and the stabilized video file path.
201
+ """
202
+ # Ensure the input is a file path (if provided as a dict, extract the "name")
203
  if isinstance(video_file, dict):
204
  video_file = video_file.get("name", None)
 
 
 
 
205
  if video_file is None:
206
+ raise ValueError("Please upload a video file.")
 
 
207
 
208
+ # Generate motion CSV using AI-based optical flow (RAFT)
209
+ csv_file = generate_motion_csv(video_file)
210
+ # Stabilize the video using the generated CSV data
211
  stabilized_path = stabilize_video_using_csv(video_file, csv_file, zoom=zoom)
212
  return video_file, stabilized_path
213
 
214
+ # Build the Gradio interface
215
  with gr.Blocks() as demo:
216
+ gr.Markdown("# AI-Powered Video Stabilization")
217
+ gr.Markdown("Upload a video and select a zoom factor. The system will automatically use a deep learning model (RAFT) to generate motion data and then stabilize the video.")
218
+
219
  with gr.Row():
220
  with gr.Column():
221
  video_input = gr.Video(label="Input Video")
 
222
  zoom_slider = gr.Slider(minimum=1.0, maximum=2.0, step=0.1, value=1.0, label="Zoom Factor")
223
+ process_button = gr.Button("Process Video")
224
  with gr.Column():
225
  original_video = gr.Video(label="Original Video")
226
  stabilized_video = gr.Video(label="Stabilized Video")
227
 
228
  process_button.click(
229
+ fn=process_video_ai,
230
+ inputs=[video_input, zoom_slider],
231
  outputs=[original_video, stabilized_video]
232
  )
233