SpyC0der77 commited on
Commit
8653b6e
·
verified ·
1 Parent(s): 89bc003

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -45
app.py CHANGED
@@ -6,37 +6,33 @@ import torch
6
  import tempfile
7
  import os
8
  import gradio as gr
 
 
 
9
 
10
  # Set up device for torch
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
- print(f"Using device: {device}")
13
 
14
  # Try to load the RAFT model from torch.hub.
15
- # If it fails (e.g. due to repository structure changes), we will fall back to OpenCV optical flow.
16
  try:
17
- # The trust_repo parameter might prompt for confirmation; set it to True.
18
  raft_model = torch.hub.load("princeton-vl/RAFT", "raft_small", pretrained=True, trust_repo=True)
19
  raft_model = raft_model.to(device)
20
  raft_model.eval()
21
- print("RAFT model loaded successfully.")
22
  except Exception as e:
23
- print("Error loading RAFT model:", e)
24
- print("Falling back to OpenCV optical flow for motion CSV generation.")
25
  raft_model = None
26
 
27
  def generate_motion_csv(video_file, output_csv=None):
28
  """
29
  Generates a CSV file with motion data (columns: frame, mag, ang, zoom) from an input video.
30
- If the RAFT model is available, it uses it to compute optical flow; otherwise, it falls back to
31
- OpenCV's Farneback optical flow.
32
-
33
- Args:
34
- video_file (str): Path to the input video.
35
- output_csv (str): Optional output CSV file path. If None, a temporary file is created.
36
-
37
- Returns:
38
- output_csv (str): Path to the generated CSV file.
39
  """
 
40
  if output_csv is None:
41
  temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.csv')
42
  output_csv = temp_file.name
@@ -44,8 +40,9 @@ def generate_motion_csv(video_file, output_csv=None):
44
 
45
  cap = cv2.VideoCapture(video_file)
46
  if not cap.isOpened():
47
- raise ValueError("Could not open video file for CSV generation.")
48
 
 
49
  with open(output_csv, 'w', newline='') as csvfile:
50
  fieldnames = ['frame', 'mag', 'ang', 'zoom']
51
  writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
@@ -53,17 +50,20 @@ def generate_motion_csv(video_file, output_csv=None):
53
 
54
  ret, first_frame = cap.read()
55
  if not ret:
56
- raise ValueError("Cannot read first frame from video.")
57
 
58
  if raft_model is not None:
59
- # Convert the first frame to RGB and then to a torch tensor.
60
  first_frame_rgb = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
61
  prev_tensor = torch.from_numpy(first_frame_rgb).permute(2, 0, 1).float().unsqueeze(0) / 255.0
62
  prev_tensor = prev_tensor.to(device)
 
63
  else:
64
  prev_gray = cv2.cvtColor(first_frame, cv2.COLOR_BGR2GRAY)
 
65
 
66
  frame_idx = 1
 
 
67
  while True:
68
  ret, frame = cap.read()
69
  if not ret:
@@ -105,20 +105,24 @@ def generate_motion_csv(video_file, output_csv=None):
105
  'zoom': zoom_factor
106
  })
107
 
 
 
108
  frame_idx += 1
109
 
110
  cap.release()
111
- print(f"Motion CSV generated: {output_csv}")
 
112
  return output_csv
113
 
114
  def read_motion_csv(csv_filename):
115
  """
116
  Reads a motion CSV file (with columns: frame, mag, ang, zoom) and computes a cumulative
117
- offset per frame (the negative cumulative displacement) for stabilization.
118
 
119
  Returns:
120
  A dictionary mapping frame numbers to (dx, dy) offsets.
121
  """
 
122
  motion_data = {}
123
  cumulative_dx = 0.0
124
  cumulative_dy = 0.0
@@ -134,30 +138,25 @@ def read_motion_csv(csv_filename):
134
  cumulative_dx += dx
135
  cumulative_dy += dy
136
  motion_data[frame_num] = (-cumulative_dx, -cumulative_dy)
 
137
  return motion_data
138
 
139
  def stabilize_video_using_csv(video_file, csv_file, zoom=1.0, output_file=None):
140
  """
141
  Stabilizes the input video using motion data from the CSV file.
142
-
143
- Args:
144
- video_file (str): Path to the input video.
145
- csv_file (str): Path to the motion CSV file.
146
- zoom (float): Zoom factor to apply before stabilization (default: 1.0).
147
- output_file (str): Path for the output stabilized video. If None, a temporary file is created.
148
-
149
- Returns:
150
- output_file (str): Path to the stabilized video file.
151
  """
 
 
152
  motion_data = read_motion_csv(csv_file)
153
 
154
  cap = cv2.VideoCapture(video_file)
155
  if not cap.isOpened():
156
- raise ValueError("Could not open video file for stabilization.")
157
 
158
  fps = cap.get(cv2.CAP_PROP_FPS)
159
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
160
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
 
161
 
162
  if output_file is None:
163
  temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
@@ -168,11 +167,14 @@ def stabilize_video_using_csv(video_file, csv_file, zoom=1.0, output_file=None):
168
  out = cv2.VideoWriter(output_file, fourcc, fps, (width, height))
169
 
170
  frame_idx = 1
 
 
171
  while True:
172
  ret, frame = cap.read()
173
  if not ret:
174
  break
175
 
 
176
  if zoom != 1.0:
177
  zoomed_frame = cv2.resize(frame, None, fx=zoom, fy=zoom, interpolation=cv2.INTER_LINEAR)
178
  zoomed_h, zoomed_w = zoomed_frame.shape[:2]
@@ -186,37 +188,43 @@ def stabilize_video_using_csv(video_file, csv_file, zoom=1.0, output_file=None):
186
  stabilized_frame = cv2.warpAffine(frame, transform, (width, height))
187
 
188
  out.write(stabilized_frame)
 
 
189
  frame_idx += 1
190
 
191
  cap.release()
192
  out.release()
193
- print(f"Stabilized video saved to: {output_file}")
 
194
  return output_file
195
 
196
  def process_video_ai(video_file, zoom):
197
  """
198
  Gradio interface function:
199
- - Generates motion data (CSV) from the input video using an AI model (RAFT, if available).
200
  - Stabilizes the video based on the generated motion data.
201
 
202
  Returns:
203
- Tuple containing the original video file path and the stabilized video file path.
204
  """
205
- if isinstance(video_file, dict):
206
- video_file = video_file.get("name", None)
207
- if video_file is None:
208
- raise ValueError("Please upload a video file.")
209
-
210
- # Generate motion CSV using the AI model (or fallback) for optical flow.
211
- csv_file = generate_motion_csv(video_file)
212
- # Stabilize the video using the generated CSV.
213
- stabilized_path = stabilize_video_using_csv(video_file, csv_file, zoom=zoom)
214
- return video_file, stabilized_path
 
 
 
215
 
216
  # Build the Gradio UI.
217
  with gr.Blocks() as demo:
218
  gr.Markdown("# AI-Powered Video Stabilization")
219
- gr.Markdown("Upload a video and select a zoom factor. The system will automatically generate motion data (video.flow.csv) using an AI model (RAFT, if available) and then stabilize the video.")
220
 
221
  with gr.Row():
222
  with gr.Column():
@@ -226,11 +234,12 @@ with gr.Blocks() as demo:
226
  with gr.Column():
227
  original_video = gr.Video(label="Original Video")
228
  stabilized_video = gr.Video(label="Stabilized Video")
 
229
 
230
  process_button.click(
231
  fn=process_video_ai,
232
  inputs=[video_input, zoom_slider],
233
- outputs=[original_video, stabilized_video]
234
  )
235
 
236
  demo.launch()
 
6
  import tempfile
7
  import os
8
  import gradio as gr
9
+ import time
10
+ import io
11
+ from contextlib import redirect_stdout
12
 
13
  # Set up device for torch
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+ print(f"[INFO] Using device: {device}")
16
 
17
  # Try to load the RAFT model from torch.hub.
18
+ # If it fails, we fall back to OpenCV optical flow.
19
  try:
20
+ print("[INFO] Attempting to load RAFT model from torch.hub...")
21
  raft_model = torch.hub.load("princeton-vl/RAFT", "raft_small", pretrained=True, trust_repo=True)
22
  raft_model = raft_model.to(device)
23
  raft_model.eval()
24
+ print("[INFO] RAFT model loaded successfully.")
25
  except Exception as e:
26
+ print("[ERROR] Error loading RAFT model:", e)
27
+ print("[INFO] Falling back to OpenCV Farneback optical flow.")
28
  raft_model = None
29
 
30
  def generate_motion_csv(video_file, output_csv=None):
31
  """
32
  Generates a CSV file with motion data (columns: frame, mag, ang, zoom) from an input video.
33
+ Uses RAFT if available, otherwise falls back to OpenCV's Farneback optical flow.
 
 
 
 
 
 
 
 
34
  """
35
+ start_time = time.time()
36
  if output_csv is None:
37
  temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.csv')
38
  output_csv = temp_file.name
 
40
 
41
  cap = cv2.VideoCapture(video_file)
42
  if not cap.isOpened():
43
+ raise ValueError("[ERROR] Could not open video file for CSV generation.")
44
 
45
+ print(f"[INFO] Generating motion CSV for video: {video_file}")
46
  with open(output_csv, 'w', newline='') as csvfile:
47
  fieldnames = ['frame', 'mag', 'ang', 'zoom']
48
  writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
 
50
 
51
  ret, first_frame = cap.read()
52
  if not ret:
53
+ raise ValueError("[ERROR] Cannot read first frame from video.")
54
 
55
  if raft_model is not None:
 
56
  first_frame_rgb = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
57
  prev_tensor = torch.from_numpy(first_frame_rgb).permute(2, 0, 1).float().unsqueeze(0) / 255.0
58
  prev_tensor = prev_tensor.to(device)
59
+ print("[INFO] Using RAFT model for optical flow computation.")
60
  else:
61
  prev_gray = cv2.cvtColor(first_frame, cv2.COLOR_BGR2GRAY)
62
+ print("[INFO] Using OpenCV Farneback optical flow for computation.")
63
 
64
  frame_idx = 1
65
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
66
+ print(f"[INFO] Total frames to process: {total_frames}")
67
  while True:
68
  ret, frame = cap.read()
69
  if not ret:
 
105
  'zoom': zoom_factor
106
  })
107
 
108
+ if frame_idx % 10 == 0 or frame_idx == total_frames:
109
+ print(f"[INFO] Processed frame {frame_idx}/{total_frames}")
110
  frame_idx += 1
111
 
112
  cap.release()
113
+ elapsed = time.time() - start_time
114
+ print(f"[INFO] Motion CSV generated: {output_csv} in {elapsed:.2f} seconds")
115
  return output_csv
116
 
117
  def read_motion_csv(csv_filename):
118
  """
119
  Reads a motion CSV file (with columns: frame, mag, ang, zoom) and computes a cumulative
120
+ offset per frame for stabilization.
121
 
122
  Returns:
123
  A dictionary mapping frame numbers to (dx, dy) offsets.
124
  """
125
+ print(f"[INFO] Reading motion CSV: {csv_filename}")
126
  motion_data = {}
127
  cumulative_dx = 0.0
128
  cumulative_dy = 0.0
 
138
  cumulative_dx += dx
139
  cumulative_dy += dy
140
  motion_data[frame_num] = (-cumulative_dx, -cumulative_dy)
141
+ print("[INFO] Completed reading motion CSV.")
142
  return motion_data
143
 
144
  def stabilize_video_using_csv(video_file, csv_file, zoom=1.0, output_file=None):
145
  """
146
  Stabilizes the input video using motion data from the CSV file.
 
 
 
 
 
 
 
 
 
147
  """
148
+ start_time = time.time()
149
+ print(f"[INFO] Starting stabilization using CSV: {csv_file}")
150
  motion_data = read_motion_csv(csv_file)
151
 
152
  cap = cv2.VideoCapture(video_file)
153
  if not cap.isOpened():
154
+ raise ValueError("[ERROR] Could not open video file for stabilization.")
155
 
156
  fps = cap.get(cv2.CAP_PROP_FPS)
157
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
158
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
159
+ print(f"[INFO] Video properties - FPS: {fps}, Width: {width}, Height: {height}")
160
 
161
  if output_file is None:
162
  temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
 
167
  out = cv2.VideoWriter(output_file, fourcc, fps, (width, height))
168
 
169
  frame_idx = 1
170
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
171
+ print(f"[INFO] Total frames to stabilize: {total_frames}")
172
  while True:
173
  ret, frame = cap.read()
174
  if not ret:
175
  break
176
 
177
+ # Optionally apply zoom (resize and center-crop)
178
  if zoom != 1.0:
179
  zoomed_frame = cv2.resize(frame, None, fx=zoom, fy=zoom, interpolation=cv2.INTER_LINEAR)
180
  zoomed_h, zoomed_w = zoomed_frame.shape[:2]
 
188
  stabilized_frame = cv2.warpAffine(frame, transform, (width, height))
189
 
190
  out.write(stabilized_frame)
191
+ if frame_idx % 10 == 0 or frame_idx == total_frames:
192
+ print(f"[INFO] Stabilized frame {frame_idx}/{total_frames}")
193
  frame_idx += 1
194
 
195
  cap.release()
196
  out.release()
197
+ elapsed = time.time() - start_time
198
+ print(f"[INFO] Stabilized video saved to: {output_file} in {elapsed:.2f} seconds")
199
  return output_file
200
 
201
  def process_video_ai(video_file, zoom):
202
  """
203
  Gradio interface function:
204
+ - Generates motion data (CSV) from the input video using an AI model (RAFT if available, else Farneback).
205
  - Stabilizes the video based on the generated motion data.
206
 
207
  Returns:
208
+ Tuple containing the original video file path, the stabilized video file path, and log output.
209
  """
210
+ log_buffer = io.StringIO()
211
+ with redirect_stdout(log_buffer):
212
+ if isinstance(video_file, dict):
213
+ video_file = video_file.get("name", None)
214
+ if video_file is None:
215
+ raise ValueError("[ERROR] Please upload a video file.")
216
+
217
+ print("[INFO] Starting AI-powered video processing...")
218
+ csv_file = generate_motion_csv(video_file)
219
+ stabilized_path = stabilize_video_using_csv(video_file, csv_file, zoom=zoom)
220
+ print("[INFO] Video processing complete.")
221
+ logs = log_buffer.getvalue()
222
+ return video_file, stabilized_path, logs
223
 
224
  # Build the Gradio UI.
225
  with gr.Blocks() as demo:
226
  gr.Markdown("# AI-Powered Video Stabilization")
227
+ gr.Markdown("Upload a video and select a zoom factor. The system will generate motion data using an AI model (RAFT if available) and then stabilize the video.")
228
 
229
  with gr.Row():
230
  with gr.Column():
 
234
  with gr.Column():
235
  original_video = gr.Video(label="Original Video")
236
  stabilized_video = gr.Video(label="Stabilized Video")
237
+ logs_output = gr.Textbox(label="Logs", lines=10)
238
 
239
  process_button.click(
240
  fn=process_video_ai,
241
  inputs=[video_input, zoom_slider],
242
+ outputs=[original_video, stabilized_video, logs_output]
243
  )
244
 
245
  demo.launch()