SpyC0der77 commited on
Commit
d5a0042
·
verified ·
1 Parent(s): 3aedaee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -77
app.py CHANGED
@@ -8,13 +8,14 @@ import os
8
  import gradio as gr
9
  import time
10
  import io
 
11
 
12
  # Set up device for torch
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
  print(f"[INFO] Using device: {device}")
15
 
16
  # Try to load the RAFT model from torch.hub.
17
- # If it fails, fall back to OpenCV's Farneback optical flow.
18
  try:
19
  print("[INFO] Attempting to load RAFT model from torch.hub...")
20
  raft_model = torch.hub.load("princeton-vl/RAFT", "raft_small", pretrained=True, trust_repo=True)
@@ -26,70 +27,48 @@ except Exception as e:
26
  print("[INFO] Falling back to OpenCV Farneback optical flow.")
27
  raft_model = None
28
 
29
- def process_video_ai(video_file, zoom):
30
  """
31
- Generator function for Gradio:
32
- - Generates motion data (CSV) from the input video using an AI model (RAFT if available, else Farneback)
33
- - Stabilizes the video using the generated motion data.
34
-
35
- Yields:
36
- A tuple of (original_video, stabilized_video, logs, progress)
37
- During processing, original_video and stabilized_video are None.
38
- The final yield returns the video file paths along with final logs and progress=100.
39
  """
40
- logs = []
41
- def add_log(msg):
42
- logs.append(msg)
43
- return "\n".join(logs)
44
-
45
- # Check and extract the file path
46
- if isinstance(video_file, dict):
47
- video_file = video_file.get("name", None)
48
- if video_file is None:
49
- yield (None, None, "[ERROR] Please upload a video file.", 0)
50
- return
51
-
52
- add_log("[INFO] Starting AI-powered video processing...")
53
- yield (None, None, add_log("Starting processing..."), 0)
54
-
55
- # === CSV Generation Phase ===
56
- add_log("[INFO] Starting motion CSV generation...")
57
- yield (None, None, add_log("Starting CSV generation..."), 0)
58
 
59
  cap = cv2.VideoCapture(video_file)
60
  if not cap.isOpened():
61
- yield (None, None, add_log("[ERROR] Could not open video file for CSV generation."), 0)
62
- return
63
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
64
- add_log(f"[INFO] Total frames in video: {total_frames}")
65
 
66
- # Create temporary CSV file
67
- csv_file = tempfile.NamedTemporaryFile(delete=False, suffix='.csv').name
68
- with open(csv_file, 'w', newline='') as csvfile:
69
  fieldnames = ['frame', 'mag', 'ang', 'zoom']
70
  writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
71
  writer.writeheader()
72
 
73
  ret, first_frame = cap.read()
74
  if not ret:
75
- yield (None, None, add_log("[ERROR] Cannot read first frame from video."), 0)
76
- return
77
 
78
  if raft_model is not None:
79
  first_frame_rgb = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
80
  prev_tensor = torch.from_numpy(first_frame_rgb).permute(2, 0, 1).float().unsqueeze(0) / 255.0
81
  prev_tensor = prev_tensor.to(device)
82
- add_log("[INFO] Using RAFT model for optical flow computation.")
83
  else:
84
  prev_gray = cv2.cvtColor(first_frame, cv2.COLOR_BGR2GRAY)
85
- add_log("[INFO] Using Farneback optical flow for computation.")
86
 
87
  frame_idx = 1
88
- # Process each frame for CSV generation
 
89
  while True:
90
  ret, frame = cap.read()
91
  if not ret:
92
  break
 
93
  if raft_model is not None:
94
  curr_frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
95
  curr_tensor = torch.from_numpy(curr_frame_rgb).permute(2, 0, 1).float().unsqueeze(0) / 255.0
@@ -104,12 +83,13 @@ def process_video_ai(video_file, zoom):
104
  pyr_scale=0.5, levels=3, winsize=15,
105
  iterations=3, poly_n=5, poly_sigma=1.2, flags=0)
106
  prev_gray = curr_gray
107
-
108
- # Compute median magnitude and angle
109
- mag, ang = cv2.cartToPolar(flow[...,0], flow[...,1], angleInDegrees=True)
110
  median_mag = np.median(mag)
111
  median_ang = np.median(ang)
112
- # Compute zoom factor: fraction of pixels moving away from center
 
113
  h, w = flow.shape[:2]
114
  center_x, center_y = w / 2, h / 2
115
  x_coords, y_coords = np.meshgrid(np.arange(w), np.arange(h))
@@ -126,23 +106,27 @@ def process_video_ai(video_file, zoom):
126
  })
127
 
128
  if frame_idx % 10 == 0 or frame_idx == total_frames:
129
- progress_csv = (frame_idx / total_frames) * 50 # CSV phase is 0-50%
130
- add_log(f"[INFO] CSV: Processed frame {frame_idx}/{total_frames}")
131
- yield (None, None, add_log(""), progress_csv)
132
  frame_idx += 1
 
133
  cap.release()
134
- add_log("[INFO] CSV generation complete.")
135
- yield (None, None, add_log(""), 50)
136
-
137
- # === Stabilization Phase ===
138
- add_log("[INFO] Starting video stabilization...")
139
- yield (None, None, add_log("Starting stabilization..."), 51)
 
 
140
 
141
- # Read the CSV and compute cumulative motion data
 
 
 
142
  motion_data = {}
143
  cumulative_dx = 0.0
144
  cumulative_dy = 0.0
145
- with open(csv_file, 'r') as csvfile:
146
  reader = csv.DictReader(csvfile)
147
  for row in reader:
148
  frame_num = int(row['frame'])
@@ -154,26 +138,43 @@ def process_video_ai(video_file, zoom):
154
  cumulative_dx += dx
155
  cumulative_dy += dy
156
  motion_data[frame_num] = (-cumulative_dx, -cumulative_dy)
157
- add_log("[INFO] Motion CSV read complete.")
158
- yield (None, None, add_log(""), 55)
 
 
 
 
 
 
 
 
159
 
160
- # Re-open video for stabilization
161
  cap = cv2.VideoCapture(video_file)
 
 
 
162
  fps = cap.get(cv2.CAP_PROP_FPS)
163
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
164
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
165
- temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
166
- output_file = temp_file.name
167
- temp_file.close()
 
 
 
 
168
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
169
  out = cv2.VideoWriter(output_file, fourcc, fps, (width, height))
170
 
171
  frame_idx = 1
172
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
 
173
  while True:
174
  ret, frame = cap.read()
175
  if not ret:
176
  break
 
 
177
  if zoom != 1.0:
178
  zoomed_frame = cv2.resize(frame, None, fx=zoom, fy=zoom, interpolation=cv2.INTER_LINEAR)
179
  zoomed_h, zoomed_w = zoomed_frame.shape[:2]
@@ -182,24 +183,48 @@ def process_video_ai(video_file, zoom):
182
  frame = zoomed_frame[start_y:start_y+height, start_x:start_x+width]
183
 
184
  dx, dy = motion_data.get(frame_idx, (0, 0))
185
- transform = np.array([[1, 0, dx], [0, 1, dy]], dtype=np.float32)
 
186
  stabilized_frame = cv2.warpAffine(frame, transform, (width, height))
187
- out.write(stabilized_frame)
188
 
 
189
  if frame_idx % 10 == 0 or frame_idx == total_frames:
190
- progress_stab = 50 + (frame_idx / total_frames) * 50 # Stabilization phase is 50-100%
191
- add_log(f"[INFO] Stabilization: Processed frame {frame_idx}/{total_frames}")
192
- yield (None, None, add_log(""), progress_stab)
193
  frame_idx += 1
 
194
  cap.release()
195
  out.release()
196
- add_log("[INFO] Stabilization complete.")
197
- yield (video_file, output_file, add_log(""), 100)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
- # Build the Gradio UI with streaming enabled.
200
  with gr.Blocks() as demo:
201
  gr.Markdown("# AI-Powered Video Stabilization")
202
- gr.Markdown("Upload a video and select a zoom factor. The system will generate motion data using an AI model (RAFT if available, else Farneback) and then stabilize the video. Logs and progress will update during processing.")
203
 
204
  with gr.Row():
205
  with gr.Column():
@@ -209,16 +234,12 @@ with gr.Blocks() as demo:
209
  with gr.Column():
210
  original_video = gr.Video(label="Original Video")
211
  stabilized_video = gr.Video(label="Stabilized Video")
212
- logs_output = gr.Textbox(label="Logs", lines=15)
213
- progress_bar = gr.Slider(label="Progress", minimum=0, maximum=100, value=0, interactive=False)
214
-
215
- demo.queue() # enable streaming
216
 
217
  process_button.click(
218
  fn=process_video_ai,
219
  inputs=[video_input, zoom_slider],
220
- outputs=[original_video, stabilized_video, logs_output, progress_bar],
221
- stream=True # enable streaming updates
222
  )
223
 
224
- demo.launch()
 
8
  import gradio as gr
9
  import time
10
  import io
11
+ from contextlib import redirect_stdout
12
 
13
  # Set up device for torch
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
  print(f"[INFO] Using device: {device}")
16
 
17
  # Try to load the RAFT model from torch.hub.
18
+ # If it fails, we fall back to OpenCV optical flow.
19
  try:
20
  print("[INFO] Attempting to load RAFT model from torch.hub...")
21
  raft_model = torch.hub.load("princeton-vl/RAFT", "raft_small", pretrained=True, trust_repo=True)
 
27
  print("[INFO] Falling back to OpenCV Farneback optical flow.")
28
  raft_model = None
29
 
30
+ def generate_motion_csv(video_file, output_csv=None):
31
  """
32
+ Generates a CSV file with motion data (columns: frame, mag, ang, zoom) from an input video.
33
+ Uses RAFT if available, otherwise falls back to OpenCV's Farneback optical flow.
 
 
 
 
 
 
34
  """
35
+ start_time = time.time()
36
+ if output_csv is None:
37
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.csv')
38
+ output_csv = temp_file.name
39
+ temp_file.close()
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  cap = cv2.VideoCapture(video_file)
42
  if not cap.isOpened():
43
+ raise ValueError("[ERROR] Could not open video file for CSV generation.")
 
 
 
44
 
45
+ print(f"[INFO] Generating motion CSV for video: {video_file}")
46
+ with open(output_csv, 'w', newline='') as csvfile:
 
47
  fieldnames = ['frame', 'mag', 'ang', 'zoom']
48
  writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
49
  writer.writeheader()
50
 
51
  ret, first_frame = cap.read()
52
  if not ret:
53
+ raise ValueError("[ERROR] Cannot read first frame from video.")
 
54
 
55
  if raft_model is not None:
56
  first_frame_rgb = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
57
  prev_tensor = torch.from_numpy(first_frame_rgb).permute(2, 0, 1).float().unsqueeze(0) / 255.0
58
  prev_tensor = prev_tensor.to(device)
59
+ print("[INFO] Using RAFT model for optical flow computation.")
60
  else:
61
  prev_gray = cv2.cvtColor(first_frame, cv2.COLOR_BGR2GRAY)
62
+ print("[INFO] Using OpenCV Farneback optical flow for computation.")
63
 
64
  frame_idx = 1
65
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
66
+ print(f"[INFO] Total frames to process: {total_frames}")
67
  while True:
68
  ret, frame = cap.read()
69
  if not ret:
70
  break
71
+
72
  if raft_model is not None:
73
  curr_frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
74
  curr_tensor = torch.from_numpy(curr_frame_rgb).permute(2, 0, 1).float().unsqueeze(0) / 255.0
 
83
  pyr_scale=0.5, levels=3, winsize=15,
84
  iterations=3, poly_n=5, poly_sigma=1.2, flags=0)
85
  prev_gray = curr_gray
86
+
87
+ # Compute median magnitude and angle of the optical flow.
88
+ mag, ang = cv2.cartToPolar(flow[..., 0], flow[..., 1], angleInDegrees=True)
89
  median_mag = np.median(mag)
90
  median_ang = np.median(ang)
91
+
92
+ # Compute a "zoom factor": fraction of pixels moving away from the center.
93
  h, w = flow.shape[:2]
94
  center_x, center_y = w / 2, h / 2
95
  x_coords, y_coords = np.meshgrid(np.arange(w), np.arange(h))
 
106
  })
107
 
108
  if frame_idx % 10 == 0 or frame_idx == total_frames:
109
+ print(f"[INFO] Processed frame {frame_idx}/{total_frames}")
 
 
110
  frame_idx += 1
111
+
112
  cap.release()
113
+ elapsed = time.time() - start_time
114
+ print(f"[INFO] Motion CSV generated: {output_csv} in {elapsed:.2f} seconds")
115
+ return output_csv
116
+
117
+ def read_motion_csv(csv_filename):
118
+ """
119
+ Reads a motion CSV file (with columns: frame, mag, ang, zoom) and computes a cumulative
120
+ offset per frame for stabilization.
121
 
122
+ Returns:
123
+ A dictionary mapping frame numbers to (dx, dy) offsets.
124
+ """
125
+ print(f"[INFO] Reading motion CSV: {csv_filename}")
126
  motion_data = {}
127
  cumulative_dx = 0.0
128
  cumulative_dy = 0.0
129
+ with open(csv_filename, 'r') as csvfile:
130
  reader = csv.DictReader(csvfile)
131
  for row in reader:
132
  frame_num = int(row['frame'])
 
138
  cumulative_dx += dx
139
  cumulative_dy += dy
140
  motion_data[frame_num] = (-cumulative_dx, -cumulative_dy)
141
+ print("[INFO] Completed reading motion CSV.")
142
+ return motion_data
143
+
144
+ def stabilize_video_using_csv(video_file, csv_file, zoom=1.0, output_file=None):
145
+ """
146
+ Stabilizes the input video using motion data from the CSV file.
147
+ """
148
+ start_time = time.time()
149
+ print(f"[INFO] Starting stabilization using CSV: {csv_file}")
150
+ motion_data = read_motion_csv(csv_file)
151
 
 
152
  cap = cv2.VideoCapture(video_file)
153
+ if not cap.isOpened():
154
+ raise ValueError("[ERROR] Could not open video file for stabilization.")
155
+
156
  fps = cap.get(cv2.CAP_PROP_FPS)
157
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
158
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
159
+ print(f"[INFO] Video properties - FPS: {fps}, Width: {width}, Height: {height}")
160
+
161
+ if output_file is None:
162
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
163
+ output_file = temp_file.name
164
+ temp_file.close()
165
+
166
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
167
  out = cv2.VideoWriter(output_file, fourcc, fps, (width, height))
168
 
169
  frame_idx = 1
170
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
171
+ print(f"[INFO] Total frames to stabilize: {total_frames}")
172
  while True:
173
  ret, frame = cap.read()
174
  if not ret:
175
  break
176
+
177
+ # Optionally apply zoom (resize and center-crop)
178
  if zoom != 1.0:
179
  zoomed_frame = cv2.resize(frame, None, fx=zoom, fy=zoom, interpolation=cv2.INTER_LINEAR)
180
  zoomed_h, zoomed_w = zoomed_frame.shape[:2]
 
183
  frame = zoomed_frame[start_y:start_y+height, start_x:start_x+width]
184
 
185
  dx, dy = motion_data.get(frame_idx, (0, 0))
186
+ transform = np.array([[1, 0, dx],
187
+ [0, 1, dy]], dtype=np.float32)
188
  stabilized_frame = cv2.warpAffine(frame, transform, (width, height))
 
189
 
190
+ out.write(stabilized_frame)
191
  if frame_idx % 10 == 0 or frame_idx == total_frames:
192
+ print(f"[INFO] Stabilized frame {frame_idx}/{total_frames}")
 
 
193
  frame_idx += 1
194
+
195
  cap.release()
196
  out.release()
197
+ elapsed = time.time() - start_time
198
+ print(f"[INFO] Stabilized video saved to: {output_file} in {elapsed:.2f} seconds")
199
+ return output_file
200
+
201
+ def process_video_ai(video_file, zoom):
202
+ """
203
+ Gradio interface function:
204
+ - Generates motion data (CSV) from the input video using an AI model (RAFT if available, else Farneback).
205
+ - Stabilizes the video based on the generated motion data.
206
+
207
+ Returns:
208
+ Tuple containing the original video file path, the stabilized video file path, and log output.
209
+ """
210
+ log_buffer = io.StringIO()
211
+ with redirect_stdout(log_buffer):
212
+ if isinstance(video_file, dict):
213
+ video_file = video_file.get("name", None)
214
+ if video_file is None:
215
+ raise ValueError("[ERROR] Please upload a video file.")
216
+
217
+ print("[INFO] Starting AI-powered video processing...")
218
+ csv_file = generate_motion_csv(video_file)
219
+ stabilized_path = stabilize_video_using_csv(video_file, csv_file, zoom=zoom)
220
+ print("[INFO] Video processing complete.")
221
+ logs = log_buffer.getvalue()
222
+ return video_file, stabilized_path, logs
223
 
224
+ # Build the Gradio UI.
225
  with gr.Blocks() as demo:
226
  gr.Markdown("# AI-Powered Video Stabilization")
227
+ gr.Markdown("Upload a video and select a zoom factor. The system will generate motion data using an AI model (RAFT if available) and then stabilize the video.")
228
 
229
  with gr.Row():
230
  with gr.Column():
 
234
  with gr.Column():
235
  original_video = gr.Video(label="Original Video")
236
  stabilized_video = gr.Video(label="Stabilized Video")
237
+ logs_output = gr.Textbox(label="Logs", lines=10)
 
 
 
238
 
239
  process_button.click(
240
  fn=process_video_ai,
241
  inputs=[video_input, zoom_slider],
242
+ outputs=[original_video, stabilized_video, logs_output]
 
243
  )
244
 
245
+ demo.launch()