SpyC0der77 commited on
Commit
3aedaee
·
verified ·
1 Parent(s): dec9875

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -102
app.py CHANGED
@@ -7,24 +7,14 @@ import tempfile
7
  import os
8
  import gradio as gr
9
  import time
10
- import threading
11
 
12
- # Global status and result dictionaries.
13
- status = {
14
- "logs": "",
15
- "progress": 0, # from 0 to 100
16
- "finished": False
17
- }
18
- result = {
19
- "original_video": None,
20
- "stabilized_video": None
21
- }
22
-
23
- # Set up device for torch.
24
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
  print(f"[INFO] Using device: {device}")
26
 
27
- # Try to load the RAFT model. If it fails, fall back to OpenCV Farneback.
 
28
  try:
29
  print("[INFO] Attempting to load RAFT model from torch.hub...")
30
  raft_model = torch.hub.load("princeton-vl/RAFT", "raft_small", pretrained=True, trust_repo=True)
@@ -36,70 +26,76 @@ except Exception as e:
36
  print("[INFO] Falling back to OpenCV Farneback optical flow.")
37
  raft_model = None
38
 
39
- def append_log(msg):
40
- """Append a log message to the global status and print it."""
41
- global status
42
- status["logs"] += msg + "\n"
43
- print(msg)
44
-
45
- def background_process(video_file, zoom):
46
  """
47
- Runs the full processing: generates a motion CSV using RAFT (or Farneback)
48
- and then stabilizes the video. Updates global status and result.
 
 
 
 
 
 
49
  """
50
- global status, result
51
-
52
- status["logs"] = ""
53
- status["progress"] = 0
54
- status["finished"] = False
55
- result["original_video"] = None
56
- result["stabilized_video"] = None
57
-
58
- append_log("[INFO] Starting AI-powered video processing...")
 
 
 
 
 
 
59
  # === CSV Generation Phase ===
60
- append_log("[INFO] Starting motion CSV generation...")
 
 
61
  cap = cv2.VideoCapture(video_file)
62
  if not cap.isOpened():
63
- append_log("[ERROR] Could not open video file for CSV generation.")
64
- status["finished"] = True
65
  return
66
-
67
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
68
- append_log(f"[INFO] Total frames in video: {total_frames}")
 
 
69
  csv_file = tempfile.NamedTemporaryFile(delete=False, suffix='.csv').name
70
  with open(csv_file, 'w', newline='') as csvfile:
71
  fieldnames = ['frame', 'mag', 'ang', 'zoom']
72
  writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
73
  writer.writeheader()
74
-
75
  ret, first_frame = cap.read()
76
  if not ret:
77
- append_log("[ERROR] Cannot read first frame from video.")
78
- status["finished"] = True
79
- cap.release()
80
  return
81
-
82
  if raft_model is not None:
83
  first_frame_rgb = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
84
  prev_tensor = torch.from_numpy(first_frame_rgb).permute(2, 0, 1).float().unsqueeze(0) / 255.0
85
  prev_tensor = prev_tensor.to(device)
86
- append_log("[INFO] Using RAFT model for optical flow computation.")
87
  else:
88
  prev_gray = cv2.cvtColor(first_frame, cv2.COLOR_BGR2GRAY)
89
- append_log("[INFO] Using Farneback optical flow for computation.")
90
-
91
  frame_idx = 1
 
92
  while True:
93
  ret, frame = cap.read()
94
  if not ret:
95
  break
96
-
97
  if raft_model is not None:
98
  curr_frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
99
  curr_tensor = torch.from_numpy(curr_frame_rgb).permute(2, 0, 1).float().unsqueeze(0) / 255.0
100
  curr_tensor = curr_tensor.to(device)
101
  with torch.no_grad():
102
- _, flow_up = raft_model(prev_tensor, curr_tensor, iters=20, test_mode=True)
103
  flow = flow_up[0].permute(1, 2, 0).cpu().numpy()
104
  prev_tensor = curr_tensor.clone()
105
  else:
@@ -108,12 +104,12 @@ def background_process(video_file, zoom):
108
  pyr_scale=0.5, levels=3, winsize=15,
109
  iterations=3, poly_n=5, poly_sigma=1.2, flags=0)
110
  prev_gray = curr_gray
111
-
112
- # Compute median magnitude and angle.
113
- mag, ang = cv2.cartToPolar(flow[..., 0], flow[..., 1], angleInDegrees=True)
114
  median_mag = np.median(mag)
115
  median_ang = np.median(ang)
116
- # Compute zoom factor: fraction of pixels moving away from the center.
117
  h, w = flow.shape[:2]
118
  center_x, center_y = w / 2, h / 2
119
  x_coords, y_coords = np.meshgrid(np.arange(w), np.arange(h))
@@ -121,25 +117,28 @@ def background_process(video_file, zoom):
121
  y_offset = y_coords - center_y
122
  dot = flow[..., 0] * x_offset + flow[..., 1] * y_offset
123
  zoom_factor = np.count_nonzero(dot > 0) / (w * h)
 
124
  writer.writerow({
125
  'frame': frame_idx,
126
  'mag': median_mag,
127
  'ang': median_ang,
128
  'zoom': zoom_factor
129
  })
130
-
131
  if frame_idx % 10 == 0 or frame_idx == total_frames:
132
- progress_csv = (frame_idx / total_frames) * 50 # CSV phase: 0-50%
133
- append_log(f"[INFO] CSV: Processed frame {frame_idx}/{total_frames}")
134
- status["progress"] = progress_csv
135
  frame_idx += 1
136
  cap.release()
137
- append_log("[INFO] CSV generation complete.")
138
- status["progress"] = 50
139
-
140
  # === Stabilization Phase ===
141
- append_log("[INFO] Starting video stabilization...")
142
- # Read the CSV and compute cumulative motion data.
 
 
143
  motion_data = {}
144
  cumulative_dx = 0.0
145
  cumulative_dy = 0.0
@@ -155,10 +154,10 @@ def background_process(video_file, zoom):
155
  cumulative_dx += dx
156
  cumulative_dy += dy
157
  motion_data[frame_num] = (-cumulative_dx, -cumulative_dy)
158
- append_log("[INFO] Motion CSV read complete.")
159
- status["progress"] = 55
160
-
161
- # Re-open video for stabilization.
162
  cap = cv2.VideoCapture(video_file)
163
  fps = cap.get(cv2.CAP_PROP_FPS)
164
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
@@ -168,7 +167,7 @@ def background_process(video_file, zoom):
168
  temp_file.close()
169
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
170
  out = cv2.VideoWriter(output_file, fourcc, fps, (width, height))
171
-
172
  frame_idx = 1
173
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
174
  while True:
@@ -181,59 +180,45 @@ def background_process(video_file, zoom):
181
  start_x = max((zoomed_w - width) // 2, 0)
182
  start_y = max((zoomed_h - height) // 2, 0)
183
  frame = zoomed_frame[start_y:start_y+height, start_x:start_x+width]
 
184
  dx, dy = motion_data.get(frame_idx, (0, 0))
185
- transform = np.array([[1, 0, dx],
186
- [0, 1, dy]], dtype=np.float32)
187
  stabilized_frame = cv2.warpAffine(frame, transform, (width, height))
188
  out.write(stabilized_frame)
 
189
  if frame_idx % 10 == 0 or frame_idx == total_frames:
190
- progress_stab = 50 + (frame_idx / total_frames) * 50 # Stabilization phase: 50-100%
191
- append_log(f"[INFO] Stabilization: Processed frame {frame_idx}/{total_frames}")
192
- status["progress"] = progress_stab
193
  frame_idx += 1
194
  cap.release()
195
  out.release()
196
- append_log("[INFO] Stabilization complete.")
197
- status["progress"] = 100
198
- status["finished"] = True
199
- result["original_video"] = video_file
200
- result["stabilized_video"] = output_file
201
-
202
- def start_processing(video_file, zoom):
203
- """Starts background processing in a new thread."""
204
- thread = threading.Thread(target=background_process, args=(video_file, zoom), daemon=True)
205
- thread.start()
206
- return "[INFO] Processing started..."
207
 
208
- def poll_status():
209
- """
210
- Returns the current processing status:
211
- - original_video: path if finished (else None)
212
- - stabilized_video: path if finished (else None)
213
- - logs: current logs string
214
- - progress: current progress value (0 to 100)
215
- """
216
- return result["original_video"], result["stabilized_video"], status["logs"], status["progress"]
217
-
218
- # Build the Gradio UI.
219
  with gr.Blocks() as demo:
220
  gr.Markdown("# AI-Powered Video Stabilization")
221
- gr.Markdown("Upload a video and select a zoom factor. Processing will start automatically and the UI will update every 2 seconds.")
222
-
223
  with gr.Row():
224
  with gr.Column():
225
  video_input = gr.Video(label="Input Video")
226
  zoom_slider = gr.Slider(minimum=1.0, maximum=2.0, step=0.1, value=1.0, label="Zoom Factor")
227
- start_button = gr.Button("Process Video")
228
  with gr.Column():
229
  original_video = gr.Video(label="Original Video")
230
  stabilized_video = gr.Video(label="Stabilized Video")
231
  logs_output = gr.Textbox(label="Logs", lines=15)
232
  progress_bar = gr.Slider(label="Progress", minimum=0, maximum=100, value=0, interactive=False)
233
-
234
- # When "Process Video" is clicked, start processing in the background.
235
- start_button.click(fn=start_processing, inputs=[video_input, zoom_slider], outputs=[logs_output])
236
- # Automatically poll status every 2 seconds using Blocks.load().
237
- demo.load(fn=poll_status, inputs=[], outputs=[original_video, stabilized_video, logs_output, progress_bar], every=2)
238
-
239
- demo.launch()
 
 
 
 
 
7
  import os
8
  import gradio as gr
9
  import time
10
+ import io
11
 
12
+ # Set up device for torch
 
 
 
 
 
 
 
 
 
 
 
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
  print(f"[INFO] Using device: {device}")
15
 
16
+ # Try to load the RAFT model from torch.hub.
17
+ # If it fails, fall back to OpenCV's Farneback optical flow.
18
  try:
19
  print("[INFO] Attempting to load RAFT model from torch.hub...")
20
  raft_model = torch.hub.load("princeton-vl/RAFT", "raft_small", pretrained=True, trust_repo=True)
 
26
  print("[INFO] Falling back to OpenCV Farneback optical flow.")
27
  raft_model = None
28
 
29
+ def process_video_ai(video_file, zoom):
 
 
 
 
 
 
30
  """
31
+ Generator function for Gradio:
32
+ - Generates motion data (CSV) from the input video using an AI model (RAFT if available, else Farneback)
33
+ - Stabilizes the video using the generated motion data.
34
+
35
+ Yields:
36
+ A tuple of (original_video, stabilized_video, logs, progress)
37
+ During processing, original_video and stabilized_video are None.
38
+ The final yield returns the video file paths along with final logs and progress=100.
39
  """
40
+ logs = []
41
+ def add_log(msg):
42
+ logs.append(msg)
43
+ return "\n".join(logs)
44
+
45
+ # Check and extract the file path
46
+ if isinstance(video_file, dict):
47
+ video_file = video_file.get("name", None)
48
+ if video_file is None:
49
+ yield (None, None, "[ERROR] Please upload a video file.", 0)
50
+ return
51
+
52
+ add_log("[INFO] Starting AI-powered video processing...")
53
+ yield (None, None, add_log("Starting processing..."), 0)
54
+
55
  # === CSV Generation Phase ===
56
+ add_log("[INFO] Starting motion CSV generation...")
57
+ yield (None, None, add_log("Starting CSV generation..."), 0)
58
+
59
  cap = cv2.VideoCapture(video_file)
60
  if not cap.isOpened():
61
+ yield (None, None, add_log("[ERROR] Could not open video file for CSV generation."), 0)
 
62
  return
 
63
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
64
+ add_log(f"[INFO] Total frames in video: {total_frames}")
65
+
66
+ # Create temporary CSV file
67
  csv_file = tempfile.NamedTemporaryFile(delete=False, suffix='.csv').name
68
  with open(csv_file, 'w', newline='') as csvfile:
69
  fieldnames = ['frame', 'mag', 'ang', 'zoom']
70
  writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
71
  writer.writeheader()
72
+
73
  ret, first_frame = cap.read()
74
  if not ret:
75
+ yield (None, None, add_log("[ERROR] Cannot read first frame from video."), 0)
 
 
76
  return
77
+
78
  if raft_model is not None:
79
  first_frame_rgb = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
80
  prev_tensor = torch.from_numpy(first_frame_rgb).permute(2, 0, 1).float().unsqueeze(0) / 255.0
81
  prev_tensor = prev_tensor.to(device)
82
+ add_log("[INFO] Using RAFT model for optical flow computation.")
83
  else:
84
  prev_gray = cv2.cvtColor(first_frame, cv2.COLOR_BGR2GRAY)
85
+ add_log("[INFO] Using Farneback optical flow for computation.")
86
+
87
  frame_idx = 1
88
+ # Process each frame for CSV generation
89
  while True:
90
  ret, frame = cap.read()
91
  if not ret:
92
  break
 
93
  if raft_model is not None:
94
  curr_frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
95
  curr_tensor = torch.from_numpy(curr_frame_rgb).permute(2, 0, 1).float().unsqueeze(0) / 255.0
96
  curr_tensor = curr_tensor.to(device)
97
  with torch.no_grad():
98
+ flow_low, flow_up = raft_model(prev_tensor, curr_tensor, iters=20, test_mode=True)
99
  flow = flow_up[0].permute(1, 2, 0).cpu().numpy()
100
  prev_tensor = curr_tensor.clone()
101
  else:
 
104
  pyr_scale=0.5, levels=3, winsize=15,
105
  iterations=3, poly_n=5, poly_sigma=1.2, flags=0)
106
  prev_gray = curr_gray
107
+
108
+ # Compute median magnitude and angle
109
+ mag, ang = cv2.cartToPolar(flow[...,0], flow[...,1], angleInDegrees=True)
110
  median_mag = np.median(mag)
111
  median_ang = np.median(ang)
112
+ # Compute zoom factor: fraction of pixels moving away from center
113
  h, w = flow.shape[:2]
114
  center_x, center_y = w / 2, h / 2
115
  x_coords, y_coords = np.meshgrid(np.arange(w), np.arange(h))
 
117
  y_offset = y_coords - center_y
118
  dot = flow[..., 0] * x_offset + flow[..., 1] * y_offset
119
  zoom_factor = np.count_nonzero(dot > 0) / (w * h)
120
+
121
  writer.writerow({
122
  'frame': frame_idx,
123
  'mag': median_mag,
124
  'ang': median_ang,
125
  'zoom': zoom_factor
126
  })
127
+
128
  if frame_idx % 10 == 0 or frame_idx == total_frames:
129
+ progress_csv = (frame_idx / total_frames) * 50 # CSV phase is 0-50%
130
+ add_log(f"[INFO] CSV: Processed frame {frame_idx}/{total_frames}")
131
+ yield (None, None, add_log(""), progress_csv)
132
  frame_idx += 1
133
  cap.release()
134
+ add_log("[INFO] CSV generation complete.")
135
+ yield (None, None, add_log(""), 50)
136
+
137
  # === Stabilization Phase ===
138
+ add_log("[INFO] Starting video stabilization...")
139
+ yield (None, None, add_log("Starting stabilization..."), 51)
140
+
141
+ # Read the CSV and compute cumulative motion data
142
  motion_data = {}
143
  cumulative_dx = 0.0
144
  cumulative_dy = 0.0
 
154
  cumulative_dx += dx
155
  cumulative_dy += dy
156
  motion_data[frame_num] = (-cumulative_dx, -cumulative_dy)
157
+ add_log("[INFO] Motion CSV read complete.")
158
+ yield (None, None, add_log(""), 55)
159
+
160
+ # Re-open video for stabilization
161
  cap = cv2.VideoCapture(video_file)
162
  fps = cap.get(cv2.CAP_PROP_FPS)
163
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
 
167
  temp_file.close()
168
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
169
  out = cv2.VideoWriter(output_file, fourcc, fps, (width, height))
170
+
171
  frame_idx = 1
172
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
173
  while True:
 
180
  start_x = max((zoomed_w - width) // 2, 0)
181
  start_y = max((zoomed_h - height) // 2, 0)
182
  frame = zoomed_frame[start_y:start_y+height, start_x:start_x+width]
183
+
184
  dx, dy = motion_data.get(frame_idx, (0, 0))
185
+ transform = np.array([[1, 0, dx], [0, 1, dy]], dtype=np.float32)
 
186
  stabilized_frame = cv2.warpAffine(frame, transform, (width, height))
187
  out.write(stabilized_frame)
188
+
189
  if frame_idx % 10 == 0 or frame_idx == total_frames:
190
+ progress_stab = 50 + (frame_idx / total_frames) * 50 # Stabilization phase is 50-100%
191
+ add_log(f"[INFO] Stabilization: Processed frame {frame_idx}/{total_frames}")
192
+ yield (None, None, add_log(""), progress_stab)
193
  frame_idx += 1
194
  cap.release()
195
  out.release()
196
+ add_log("[INFO] Stabilization complete.")
197
+ yield (video_file, output_file, add_log(""), 100)
 
 
 
 
 
 
 
 
 
198
 
199
+ # Build the Gradio UI with streaming enabled.
 
 
 
 
 
 
 
 
 
 
200
  with gr.Blocks() as demo:
201
  gr.Markdown("# AI-Powered Video Stabilization")
202
+ gr.Markdown("Upload a video and select a zoom factor. The system will generate motion data using an AI model (RAFT if available, else Farneback) and then stabilize the video. Logs and progress will update during processing.")
203
+
204
  with gr.Row():
205
  with gr.Column():
206
  video_input = gr.Video(label="Input Video")
207
  zoom_slider = gr.Slider(minimum=1.0, maximum=2.0, step=0.1, value=1.0, label="Zoom Factor")
208
+ process_button = gr.Button("Process Video")
209
  with gr.Column():
210
  original_video = gr.Video(label="Original Video")
211
  stabilized_video = gr.Video(label="Stabilized Video")
212
  logs_output = gr.Textbox(label="Logs", lines=15)
213
  progress_bar = gr.Slider(label="Progress", minimum=0, maximum=100, value=0, interactive=False)
214
+
215
+ demo.queue() # enable streaming
216
+
217
+ process_button.click(
218
+ fn=process_video_ai,
219
+ inputs=[video_input, zoom_slider],
220
+ outputs=[original_video, stabilized_video, logs_output, progress_bar],
221
+ stream=True # enable streaming updates
222
+ )
223
+
224
+ demo.launch()