SpyC0der77 commited on
Commit
7419c44
·
verified ·
1 Parent(s): a363ef6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -20
app.py CHANGED
@@ -29,20 +29,20 @@ except Exception as e:
29
  def process_video_ai(video_file, zoom):
30
  """
31
  Generator function for Gradio:
32
- - Generates motion data (CSV) from the input video using an AI model (RAFT if available, else Farneback)
33
  - Stabilizes the video using the generated motion data.
34
 
35
  Yields:
36
  A tuple of (original_video, stabilized_video, logs, progress)
37
  During processing, original_video and stabilized_video are None.
38
- The final yield returns the video file paths along with final logs and progress=100.
39
  """
40
  logs = []
41
  def add_log(msg):
42
  logs.append(msg)
43
  return "\n".join(logs)
44
 
45
- # Check and extract the file path
46
  if isinstance(video_file, dict):
47
  video_file = video_file.get("name", None)
48
  if video_file is None:
@@ -63,7 +63,7 @@ def process_video_ai(video_file, zoom):
63
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
64
  add_log(f"[INFO] Total frames in video: {total_frames}")
65
 
66
- # Create temporary CSV file
67
  csv_file = tempfile.NamedTemporaryFile(delete=False, suffix='.csv').name
68
  with open(csv_file, 'w', newline='') as csvfile:
69
  fieldnames = ['frame', 'mag', 'ang', 'zoom']
@@ -85,11 +85,11 @@ def process_video_ai(video_file, zoom):
85
  add_log("[INFO] Using Farneback optical flow for computation.")
86
 
87
  frame_idx = 1
88
- # Process each frame for CSV generation
89
  while True:
90
  ret, frame = cap.read()
91
  if not ret:
92
  break
 
93
  if raft_model is not None:
94
  curr_frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
95
  curr_tensor = torch.from_numpy(curr_frame_rgb).permute(2, 0, 1).float().unsqueeze(0) / 255.0
@@ -105,11 +105,11 @@ def process_video_ai(video_file, zoom):
105
  iterations=3, poly_n=5, poly_sigma=1.2, flags=0)
106
  prev_gray = curr_gray
107
 
108
- # Compute median magnitude and angle
109
  mag, ang = cv2.cartToPolar(flow[...,0], flow[...,1], angleInDegrees=True)
110
  median_mag = np.median(mag)
111
  median_ang = np.median(ang)
112
- # Compute zoom factor: fraction of pixels moving away from center
113
  h, w = flow.shape[:2]
114
  center_x, center_y = w / 2, h / 2
115
  x_coords, y_coords = np.meshgrid(np.arange(w), np.arange(h))
@@ -126,7 +126,7 @@ def process_video_ai(video_file, zoom):
126
  })
127
 
128
  if frame_idx % 10 == 0 or frame_idx == total_frames:
129
- progress_csv = (frame_idx / total_frames) * 50 # CSV phase is 0-50%
130
  add_log(f"[INFO] CSV: Processed frame {frame_idx}/{total_frames}")
131
  yield (None, None, add_log(""), progress_csv)
132
  frame_idx += 1
@@ -138,7 +138,7 @@ def process_video_ai(video_file, zoom):
138
  add_log("[INFO] Starting video stabilization...")
139
  yield (None, None, add_log("Starting stabilization..."), 51)
140
 
141
- # Read the CSV and compute cumulative motion data
142
  motion_data = {}
143
  cumulative_dx = 0.0
144
  cumulative_dy = 0.0
@@ -157,7 +157,7 @@ def process_video_ai(video_file, zoom):
157
  add_log("[INFO] Motion CSV read complete.")
158
  yield (None, None, add_log(""), 55)
159
 
160
- # Re-open video for stabilization
161
  cap = cv2.VideoCapture(video_file)
162
  fps = cap.get(cv2.CAP_PROP_FPS)
163
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
@@ -182,12 +182,13 @@ def process_video_ai(video_file, zoom):
182
  frame = zoomed_frame[start_y:start_y+height, start_x:start_x+width]
183
 
184
  dx, dy = motion_data.get(frame_idx, (0, 0))
185
- transform = np.array([[1, 0, dx], [0, 1, dy]], dtype=np.float32)
 
186
  stabilized_frame = cv2.warpAffine(frame, transform, (width, height))
187
  out.write(stabilized_frame)
188
 
189
  if frame_idx % 10 == 0 or frame_idx == total_frames:
190
- progress_stab = 50 + (frame_idx / total_frames) * 50 # Stabilization phase is 50-100%
191
  add_log(f"[INFO] Stabilization: Processed frame {frame_idx}/{total_frames}")
192
  yield (None, None, add_log(""), progress_stab)
193
  frame_idx += 1
@@ -196,7 +197,7 @@ def process_video_ai(video_file, zoom):
196
  add_log("[INFO] Stabilization complete.")
197
  yield (video_file, output_file, add_log(""), 100)
198
 
199
- # Build the Gradio UI with streaming enabled.
200
  with gr.Blocks() as demo:
201
  gr.Markdown("# AI-Powered Video Stabilization")
202
  gr.Markdown("Upload a video and select a zoom factor. The system will generate motion data using an AI model (RAFT if available, else Farneback) and then stabilize the video. Logs and progress will update during processing.")
@@ -212,13 +213,22 @@ with gr.Blocks() as demo:
212
  logs_output = gr.Textbox(label="Logs", lines=15)
213
  progress_bar = gr.Slider(label="Progress", minimum=0, maximum=100, value=0, interactive=False)
214
 
215
- demo.queue() # enable streaming
216
 
217
- process_button.click(
218
- fn=process_video_ai,
219
- inputs=[video_input, zoom_slider],
220
- outputs=[original_video, stabilized_video, logs_output, progress_bar],
221
- stream=True # enable streaming updates
222
- )
 
 
 
 
 
 
 
 
 
223
 
224
  demo.launch()
 
29
  def process_video_ai(video_file, zoom):
30
  """
31
  Generator function for Gradio:
32
+ - Generates motion data (CSV) from the input video using an AI model (RAFT if available, else Farneback).
33
  - Stabilizes the video using the generated motion data.
34
 
35
  Yields:
36
  A tuple of (original_video, stabilized_video, logs, progress)
37
  During processing, original_video and stabilized_video are None.
38
+ The final yield returns the video file paths with final logs and progress=100.
39
  """
40
  logs = []
41
  def add_log(msg):
42
  logs.append(msg)
43
  return "\n".join(logs)
44
 
45
+ # Check and extract the file path.
46
  if isinstance(video_file, dict):
47
  video_file = video_file.get("name", None)
48
  if video_file is None:
 
63
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
64
  add_log(f"[INFO] Total frames in video: {total_frames}")
65
 
66
+ # Create temporary CSV file.
67
  csv_file = tempfile.NamedTemporaryFile(delete=False, suffix='.csv').name
68
  with open(csv_file, 'w', newline='') as csvfile:
69
  fieldnames = ['frame', 'mag', 'ang', 'zoom']
 
85
  add_log("[INFO] Using Farneback optical flow for computation.")
86
 
87
  frame_idx = 1
 
88
  while True:
89
  ret, frame = cap.read()
90
  if not ret:
91
  break
92
+
93
  if raft_model is not None:
94
  curr_frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
95
  curr_tensor = torch.from_numpy(curr_frame_rgb).permute(2, 0, 1).float().unsqueeze(0) / 255.0
 
105
  iterations=3, poly_n=5, poly_sigma=1.2, flags=0)
106
  prev_gray = curr_gray
107
 
108
+ # Compute median magnitude and angle.
109
  mag, ang = cv2.cartToPolar(flow[...,0], flow[...,1], angleInDegrees=True)
110
  median_mag = np.median(mag)
111
  median_ang = np.median(ang)
112
+ # Compute zoom factor: fraction of pixels moving away from the center.
113
  h, w = flow.shape[:2]
114
  center_x, center_y = w / 2, h / 2
115
  x_coords, y_coords = np.meshgrid(np.arange(w), np.arange(h))
 
126
  })
127
 
128
  if frame_idx % 10 == 0 or frame_idx == total_frames:
129
+ progress_csv = (frame_idx / total_frames) * 50 # CSV phase: 0-50%
130
  add_log(f"[INFO] CSV: Processed frame {frame_idx}/{total_frames}")
131
  yield (None, None, add_log(""), progress_csv)
132
  frame_idx += 1
 
138
  add_log("[INFO] Starting video stabilization...")
139
  yield (None, None, add_log("Starting stabilization..."), 51)
140
 
141
+ # Read the CSV and compute cumulative motion data.
142
  motion_data = {}
143
  cumulative_dx = 0.0
144
  cumulative_dy = 0.0
 
157
  add_log("[INFO] Motion CSV read complete.")
158
  yield (None, None, add_log(""), 55)
159
 
160
+ # Re-open video for stabilization.
161
  cap = cv2.VideoCapture(video_file)
162
  fps = cap.get(cv2.CAP_PROP_FPS)
163
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
 
182
  frame = zoomed_frame[start_y:start_y+height, start_x:start_x+width]
183
 
184
  dx, dy = motion_data.get(frame_idx, (0, 0))
185
+ transform = np.array([[1, 0, dx],
186
+ [0, 1, dy]], dtype=np.float32)
187
  stabilized_frame = cv2.warpAffine(frame, transform, (width, height))
188
  out.write(stabilized_frame)
189
 
190
  if frame_idx % 10 == 0 or frame_idx == total_frames:
191
+ progress_stab = 50 + (frame_idx / total_frames) * 50 # Stabilization phase: 50-100%
192
  add_log(f"[INFO] Stabilization: Processed frame {frame_idx}/{total_frames}")
193
  yield (None, None, add_log(""), progress_stab)
194
  frame_idx += 1
 
197
  add_log("[INFO] Stabilization complete.")
198
  yield (video_file, output_file, add_log(""), 100)
199
 
200
+ # Build the Gradio UI.
201
  with gr.Blocks() as demo:
202
  gr.Markdown("# AI-Powered Video Stabilization")
203
  gr.Markdown("Upload a video and select a zoom factor. The system will generate motion data using an AI model (RAFT if available, else Farneback) and then stabilize the video. Logs and progress will update during processing.")
 
213
  logs_output = gr.Textbox(label="Logs", lines=15)
214
  progress_bar = gr.Slider(label="Progress", minimum=0, maximum=100, value=0, interactive=False)
215
 
216
+ demo.queue() # enable queue for streaming
217
 
218
+ # Try using stream=True. If that raises a TypeError, fall back without it.
219
+ try:
220
+ process_button.click(
221
+ fn=process_video_ai,
222
+ inputs=[video_input, zoom_slider],
223
+ outputs=[original_video, stabilized_video, logs_output, progress_bar],
224
+ stream=True
225
+ )
226
+ except TypeError as e:
227
+ print("[WARNING] Streaming not supported in this version of Gradio. Disabling streaming.")
228
+ process_button.click(
229
+ fn=process_video_ai,
230
+ inputs=[video_input, zoom_slider],
231
+ outputs=[original_video, stabilized_video, logs_output, progress_bar]
232
+ )
233
 
234
  demo.launch()