Update app.py
Browse files
app.py
CHANGED
@@ -15,7 +15,6 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
15 |
print(f"[INFO] Using device: {device}")
|
16 |
|
17 |
# Try to load the RAFT model from torch.hub.
|
18 |
-
# If it fails, we fall back to OpenCV optical flow.
|
19 |
try:
|
20 |
print("[INFO] Attempting to load RAFT model from torch.hub...")
|
21 |
raft_model = torch.hub.load("princeton-vl/RAFT", "raft_small", pretrained=True, trust_repo=True)
|
@@ -26,13 +25,14 @@ except Exception as e:
|
|
26 |
print("[ERROR] Error loading RAFT model:", e)
|
27 |
print("[INFO] Falling back to OpenCV Farneback optical flow.")
|
28 |
raft_model = None
|
|
|
29 |
|
30 |
def generate_motion_csv(video_file, output_csv=None, progress=gr.Progress(), progress_offset=0.0, progress_scale=0.5):
|
31 |
"""
|
32 |
Generates a CSV file with motion data (columns: frame, mag, ang, zoom) from an input video.
|
33 |
Uses RAFT if available, otherwise falls back to OpenCV's Farneback optical flow.
|
34 |
|
35 |
-
|
36 |
"""
|
37 |
start_time = time.time()
|
38 |
if output_csv is None:
|
@@ -42,7 +42,7 @@ def generate_motion_csv(video_file, output_csv=None, progress=gr.Progress(), pro
|
|
42 |
|
43 |
cap = cv2.VideoCapture(video_file)
|
44 |
if not cap.isOpened():
|
45 |
-
raise
|
46 |
|
47 |
print(f"[INFO] Generating motion CSV for video: {video_file}")
|
48 |
with open(output_csv, 'w', newline='') as csvfile:
|
@@ -52,7 +52,7 @@ def generate_motion_csv(video_file, output_csv=None, progress=gr.Progress(), pro
|
|
52 |
|
53 |
ret, first_frame = cap.read()
|
54 |
if not ret:
|
55 |
-
raise
|
56 |
|
57 |
if raft_model is not None:
|
58 |
first_frame_rgb = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
|
@@ -86,12 +86,12 @@ def generate_motion_csv(video_file, output_csv=None, progress=gr.Progress(), pro
|
|
86 |
iterations=3, poly_n=5, poly_sigma=1.2, flags=0)
|
87 |
prev_gray = curr_gray
|
88 |
|
89 |
-
# Compute median magnitude and angle
|
90 |
mag, ang = cv2.cartToPolar(flow[..., 0], flow[..., 1], angleInDegrees=True)
|
91 |
median_mag = np.median(mag)
|
92 |
median_ang = np.median(ang)
|
93 |
|
94 |
-
# Compute
|
95 |
h, w = flow.shape[:2]
|
96 |
center_x, center_y = w / 2, h / 2
|
97 |
x_coords, y_coords = np.meshgrid(np.arange(w), np.arange(h))
|
@@ -110,7 +110,6 @@ def generate_motion_csv(video_file, output_csv=None, progress=gr.Progress(), pro
|
|
110 |
if frame_idx % 10 == 0 or frame_idx == total_frames:
|
111 |
print(f"[INFO] Processed frame {frame_idx}/{total_frames}")
|
112 |
|
113 |
-
# Update progress for this phase.
|
114 |
progress(progress_offset + (frame_idx / total_frames) * progress_scale, desc="Generating Motion CSV")
|
115 |
frame_idx += 1
|
116 |
|
@@ -121,11 +120,9 @@ def generate_motion_csv(video_file, output_csv=None, progress=gr.Progress(), pro
|
|
121 |
|
122 |
def read_motion_csv(csv_filename):
|
123 |
"""
|
124 |
-
Reads a motion CSV file
|
125 |
-
offset per frame for stabilization.
|
126 |
|
127 |
-
Returns
|
128 |
-
A dictionary mapping frame numbers to (dx, dy) offsets.
|
129 |
"""
|
130 |
print(f"[INFO] Reading motion CSV: {csv_filename}")
|
131 |
motion_data = {}
|
@@ -148,10 +145,10 @@ def read_motion_csv(csv_filename):
|
|
148 |
|
149 |
def stabilize_video_using_csv(video_file, csv_file, zoom=1.0, vertical_only=False, progress=gr.Progress(), progress_offset=0.5, progress_scale=0.5, output_file=None):
|
150 |
"""
|
151 |
-
Stabilizes the input video using motion data from the CSV
|
152 |
-
If vertical_only is True, only vertical motion is corrected
|
153 |
|
154 |
-
|
155 |
"""
|
156 |
start_time = time.time()
|
157 |
print(f"[INFO] Starting stabilization using CSV: {csv_file}")
|
@@ -159,7 +156,7 @@ def stabilize_video_using_csv(video_file, csv_file, zoom=1.0, vertical_only=Fals
|
|
159 |
|
160 |
cap = cv2.VideoCapture(video_file)
|
161 |
if not cap.isOpened():
|
162 |
-
raise
|
163 |
|
164 |
fps = cap.get(cv2.CAP_PROP_FPS)
|
165 |
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
@@ -192,7 +189,7 @@ def stabilize_video_using_csv(video_file, csv_file, zoom=1.0, vertical_only=Fals
|
|
192 |
|
193 |
dx, dy = motion_data.get(frame_idx, (0, 0))
|
194 |
if vertical_only:
|
195 |
-
dx = 0 #
|
196 |
transform = np.array([[1, 0, dx],
|
197 |
[0, 1, dy]], dtype=np.float32)
|
198 |
stabilized_frame = cv2.warpAffine(frame, transform, (width, height))
|
@@ -201,7 +198,6 @@ def stabilize_video_using_csv(video_file, csv_file, zoom=1.0, vertical_only=Fals
|
|
201 |
if frame_idx % 10 == 0 or frame_idx == total_frames:
|
202 |
print(f"[INFO] Stabilized frame {frame_idx}/{total_frames}")
|
203 |
|
204 |
-
# Update progress for stabilization phase.
|
205 |
progress(progress_offset + (frame_idx / total_frames) * progress_scale, desc="Stabilizing Video")
|
206 |
frame_idx += 1
|
207 |
|
@@ -214,26 +210,28 @@ def stabilize_video_using_csv(video_file, csv_file, zoom=1.0, vertical_only=Fals
|
|
214 |
def process_video_ai(video_file, zoom, vertical_only, progress=gr.Progress(track_tqdm=True)):
|
215 |
"""
|
216 |
Gradio interface function:
|
217 |
-
- Generates motion data
|
218 |
- Stabilizes the video based on the generated motion data.
|
219 |
- If vertical_only is True, only vertical stabilization is applied.
|
220 |
|
221 |
Returns:
|
222 |
-
|
223 |
"""
|
|
|
|
|
|
|
224 |
log_buffer = io.StringIO()
|
225 |
with redirect_stdout(log_buffer):
|
226 |
if isinstance(video_file, dict):
|
227 |
video_file = video_file.get("name", None)
|
228 |
if video_file is None:
|
229 |
-
raise
|
230 |
|
231 |
-
print("[INFO] Starting AI-powered video processing...")
|
232 |
-
# First half: Generate motion CSV.
|
233 |
csv_file = generate_motion_csv(video_file, progress=progress, progress_offset=0.0, progress_scale=0.5)
|
234 |
-
|
235 |
stabilized_path = stabilize_video_using_csv(video_file, csv_file, zoom=zoom, vertical_only=vertical_only,
|
236 |
progress=progress, progress_offset=0.5, progress_scale=0.5)
|
|
|
237 |
print("[INFO] Video processing complete.")
|
238 |
logs = log_buffer.getvalue()
|
239 |
return video_file, stabilized_path, logs
|
@@ -241,7 +239,7 @@ def process_video_ai(video_file, zoom, vertical_only, progress=gr.Progress(track
|
|
241 |
# Build the Gradio UI.
|
242 |
with gr.Blocks() as demo:
|
243 |
gr.Markdown("# AI-Powered Video Stabilization")
|
244 |
-
gr.Markdown("Upload a video, select a zoom factor, and choose whether to apply only vertical stabilization. The system will generate motion data using an AI model (RAFT if available) and then stabilize the video with live progress updates.")
|
245 |
|
246 |
with gr.Row():
|
247 |
with gr.Column():
|
|
|
15 |
print(f"[INFO] Using device: {device}")
|
16 |
|
17 |
# Try to load the RAFT model from torch.hub.
|
|
|
18 |
try:
|
19 |
print("[INFO] Attempting to load RAFT model from torch.hub...")
|
20 |
raft_model = torch.hub.load("princeton-vl/RAFT", "raft_small", pretrained=True, trust_repo=True)
|
|
|
25 |
print("[ERROR] Error loading RAFT model:", e)
|
26 |
print("[INFO] Falling back to OpenCV Farneback optical flow.")
|
27 |
raft_model = None
|
28 |
+
gr.Warning("Falling back to OpenCV Farneback optical flow.")
|
29 |
|
30 |
def generate_motion_csv(video_file, output_csv=None, progress=gr.Progress(), progress_offset=0.0, progress_scale=0.5):
|
31 |
"""
|
32 |
Generates a CSV file with motion data (columns: frame, mag, ang, zoom) from an input video.
|
33 |
Uses RAFT if available, otherwise falls back to OpenCV's Farneback optical flow.
|
34 |
|
35 |
+
Updates progress from progress_offset to progress_offset+progress_scale.
|
36 |
"""
|
37 |
start_time = time.time()
|
38 |
if output_csv is None:
|
|
|
42 |
|
43 |
cap = cv2.VideoCapture(video_file)
|
44 |
if not cap.isOpened():
|
45 |
+
raise gr.Error("Could not open video file for CSV generation.")
|
46 |
|
47 |
print(f"[INFO] Generating motion CSV for video: {video_file}")
|
48 |
with open(output_csv, 'w', newline='') as csvfile:
|
|
|
52 |
|
53 |
ret, first_frame = cap.read()
|
54 |
if not ret:
|
55 |
+
raise gr.Error("Cannot read first frame from video.")
|
56 |
|
57 |
if raft_model is not None:
|
58 |
first_frame_rgb = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
|
|
|
86 |
iterations=3, poly_n=5, poly_sigma=1.2, flags=0)
|
87 |
prev_gray = curr_gray
|
88 |
|
89 |
+
# Compute median magnitude and angle.
|
90 |
mag, ang = cv2.cartToPolar(flow[..., 0], flow[..., 1], angleInDegrees=True)
|
91 |
median_mag = np.median(mag)
|
92 |
median_ang = np.median(ang)
|
93 |
|
94 |
+
# Compute "zoom factor": fraction of pixels moving away from center.
|
95 |
h, w = flow.shape[:2]
|
96 |
center_x, center_y = w / 2, h / 2
|
97 |
x_coords, y_coords = np.meshgrid(np.arange(w), np.arange(h))
|
|
|
110 |
if frame_idx % 10 == 0 or frame_idx == total_frames:
|
111 |
print(f"[INFO] Processed frame {frame_idx}/{total_frames}")
|
112 |
|
|
|
113 |
progress(progress_offset + (frame_idx / total_frames) * progress_scale, desc="Generating Motion CSV")
|
114 |
frame_idx += 1
|
115 |
|
|
|
120 |
|
121 |
def read_motion_csv(csv_filename):
|
122 |
"""
|
123 |
+
Reads a motion CSV file and computes cumulative offset per frame.
|
|
|
124 |
|
125 |
+
Returns a dictionary mapping frame numbers to (dx, dy) offsets.
|
|
|
126 |
"""
|
127 |
print(f"[INFO] Reading motion CSV: {csv_filename}")
|
128 |
motion_data = {}
|
|
|
145 |
|
146 |
def stabilize_video_using_csv(video_file, csv_file, zoom=1.0, vertical_only=False, progress=gr.Progress(), progress_offset=0.5, progress_scale=0.5, output_file=None):
|
147 |
"""
|
148 |
+
Stabilizes the input video using motion data from the CSV.
|
149 |
+
If vertical_only is True, only vertical motion is corrected.
|
150 |
|
151 |
+
Updates progress from progress_offset to progress_offset+progress_scale.
|
152 |
"""
|
153 |
start_time = time.time()
|
154 |
print(f"[INFO] Starting stabilization using CSV: {csv_file}")
|
|
|
156 |
|
157 |
cap = cv2.VideoCapture(video_file)
|
158 |
if not cap.isOpened():
|
159 |
+
raise gr.Error("Could not open video file for stabilization.")
|
160 |
|
161 |
fps = cap.get(cv2.CAP_PROP_FPS)
|
162 |
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
|
189 |
|
190 |
dx, dy = motion_data.get(frame_idx, (0, 0))
|
191 |
if vertical_only:
|
192 |
+
dx = 0 # Only vertical stabilization.
|
193 |
transform = np.array([[1, 0, dx],
|
194 |
[0, 1, dy]], dtype=np.float32)
|
195 |
stabilized_frame = cv2.warpAffine(frame, transform, (width, height))
|
|
|
198 |
if frame_idx % 10 == 0 or frame_idx == total_frames:
|
199 |
print(f"[INFO] Stabilized frame {frame_idx}/{total_frames}")
|
200 |
|
|
|
201 |
progress(progress_offset + (frame_idx / total_frames) * progress_scale, desc="Stabilizing Video")
|
202 |
frame_idx += 1
|
203 |
|
|
|
210 |
def process_video_ai(video_file, zoom, vertical_only, progress=gr.Progress(track_tqdm=True)):
|
211 |
"""
|
212 |
Gradio interface function:
|
213 |
+
- Generates motion data from the input video using an AI model (RAFT if available, else Farneback).
|
214 |
- Stabilizes the video based on the generated motion data.
|
215 |
- If vertical_only is True, only vertical stabilization is applied.
|
216 |
|
217 |
Returns:
|
218 |
+
Tuple: (original video file path, stabilized video file path, log output)
|
219 |
"""
|
220 |
+
# Display an info alert.
|
221 |
+
gr.Info("Starting AI-powered video processing...")
|
222 |
+
|
223 |
log_buffer = io.StringIO()
|
224 |
with redirect_stdout(log_buffer):
|
225 |
if isinstance(video_file, dict):
|
226 |
video_file = video_file.get("name", None)
|
227 |
if video_file is None:
|
228 |
+
raise gr.Error("Please upload a video file.")
|
229 |
|
|
|
|
|
230 |
csv_file = generate_motion_csv(video_file, progress=progress, progress_offset=0.0, progress_scale=0.5)
|
231 |
+
gr.Info("Motion CSV generated successfully.")
|
232 |
stabilized_path = stabilize_video_using_csv(video_file, csv_file, zoom=zoom, vertical_only=vertical_only,
|
233 |
progress=progress, progress_offset=0.5, progress_scale=0.5)
|
234 |
+
gr.Info("Video stabilization complete.")
|
235 |
print("[INFO] Video processing complete.")
|
236 |
logs = log_buffer.getvalue()
|
237 |
return video_file, stabilized_path, logs
|
|
|
239 |
# Build the Gradio UI.
|
240 |
with gr.Blocks() as demo:
|
241 |
gr.Markdown("# AI-Powered Video Stabilization")
|
242 |
+
gr.Markdown("Upload a video, select a zoom factor, and choose whether to apply only vertical stabilization. The system will generate motion data using an AI model (RAFT if available) and then stabilize the video with live progress updates and alerts.")
|
243 |
|
244 |
with gr.Row():
|
245 |
with gr.Column():
|