SpyC0der77 commited on
Commit
02f987c
·
verified ·
1 Parent(s): 1314b4f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -24
app.py CHANGED
@@ -15,7 +15,6 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
  print(f"[INFO] Using device: {device}")
16
 
17
  # Try to load the RAFT model from torch.hub.
18
- # If it fails, we fall back to OpenCV optical flow.
19
  try:
20
  print("[INFO] Attempting to load RAFT model from torch.hub...")
21
  raft_model = torch.hub.load("princeton-vl/RAFT", "raft_small", pretrained=True, trust_repo=True)
@@ -26,13 +25,14 @@ except Exception as e:
26
  print("[ERROR] Error loading RAFT model:", e)
27
  print("[INFO] Falling back to OpenCV Farneback optical flow.")
28
  raft_model = None
 
29
 
30
  def generate_motion_csv(video_file, output_csv=None, progress=gr.Progress(), progress_offset=0.0, progress_scale=0.5):
31
  """
32
  Generates a CSV file with motion data (columns: frame, mag, ang, zoom) from an input video.
33
  Uses RAFT if available, otherwise falls back to OpenCV's Farneback optical flow.
34
 
35
- The progress bar is updated from progress_offset to progress_offset+progress_scale.
36
  """
37
  start_time = time.time()
38
  if output_csv is None:
@@ -42,7 +42,7 @@ def generate_motion_csv(video_file, output_csv=None, progress=gr.Progress(), pro
42
 
43
  cap = cv2.VideoCapture(video_file)
44
  if not cap.isOpened():
45
- raise ValueError("[ERROR] Could not open video file for CSV generation.")
46
 
47
  print(f"[INFO] Generating motion CSV for video: {video_file}")
48
  with open(output_csv, 'w', newline='') as csvfile:
@@ -52,7 +52,7 @@ def generate_motion_csv(video_file, output_csv=None, progress=gr.Progress(), pro
52
 
53
  ret, first_frame = cap.read()
54
  if not ret:
55
- raise ValueError("[ERROR] Cannot read first frame from video.")
56
 
57
  if raft_model is not None:
58
  first_frame_rgb = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
@@ -86,12 +86,12 @@ def generate_motion_csv(video_file, output_csv=None, progress=gr.Progress(), pro
86
  iterations=3, poly_n=5, poly_sigma=1.2, flags=0)
87
  prev_gray = curr_gray
88
 
89
- # Compute median magnitude and angle of the optical flow.
90
  mag, ang = cv2.cartToPolar(flow[..., 0], flow[..., 1], angleInDegrees=True)
91
  median_mag = np.median(mag)
92
  median_ang = np.median(ang)
93
 
94
- # Compute a "zoom factor": fraction of pixels moving away from the center.
95
  h, w = flow.shape[:2]
96
  center_x, center_y = w / 2, h / 2
97
  x_coords, y_coords = np.meshgrid(np.arange(w), np.arange(h))
@@ -110,7 +110,6 @@ def generate_motion_csv(video_file, output_csv=None, progress=gr.Progress(), pro
110
  if frame_idx % 10 == 0 or frame_idx == total_frames:
111
  print(f"[INFO] Processed frame {frame_idx}/{total_frames}")
112
 
113
- # Update progress for this phase.
114
  progress(progress_offset + (frame_idx / total_frames) * progress_scale, desc="Generating Motion CSV")
115
  frame_idx += 1
116
 
@@ -121,11 +120,9 @@ def generate_motion_csv(video_file, output_csv=None, progress=gr.Progress(), pro
121
 
122
  def read_motion_csv(csv_filename):
123
  """
124
- Reads a motion CSV file (with columns: frame, mag, ang, zoom) and computes a cumulative
125
- offset per frame for stabilization.
126
 
127
- Returns:
128
- A dictionary mapping frame numbers to (dx, dy) offsets.
129
  """
130
  print(f"[INFO] Reading motion CSV: {csv_filename}")
131
  motion_data = {}
@@ -148,10 +145,10 @@ def read_motion_csv(csv_filename):
148
 
149
  def stabilize_video_using_csv(video_file, csv_file, zoom=1.0, vertical_only=False, progress=gr.Progress(), progress_offset=0.5, progress_scale=0.5, output_file=None):
150
  """
151
- Stabilizes the input video using motion data from the CSV file.
152
- If vertical_only is True, only vertical motion is corrected (horizontal displacement is ignored).
153
 
154
- The progress bar is updated from progress_offset to progress_offset+progress_scale.
155
  """
156
  start_time = time.time()
157
  print(f"[INFO] Starting stabilization using CSV: {csv_file}")
@@ -159,7 +156,7 @@ def stabilize_video_using_csv(video_file, csv_file, zoom=1.0, vertical_only=Fals
159
 
160
  cap = cv2.VideoCapture(video_file)
161
  if not cap.isOpened():
162
- raise ValueError("[ERROR] Could not open video file for stabilization.")
163
 
164
  fps = cap.get(cv2.CAP_PROP_FPS)
165
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
@@ -192,7 +189,7 @@ def stabilize_video_using_csv(video_file, csv_file, zoom=1.0, vertical_only=Fals
192
 
193
  dx, dy = motion_data.get(frame_idx, (0, 0))
194
  if vertical_only:
195
- dx = 0 # Ignore horizontal motion for vertical-only stabilization.
196
  transform = np.array([[1, 0, dx],
197
  [0, 1, dy]], dtype=np.float32)
198
  stabilized_frame = cv2.warpAffine(frame, transform, (width, height))
@@ -201,7 +198,6 @@ def stabilize_video_using_csv(video_file, csv_file, zoom=1.0, vertical_only=Fals
201
  if frame_idx % 10 == 0 or frame_idx == total_frames:
202
  print(f"[INFO] Stabilized frame {frame_idx}/{total_frames}")
203
 
204
- # Update progress for stabilization phase.
205
  progress(progress_offset + (frame_idx / total_frames) * progress_scale, desc="Stabilizing Video")
206
  frame_idx += 1
207
 
@@ -214,26 +210,28 @@ def stabilize_video_using_csv(video_file, csv_file, zoom=1.0, vertical_only=Fals
214
  def process_video_ai(video_file, zoom, vertical_only, progress=gr.Progress(track_tqdm=True)):
215
  """
216
  Gradio interface function:
217
- - Generates motion data (CSV) from the input video using an AI model (RAFT if available, else Farneback).
218
  - Stabilizes the video based on the generated motion data.
219
  - If vertical_only is True, only vertical stabilization is applied.
220
 
221
  Returns:
222
- Tuple containing the original video file path, the stabilized video file path, and log output.
223
  """
 
 
 
224
  log_buffer = io.StringIO()
225
  with redirect_stdout(log_buffer):
226
  if isinstance(video_file, dict):
227
  video_file = video_file.get("name", None)
228
  if video_file is None:
229
- raise ValueError("[ERROR] Please upload a video file.")
230
 
231
- print("[INFO] Starting AI-powered video processing...")
232
- # First half: Generate motion CSV.
233
  csv_file = generate_motion_csv(video_file, progress=progress, progress_offset=0.0, progress_scale=0.5)
234
- # Second half: Stabilize video.
235
  stabilized_path = stabilize_video_using_csv(video_file, csv_file, zoom=zoom, vertical_only=vertical_only,
236
  progress=progress, progress_offset=0.5, progress_scale=0.5)
 
237
  print("[INFO] Video processing complete.")
238
  logs = log_buffer.getvalue()
239
  return video_file, stabilized_path, logs
@@ -241,7 +239,7 @@ def process_video_ai(video_file, zoom, vertical_only, progress=gr.Progress(track
241
  # Build the Gradio UI.
242
  with gr.Blocks() as demo:
243
  gr.Markdown("# AI-Powered Video Stabilization")
244
- gr.Markdown("Upload a video, select a zoom factor, and choose whether to apply only vertical stabilization. The system will generate motion data using an AI model (RAFT if available) and then stabilize the video with live progress updates.")
245
 
246
  with gr.Row():
247
  with gr.Column():
 
15
  print(f"[INFO] Using device: {device}")
16
 
17
  # Try to load the RAFT model from torch.hub.
 
18
  try:
19
  print("[INFO] Attempting to load RAFT model from torch.hub...")
20
  raft_model = torch.hub.load("princeton-vl/RAFT", "raft_small", pretrained=True, trust_repo=True)
 
25
  print("[ERROR] Error loading RAFT model:", e)
26
  print("[INFO] Falling back to OpenCV Farneback optical flow.")
27
  raft_model = None
28
+ gr.Warning("Falling back to OpenCV Farneback optical flow.")
29
 
30
  def generate_motion_csv(video_file, output_csv=None, progress=gr.Progress(), progress_offset=0.0, progress_scale=0.5):
31
  """
32
  Generates a CSV file with motion data (columns: frame, mag, ang, zoom) from an input video.
33
  Uses RAFT if available, otherwise falls back to OpenCV's Farneback optical flow.
34
 
35
+ Updates progress from progress_offset to progress_offset+progress_scale.
36
  """
37
  start_time = time.time()
38
  if output_csv is None:
 
42
 
43
  cap = cv2.VideoCapture(video_file)
44
  if not cap.isOpened():
45
+ raise gr.Error("Could not open video file for CSV generation.")
46
 
47
  print(f"[INFO] Generating motion CSV for video: {video_file}")
48
  with open(output_csv, 'w', newline='') as csvfile:
 
52
 
53
  ret, first_frame = cap.read()
54
  if not ret:
55
+ raise gr.Error("Cannot read first frame from video.")
56
 
57
  if raft_model is not None:
58
  first_frame_rgb = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
 
86
  iterations=3, poly_n=5, poly_sigma=1.2, flags=0)
87
  prev_gray = curr_gray
88
 
89
+ # Compute median magnitude and angle.
90
  mag, ang = cv2.cartToPolar(flow[..., 0], flow[..., 1], angleInDegrees=True)
91
  median_mag = np.median(mag)
92
  median_ang = np.median(ang)
93
 
94
+ # Compute "zoom factor": fraction of pixels moving away from center.
95
  h, w = flow.shape[:2]
96
  center_x, center_y = w / 2, h / 2
97
  x_coords, y_coords = np.meshgrid(np.arange(w), np.arange(h))
 
110
  if frame_idx % 10 == 0 or frame_idx == total_frames:
111
  print(f"[INFO] Processed frame {frame_idx}/{total_frames}")
112
 
 
113
  progress(progress_offset + (frame_idx / total_frames) * progress_scale, desc="Generating Motion CSV")
114
  frame_idx += 1
115
 
 
120
 
121
  def read_motion_csv(csv_filename):
122
  """
123
+ Reads a motion CSV file and computes cumulative offset per frame.
 
124
 
125
+ Returns a dictionary mapping frame numbers to (dx, dy) offsets.
 
126
  """
127
  print(f"[INFO] Reading motion CSV: {csv_filename}")
128
  motion_data = {}
 
145
 
146
  def stabilize_video_using_csv(video_file, csv_file, zoom=1.0, vertical_only=False, progress=gr.Progress(), progress_offset=0.5, progress_scale=0.5, output_file=None):
147
  """
148
+ Stabilizes the input video using motion data from the CSV.
149
+ If vertical_only is True, only vertical motion is corrected.
150
 
151
+ Updates progress from progress_offset to progress_offset+progress_scale.
152
  """
153
  start_time = time.time()
154
  print(f"[INFO] Starting stabilization using CSV: {csv_file}")
 
156
 
157
  cap = cv2.VideoCapture(video_file)
158
  if not cap.isOpened():
159
+ raise gr.Error("Could not open video file for stabilization.")
160
 
161
  fps = cap.get(cv2.CAP_PROP_FPS)
162
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
 
189
 
190
  dx, dy = motion_data.get(frame_idx, (0, 0))
191
  if vertical_only:
192
+ dx = 0 # Only vertical stabilization.
193
  transform = np.array([[1, 0, dx],
194
  [0, 1, dy]], dtype=np.float32)
195
  stabilized_frame = cv2.warpAffine(frame, transform, (width, height))
 
198
  if frame_idx % 10 == 0 or frame_idx == total_frames:
199
  print(f"[INFO] Stabilized frame {frame_idx}/{total_frames}")
200
 
 
201
  progress(progress_offset + (frame_idx / total_frames) * progress_scale, desc="Stabilizing Video")
202
  frame_idx += 1
203
 
 
210
  def process_video_ai(video_file, zoom, vertical_only, progress=gr.Progress(track_tqdm=True)):
211
  """
212
  Gradio interface function:
213
+ - Generates motion data from the input video using an AI model (RAFT if available, else Farneback).
214
  - Stabilizes the video based on the generated motion data.
215
  - If vertical_only is True, only vertical stabilization is applied.
216
 
217
  Returns:
218
+ Tuple: (original video file path, stabilized video file path, log output)
219
  """
220
+ # Display an info alert.
221
+ gr.Info("Starting AI-powered video processing...")
222
+
223
  log_buffer = io.StringIO()
224
  with redirect_stdout(log_buffer):
225
  if isinstance(video_file, dict):
226
  video_file = video_file.get("name", None)
227
  if video_file is None:
228
+ raise gr.Error("Please upload a video file.")
229
 
 
 
230
  csv_file = generate_motion_csv(video_file, progress=progress, progress_offset=0.0, progress_scale=0.5)
231
+ gr.Info("Motion CSV generated successfully.")
232
  stabilized_path = stabilize_video_using_csv(video_file, csv_file, zoom=zoom, vertical_only=vertical_only,
233
  progress=progress, progress_offset=0.5, progress_scale=0.5)
234
+ gr.Info("Video stabilization complete.")
235
  print("[INFO] Video processing complete.")
236
  logs = log_buffer.getvalue()
237
  return video_file, stabilized_path, logs
 
239
  # Build the Gradio UI.
240
  with gr.Blocks() as demo:
241
  gr.Markdown("# AI-Powered Video Stabilization")
242
+ gr.Markdown("Upload a video, select a zoom factor, and choose whether to apply only vertical stabilization. The system will generate motion data using an AI model (RAFT if available) and then stabilize the video with live progress updates and alerts.")
243
 
244
  with gr.Row():
245
  with gr.Column():