SpyC0der77 commited on
Commit
89bc003
·
verified ·
1 Parent(s): 7561365

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -60
app.py CHANGED
@@ -7,24 +7,32 @@ import tempfile
7
  import os
8
  import gradio as gr
9
 
10
- # Load the RAFT model from torch.hub (uses the 'raft_small' variant)
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
  print(f"Using device: {device}")
13
- model = torch.hub.load("princeton-vl/RAFT", "raft_small", pretrained=True)
14
- model = model.to(device)
15
- model.eval()
 
 
 
 
 
 
 
 
 
 
16
 
17
  def generate_motion_csv(video_file, output_csv=None):
18
  """
19
- Uses the RAFT model to compute optical flow between consecutive frames,
20
- then writes a CSV file (with columns: frame, mag, ang, zoom) where:
21
- - mag: median magnitude of the flow,
22
- - ang: median angle (in degrees), and
23
- - zoom: fraction of pixels moving away from the image center.
24
-
25
  Args:
26
  video_file (str): Path to the input video.
27
- output_csv (str): Optional path for output CSV file. If None, a temporary file is used.
28
 
29
  Returns:
30
  output_csv (str): Path to the generated CSV file.
@@ -38,40 +46,46 @@ def generate_motion_csv(video_file, output_csv=None):
38
  if not cap.isOpened():
39
  raise ValueError("Could not open video file for CSV generation.")
40
 
41
- # Prepare CSV file for writing
42
  with open(output_csv, 'w', newline='') as csvfile:
43
  fieldnames = ['frame', 'mag', 'ang', 'zoom']
44
  writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
45
  writer.writeheader()
46
 
47
- ret, prev_frame = cap.read()
48
  if not ret:
49
  raise ValueError("Cannot read first frame from video.")
50
 
51
- # Convert the first frame to tensor
52
- prev_frame_rgb = cv2.cvtColor(prev_frame, cv2.COLOR_BGR2RGB)
53
- prev_tensor = torch.from_numpy(prev_frame_rgb).permute(2,0,1).float().unsqueeze(0) / 255.0
54
- prev_tensor = prev_tensor.to(device)
 
 
 
55
 
56
  frame_idx = 1
57
  while True:
58
  ret, frame = cap.read()
59
  if not ret:
60
  break
61
-
62
- curr_frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
63
- curr_tensor = torch.from_numpy(curr_frame_rgb).permute(2,0,1).float().unsqueeze(0) / 255.0
64
- curr_tensor = curr_tensor.to(device)
65
-
66
- # Use RAFT to compute optical flow between previous and current frame.
67
- with torch.no_grad():
68
- # The RAFT model returns a low-resolution flow and an upsampled (high-res) flow.
69
- flow_low, flow_up = model(prev_tensor, curr_tensor, iters=20, test_mode=True)
70
- # Convert flow to numpy array (shape: H x W x 2)
71
- flow = flow_up[0].permute(1,2,0).cpu().numpy()
72
-
73
- # Compute median magnitude and angle for the optical flow
74
- mag, ang = cv2.cartToPolar(flow[...,0], flow[...,1], angleInDegrees=True)
 
 
 
 
75
  median_mag = np.median(mag)
76
  median_ang = np.median(ang)
77
 
@@ -81,11 +95,9 @@ def generate_motion_csv(video_file, output_csv=None):
81
  x_coords, y_coords = np.meshgrid(np.arange(w), np.arange(h))
82
  x_offset = x_coords - center_x
83
  y_offset = y_coords - center_y
84
- # Dot product between flow vectors and pixel offsets:
85
- dot = flow[...,0] * x_offset + flow[...,1] * y_offset
86
  zoom_factor = np.count_nonzero(dot > 0) / (w * h)
87
 
88
- # Write the computed metrics to the CSV file.
89
  writer.writerow({
90
  'frame': frame_idx,
91
  'mag': median_mag,
@@ -93,21 +105,19 @@ def generate_motion_csv(video_file, output_csv=None):
93
  'zoom': zoom_factor
94
  })
95
 
96
- # Update for next iteration
97
- prev_tensor = curr_tensor.clone()
98
  frame_idx += 1
99
-
100
  cap.release()
101
  print(f"Motion CSV generated: {output_csv}")
102
  return output_csv
103
 
104
  def read_motion_csv(csv_filename):
105
  """
106
- Reads the CSV file (columns: frame, mag, ang, zoom) and computes a cumulative
107
- offset per frame to be used for stabilization.
108
 
109
  Returns:
110
- A dictionary mapping frame numbers to (dx, dy) offsets (the negative cumulative displacement).
111
  """
112
  motion_data = {}
113
  cumulative_dx = 0.0
@@ -118,13 +128,11 @@ def read_motion_csv(csv_filename):
118
  frame_num = int(row['frame'])
119
  mag = float(row['mag'])
120
  ang = float(row['ang'])
121
- # Convert angle (in degrees) to radians.
122
  rad = math.radians(ang)
123
  dx = mag * math.cos(rad)
124
  dy = mag * math.sin(rad)
125
  cumulative_dx += dx
126
  cumulative_dy += dy
127
- # Negative cumulative offset counteracts the detected motion.
128
  motion_data[frame_num] = (-cumulative_dx, -cumulative_dy)
129
  return motion_data
130
 
@@ -135,13 +143,12 @@ def stabilize_video_using_csv(video_file, csv_file, zoom=1.0, output_file=None):
135
  Args:
136
  video_file (str): Path to the input video.
137
  csv_file (str): Path to the motion CSV file.
138
- zoom (float): Zoom factor to apply before stabilization (default: 1.0, no zoom).
139
  output_file (str): Path for the output stabilized video. If None, a temporary file is created.
140
 
141
  Returns:
142
  output_file (str): Path to the stabilized video file.
143
  """
144
- # Read motion data from CSV
145
  motion_data = read_motion_csv(csv_file)
146
 
147
  cap = cv2.VideoCapture(video_file)
@@ -160,13 +167,12 @@ def stabilize_video_using_csv(video_file, csv_file, zoom=1.0, output_file=None):
160
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
161
  out = cv2.VideoWriter(output_file, fourcc, fps, (width, height))
162
 
163
- frame_num = 1
164
  while True:
165
  ret, frame = cap.read()
166
  if not ret:
167
  break
168
 
169
- # Optionally apply zoom (resize and center-crop)
170
  if zoom != 1.0:
171
  zoomed_frame = cv2.resize(frame, None, fx=zoom, fy=zoom, interpolation=cv2.INTER_LINEAR)
172
  zoomed_h, zoomed_w = zoomed_frame.shape[:2]
@@ -174,16 +180,13 @@ def stabilize_video_using_csv(video_file, csv_file, zoom=1.0, output_file=None):
174
  start_y = max((zoomed_h - height) // 2, 0)
175
  frame = zoomed_frame[start_y:start_y+height, start_x:start_x+width]
176
 
177
- # Get the stabilization offset for the current frame (default to (0,0) if not available)
178
- dx, dy = motion_data.get(frame_num, (0, 0))
179
-
180
- # Apply an affine transformation to counteract the motion.
181
  transform = np.array([[1, 0, dx],
182
  [0, 1, dy]], dtype=np.float32)
183
  stabilized_frame = cv2.warpAffine(frame, transform, (width, height))
184
 
185
  out.write(stabilized_frame)
186
- frame_num += 1
187
 
188
  cap.release()
189
  out.release()
@@ -192,29 +195,28 @@ def stabilize_video_using_csv(video_file, csv_file, zoom=1.0, output_file=None):
192
 
193
  def process_video_ai(video_file, zoom):
194
  """
195
- Gradio interface function: Given an input video and a zoom factor,
196
- it uses a deep learning model (RAFT) to generate motion data (video.flow.csv)
197
- and then stabilizes the video based on that data.
198
 
199
  Returns:
200
- A tuple containing the original video file path and the stabilized video file path.
201
  """
202
- # Ensure the input is a file path (if provided as a dict, extract the "name")
203
  if isinstance(video_file, dict):
204
  video_file = video_file.get("name", None)
205
  if video_file is None:
206
  raise ValueError("Please upload a video file.")
207
 
208
- # Generate motion CSV using AI-based optical flow (RAFT)
209
  csv_file = generate_motion_csv(video_file)
210
- # Stabilize the video using the generated CSV data
211
  stabilized_path = stabilize_video_using_csv(video_file, csv_file, zoom=zoom)
212
  return video_file, stabilized_path
213
 
214
- # Build the Gradio interface
215
  with gr.Blocks() as demo:
216
  gr.Markdown("# AI-Powered Video Stabilization")
217
- gr.Markdown("Upload a video and select a zoom factor. The system will automatically use a deep learning model (RAFT) to generate motion data and then stabilize the video.")
218
 
219
  with gr.Row():
220
  with gr.Column():
 
7
  import os
8
  import gradio as gr
9
 
10
+ # Set up device for torch
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
  print(f"Using device: {device}")
13
+
14
+ # Try to load the RAFT model from torch.hub.
15
+ # If it fails (e.g. due to repository structure changes), we will fall back to OpenCV optical flow.
16
+ try:
17
+ # The trust_repo parameter might prompt for confirmation; set it to True.
18
+ raft_model = torch.hub.load("princeton-vl/RAFT", "raft_small", pretrained=True, trust_repo=True)
19
+ raft_model = raft_model.to(device)
20
+ raft_model.eval()
21
+ print("RAFT model loaded successfully.")
22
+ except Exception as e:
23
+ print("Error loading RAFT model:", e)
24
+ print("Falling back to OpenCV optical flow for motion CSV generation.")
25
+ raft_model = None
26
 
27
  def generate_motion_csv(video_file, output_csv=None):
28
  """
29
+ Generates a CSV file with motion data (columns: frame, mag, ang, zoom) from an input video.
30
+ If the RAFT model is available, it uses it to compute optical flow; otherwise, it falls back to
31
+ OpenCV's Farneback optical flow.
32
+
 
 
33
  Args:
34
  video_file (str): Path to the input video.
35
+ output_csv (str): Optional output CSV file path. If None, a temporary file is created.
36
 
37
  Returns:
38
  output_csv (str): Path to the generated CSV file.
 
46
  if not cap.isOpened():
47
  raise ValueError("Could not open video file for CSV generation.")
48
 
 
49
  with open(output_csv, 'w', newline='') as csvfile:
50
  fieldnames = ['frame', 'mag', 'ang', 'zoom']
51
  writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
52
  writer.writeheader()
53
 
54
+ ret, first_frame = cap.read()
55
  if not ret:
56
  raise ValueError("Cannot read first frame from video.")
57
 
58
+ if raft_model is not None:
59
+ # Convert the first frame to RGB and then to a torch tensor.
60
+ first_frame_rgb = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
61
+ prev_tensor = torch.from_numpy(first_frame_rgb).permute(2, 0, 1).float().unsqueeze(0) / 255.0
62
+ prev_tensor = prev_tensor.to(device)
63
+ else:
64
+ prev_gray = cv2.cvtColor(first_frame, cv2.COLOR_BGR2GRAY)
65
 
66
  frame_idx = 1
67
  while True:
68
  ret, frame = cap.read()
69
  if not ret:
70
  break
71
+
72
+ if raft_model is not None:
73
+ curr_frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
74
+ curr_tensor = torch.from_numpy(curr_frame_rgb).permute(2, 0, 1).float().unsqueeze(0) / 255.0
75
+ curr_tensor = curr_tensor.to(device)
76
+ with torch.no_grad():
77
+ flow_low, flow_up = raft_model(prev_tensor, curr_tensor, iters=20, test_mode=True)
78
+ flow = flow_up[0].permute(1, 2, 0).cpu().numpy()
79
+ prev_tensor = curr_tensor.clone()
80
+ else:
81
+ curr_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
82
+ flow = cv2.calcOpticalFlowFarneback(prev_gray, curr_gray, None,
83
+ pyr_scale=0.5, levels=3, winsize=15,
84
+ iterations=3, poly_n=5, poly_sigma=1.2, flags=0)
85
+ prev_gray = curr_gray
86
+
87
+ # Compute median magnitude and angle of the optical flow.
88
+ mag, ang = cv2.cartToPolar(flow[..., 0], flow[..., 1], angleInDegrees=True)
89
  median_mag = np.median(mag)
90
  median_ang = np.median(ang)
91
 
 
95
  x_coords, y_coords = np.meshgrid(np.arange(w), np.arange(h))
96
  x_offset = x_coords - center_x
97
  y_offset = y_coords - center_y
98
+ dot = flow[..., 0] * x_offset + flow[..., 1] * y_offset
 
99
  zoom_factor = np.count_nonzero(dot > 0) / (w * h)
100
 
 
101
  writer.writerow({
102
  'frame': frame_idx,
103
  'mag': median_mag,
 
105
  'zoom': zoom_factor
106
  })
107
 
 
 
108
  frame_idx += 1
109
+
110
  cap.release()
111
  print(f"Motion CSV generated: {output_csv}")
112
  return output_csv
113
 
114
  def read_motion_csv(csv_filename):
115
  """
116
+ Reads a motion CSV file (with columns: frame, mag, ang, zoom) and computes a cumulative
117
+ offset per frame (the negative cumulative displacement) for stabilization.
118
 
119
  Returns:
120
+ A dictionary mapping frame numbers to (dx, dy) offsets.
121
  """
122
  motion_data = {}
123
  cumulative_dx = 0.0
 
128
  frame_num = int(row['frame'])
129
  mag = float(row['mag'])
130
  ang = float(row['ang'])
 
131
  rad = math.radians(ang)
132
  dx = mag * math.cos(rad)
133
  dy = mag * math.sin(rad)
134
  cumulative_dx += dx
135
  cumulative_dy += dy
 
136
  motion_data[frame_num] = (-cumulative_dx, -cumulative_dy)
137
  return motion_data
138
 
 
143
  Args:
144
  video_file (str): Path to the input video.
145
  csv_file (str): Path to the motion CSV file.
146
+ zoom (float): Zoom factor to apply before stabilization (default: 1.0).
147
  output_file (str): Path for the output stabilized video. If None, a temporary file is created.
148
 
149
  Returns:
150
  output_file (str): Path to the stabilized video file.
151
  """
 
152
  motion_data = read_motion_csv(csv_file)
153
 
154
  cap = cv2.VideoCapture(video_file)
 
167
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
168
  out = cv2.VideoWriter(output_file, fourcc, fps, (width, height))
169
 
170
+ frame_idx = 1
171
  while True:
172
  ret, frame = cap.read()
173
  if not ret:
174
  break
175
 
 
176
  if zoom != 1.0:
177
  zoomed_frame = cv2.resize(frame, None, fx=zoom, fy=zoom, interpolation=cv2.INTER_LINEAR)
178
  zoomed_h, zoomed_w = zoomed_frame.shape[:2]
 
180
  start_y = max((zoomed_h - height) // 2, 0)
181
  frame = zoomed_frame[start_y:start_y+height, start_x:start_x+width]
182
 
183
+ dx, dy = motion_data.get(frame_idx, (0, 0))
 
 
 
184
  transform = np.array([[1, 0, dx],
185
  [0, 1, dy]], dtype=np.float32)
186
  stabilized_frame = cv2.warpAffine(frame, transform, (width, height))
187
 
188
  out.write(stabilized_frame)
189
+ frame_idx += 1
190
 
191
  cap.release()
192
  out.release()
 
195
 
196
  def process_video_ai(video_file, zoom):
197
  """
198
+ Gradio interface function:
199
+ - Generates motion data (CSV) from the input video using an AI model (RAFT, if available).
200
+ - Stabilizes the video based on the generated motion data.
201
 
202
  Returns:
203
+ Tuple containing the original video file path and the stabilized video file path.
204
  """
 
205
  if isinstance(video_file, dict):
206
  video_file = video_file.get("name", None)
207
  if video_file is None:
208
  raise ValueError("Please upload a video file.")
209
 
210
+ # Generate motion CSV using the AI model (or fallback) for optical flow.
211
  csv_file = generate_motion_csv(video_file)
212
+ # Stabilize the video using the generated CSV.
213
  stabilized_path = stabilize_video_using_csv(video_file, csv_file, zoom=zoom)
214
  return video_file, stabilized_path
215
 
216
+ # Build the Gradio UI.
217
  with gr.Blocks() as demo:
218
  gr.Markdown("# AI-Powered Video Stabilization")
219
+ gr.Markdown("Upload a video and select a zoom factor. The system will automatically generate motion data (video.flow.csv) using an AI model (RAFT, if available) and then stabilize the video.")
220
 
221
  with gr.Row():
222
  with gr.Column():