Update app.py
Browse files
app.py
CHANGED
@@ -2,18 +2,112 @@ import cv2
|
|
2 |
import numpy as np
|
3 |
import csv
|
4 |
import math
|
|
|
5 |
import tempfile
|
6 |
import os
|
7 |
import gradio as gr
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
def read_motion_csv(csv_filename):
|
10 |
"""
|
11 |
-
Reads
|
12 |
-
|
13 |
-
accumulates these to build a per-frame cumulative offset.
|
14 |
|
15 |
Returns:
|
16 |
-
A dictionary mapping frame numbers to (dx, dy) offsets.
|
17 |
"""
|
18 |
motion_data = {}
|
19 |
cumulative_dx = 0.0
|
@@ -24,15 +118,13 @@ def read_motion_csv(csv_filename):
|
|
24 |
frame_num = int(row['frame'])
|
25 |
mag = float(row['mag'])
|
26 |
ang = float(row['ang'])
|
27 |
-
# Convert angle (in degrees) to radians
|
28 |
rad = math.radians(ang)
|
29 |
-
# Compute displacement vector from magnitude and angle
|
30 |
dx = mag * math.cos(rad)
|
31 |
dy = mag * math.sin(rad)
|
32 |
-
# Accumulate the displacement over frames
|
33 |
cumulative_dx += dx
|
34 |
cumulative_dy += dy
|
35 |
-
#
|
36 |
motion_data[frame_num] = (-cumulative_dx, -cumulative_dy)
|
37 |
return motion_data
|
38 |
|
@@ -42,25 +134,24 @@ def stabilize_video_using_csv(video_file, csv_file, zoom=1.0, output_file=None):
|
|
42 |
|
43 |
Args:
|
44 |
video_file (str): Path to the input video.
|
45 |
-
csv_file (str): Path to the CSV file
|
46 |
-
zoom (float):
|
47 |
output_file (str): Path for the output stabilized video. If None, a temporary file is created.
|
48 |
-
|
49 |
Returns:
|
50 |
-
output_file (str):
|
51 |
"""
|
52 |
# Read motion data from CSV
|
53 |
motion_data = read_motion_csv(csv_file)
|
54 |
|
55 |
cap = cv2.VideoCapture(video_file)
|
56 |
if not cap.isOpened():
|
57 |
-
raise ValueError("Could not open video file.")
|
58 |
|
59 |
fps = cap.get(cv2.CAP_PROP_FPS)
|
60 |
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
61 |
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
62 |
|
63 |
-
# Create a temporary file for output if not provided
|
64 |
if output_file is None:
|
65 |
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
|
66 |
output_file = temp_file.name
|
@@ -83,10 +174,10 @@ def stabilize_video_using_csv(video_file, csv_file, zoom=1.0, output_file=None):
|
|
83 |
start_y = max((zoomed_h - height) // 2, 0)
|
84 |
frame = zoomed_frame[start_y:start_y+height, start_x:start_x+width]
|
85 |
|
86 |
-
#
|
87 |
dx, dy = motion_data.get(frame_num, (0, 0))
|
88 |
|
89 |
-
# Apply an affine transformation to counteract the motion
|
90 |
transform = np.array([[1, 0, dx],
|
91 |
[0, 1, dy]], dtype=np.float32)
|
92 |
stabilized_frame = cv2.warpAffine(frame, transform, (width, height))
|
@@ -96,48 +187,47 @@ def stabilize_video_using_csv(video_file, csv_file, zoom=1.0, output_file=None):
|
|
96 |
|
97 |
cap.release()
|
98 |
out.release()
|
|
|
99 |
return output_file
|
100 |
|
101 |
-
def
|
102 |
-
"""
|
103 |
-
Gradio interface function to stabilize a video.
|
104 |
-
Accepts an input video file, a motion CSV file, and a zoom factor.
|
105 |
-
Returns the original video and the stabilized video.
|
106 |
"""
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
|
111 |
-
|
|
|
|
|
|
|
112 |
if isinstance(video_file, dict):
|
113 |
video_file = video_file.get("name", None)
|
114 |
-
if isinstance(csv_file, dict):
|
115 |
-
csv_file = csv_file.get("name", None)
|
116 |
-
|
117 |
-
# Check that both file paths are available
|
118 |
if video_file is None:
|
119 |
-
raise ValueError("
|
120 |
-
if csv_file is None:
|
121 |
-
raise ValueError("CSV file path is missing. Please upload a CSV file.")
|
122 |
|
|
|
|
|
|
|
123 |
stabilized_path = stabilize_video_using_csv(video_file, csv_file, zoom=zoom)
|
124 |
return video_file, stabilized_path
|
125 |
|
|
|
126 |
with gr.Blocks() as demo:
|
127 |
-
gr.Markdown("# Video Stabilization
|
|
|
|
|
128 |
with gr.Row():
|
129 |
with gr.Column():
|
130 |
video_input = gr.Video(label="Input Video")
|
131 |
-
csv_input = gr.File(label="Motion CSV File (e.g., video.flow.csv)", file_count="single")
|
132 |
zoom_slider = gr.Slider(minimum=1.0, maximum=2.0, step=0.1, value=1.0, label="Zoom Factor")
|
133 |
-
process_button = gr.Button("
|
134 |
with gr.Column():
|
135 |
original_video = gr.Video(label="Original Video")
|
136 |
stabilized_video = gr.Video(label="Stabilized Video")
|
137 |
|
138 |
process_button.click(
|
139 |
-
fn=
|
140 |
-
inputs=[video_input,
|
141 |
outputs=[original_video, stabilized_video]
|
142 |
)
|
143 |
|
|
|
2 |
import numpy as np
|
3 |
import csv
|
4 |
import math
|
5 |
+
import torch
|
6 |
import tempfile
|
7 |
import os
|
8 |
import gradio as gr
|
9 |
|
10 |
+
# Load the RAFT model from torch.hub (uses the 'raft_small' variant)
|
11 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
12 |
+
print(f"Using device: {device}")
|
13 |
+
model = torch.hub.load("princeton-vl/RAFT", "raft_small", pretrained=True)
|
14 |
+
model = model.to(device)
|
15 |
+
model.eval()
|
16 |
+
|
17 |
+
def generate_motion_csv(video_file, output_csv=None):
|
18 |
+
"""
|
19 |
+
Uses the RAFT model to compute optical flow between consecutive frames,
|
20 |
+
then writes a CSV file (with columns: frame, mag, ang, zoom) where:
|
21 |
+
- mag: median magnitude of the flow,
|
22 |
+
- ang: median angle (in degrees), and
|
23 |
+
- zoom: fraction of pixels moving away from the image center.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
video_file (str): Path to the input video.
|
27 |
+
output_csv (str): Optional path for output CSV file. If None, a temporary file is used.
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
output_csv (str): Path to the generated CSV file.
|
31 |
+
"""
|
32 |
+
if output_csv is None:
|
33 |
+
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.csv')
|
34 |
+
output_csv = temp_file.name
|
35 |
+
temp_file.close()
|
36 |
+
|
37 |
+
cap = cv2.VideoCapture(video_file)
|
38 |
+
if not cap.isOpened():
|
39 |
+
raise ValueError("Could not open video file for CSV generation.")
|
40 |
+
|
41 |
+
# Prepare CSV file for writing
|
42 |
+
with open(output_csv, 'w', newline='') as csvfile:
|
43 |
+
fieldnames = ['frame', 'mag', 'ang', 'zoom']
|
44 |
+
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
45 |
+
writer.writeheader()
|
46 |
+
|
47 |
+
ret, prev_frame = cap.read()
|
48 |
+
if not ret:
|
49 |
+
raise ValueError("Cannot read first frame from video.")
|
50 |
+
|
51 |
+
# Convert the first frame to tensor
|
52 |
+
prev_frame_rgb = cv2.cvtColor(prev_frame, cv2.COLOR_BGR2RGB)
|
53 |
+
prev_tensor = torch.from_numpy(prev_frame_rgb).permute(2,0,1).float().unsqueeze(0) / 255.0
|
54 |
+
prev_tensor = prev_tensor.to(device)
|
55 |
+
|
56 |
+
frame_idx = 1
|
57 |
+
while True:
|
58 |
+
ret, frame = cap.read()
|
59 |
+
if not ret:
|
60 |
+
break
|
61 |
+
|
62 |
+
curr_frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
63 |
+
curr_tensor = torch.from_numpy(curr_frame_rgb).permute(2,0,1).float().unsqueeze(0) / 255.0
|
64 |
+
curr_tensor = curr_tensor.to(device)
|
65 |
+
|
66 |
+
# Use RAFT to compute optical flow between previous and current frame.
|
67 |
+
with torch.no_grad():
|
68 |
+
# The RAFT model returns a low-resolution flow and an upsampled (high-res) flow.
|
69 |
+
flow_low, flow_up = model(prev_tensor, curr_tensor, iters=20, test_mode=True)
|
70 |
+
# Convert flow to numpy array (shape: H x W x 2)
|
71 |
+
flow = flow_up[0].permute(1,2,0).cpu().numpy()
|
72 |
+
|
73 |
+
# Compute median magnitude and angle for the optical flow
|
74 |
+
mag, ang = cv2.cartToPolar(flow[...,0], flow[...,1], angleInDegrees=True)
|
75 |
+
median_mag = np.median(mag)
|
76 |
+
median_ang = np.median(ang)
|
77 |
+
|
78 |
+
# Compute a "zoom factor": fraction of pixels moving away from the center.
|
79 |
+
h, w = flow.shape[:2]
|
80 |
+
center_x, center_y = w / 2, h / 2
|
81 |
+
x_coords, y_coords = np.meshgrid(np.arange(w), np.arange(h))
|
82 |
+
x_offset = x_coords - center_x
|
83 |
+
y_offset = y_coords - center_y
|
84 |
+
# Dot product between flow vectors and pixel offsets:
|
85 |
+
dot = flow[...,0] * x_offset + flow[...,1] * y_offset
|
86 |
+
zoom_factor = np.count_nonzero(dot > 0) / (w * h)
|
87 |
+
|
88 |
+
# Write the computed metrics to the CSV file.
|
89 |
+
writer.writerow({
|
90 |
+
'frame': frame_idx,
|
91 |
+
'mag': median_mag,
|
92 |
+
'ang': median_ang,
|
93 |
+
'zoom': zoom_factor
|
94 |
+
})
|
95 |
+
|
96 |
+
# Update for next iteration
|
97 |
+
prev_tensor = curr_tensor.clone()
|
98 |
+
frame_idx += 1
|
99 |
+
|
100 |
+
cap.release()
|
101 |
+
print(f"Motion CSV generated: {output_csv}")
|
102 |
+
return output_csv
|
103 |
+
|
104 |
def read_motion_csv(csv_filename):
|
105 |
"""
|
106 |
+
Reads the CSV file (columns: frame, mag, ang, zoom) and computes a cumulative
|
107 |
+
offset per frame to be used for stabilization.
|
|
|
108 |
|
109 |
Returns:
|
110 |
+
A dictionary mapping frame numbers to (dx, dy) offsets (the negative cumulative displacement).
|
111 |
"""
|
112 |
motion_data = {}
|
113 |
cumulative_dx = 0.0
|
|
|
118 |
frame_num = int(row['frame'])
|
119 |
mag = float(row['mag'])
|
120 |
ang = float(row['ang'])
|
121 |
+
# Convert angle (in degrees) to radians.
|
122 |
rad = math.radians(ang)
|
|
|
123 |
dx = mag * math.cos(rad)
|
124 |
dy = mag * math.sin(rad)
|
|
|
125 |
cumulative_dx += dx
|
126 |
cumulative_dy += dy
|
127 |
+
# Negative cumulative offset counteracts the detected motion.
|
128 |
motion_data[frame_num] = (-cumulative_dx, -cumulative_dy)
|
129 |
return motion_data
|
130 |
|
|
|
134 |
|
135 |
Args:
|
136 |
video_file (str): Path to the input video.
|
137 |
+
csv_file (str): Path to the motion CSV file.
|
138 |
+
zoom (float): Zoom factor to apply before stabilization (default: 1.0, no zoom).
|
139 |
output_file (str): Path for the output stabilized video. If None, a temporary file is created.
|
140 |
+
|
141 |
Returns:
|
142 |
+
output_file (str): Path to the stabilized video file.
|
143 |
"""
|
144 |
# Read motion data from CSV
|
145 |
motion_data = read_motion_csv(csv_file)
|
146 |
|
147 |
cap = cv2.VideoCapture(video_file)
|
148 |
if not cap.isOpened():
|
149 |
+
raise ValueError("Could not open video file for stabilization.")
|
150 |
|
151 |
fps = cap.get(cv2.CAP_PROP_FPS)
|
152 |
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
153 |
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
154 |
|
|
|
155 |
if output_file is None:
|
156 |
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
|
157 |
output_file = temp_file.name
|
|
|
174 |
start_y = max((zoomed_h - height) // 2, 0)
|
175 |
frame = zoomed_frame[start_y:start_y+height, start_x:start_x+width]
|
176 |
|
177 |
+
# Get the stabilization offset for the current frame (default to (0,0) if not available)
|
178 |
dx, dy = motion_data.get(frame_num, (0, 0))
|
179 |
|
180 |
+
# Apply an affine transformation to counteract the motion.
|
181 |
transform = np.array([[1, 0, dx],
|
182 |
[0, 1, dy]], dtype=np.float32)
|
183 |
stabilized_frame = cv2.warpAffine(frame, transform, (width, height))
|
|
|
187 |
|
188 |
cap.release()
|
189 |
out.release()
|
190 |
+
print(f"Stabilized video saved to: {output_file}")
|
191 |
return output_file
|
192 |
|
193 |
+
def process_video_ai(video_file, zoom):
|
|
|
|
|
|
|
|
|
194 |
"""
|
195 |
+
Gradio interface function: Given an input video and a zoom factor,
|
196 |
+
it uses a deep learning model (RAFT) to generate motion data (video.flow.csv)
|
197 |
+
and then stabilizes the video based on that data.
|
198 |
|
199 |
+
Returns:
|
200 |
+
A tuple containing the original video file path and the stabilized video file path.
|
201 |
+
"""
|
202 |
+
# Ensure the input is a file path (if provided as a dict, extract the "name")
|
203 |
if isinstance(video_file, dict):
|
204 |
video_file = video_file.get("name", None)
|
|
|
|
|
|
|
|
|
205 |
if video_file is None:
|
206 |
+
raise ValueError("Please upload a video file.")
|
|
|
|
|
207 |
|
208 |
+
# Generate motion CSV using AI-based optical flow (RAFT)
|
209 |
+
csv_file = generate_motion_csv(video_file)
|
210 |
+
# Stabilize the video using the generated CSV data
|
211 |
stabilized_path = stabilize_video_using_csv(video_file, csv_file, zoom=zoom)
|
212 |
return video_file, stabilized_path
|
213 |
|
214 |
+
# Build the Gradio interface
|
215 |
with gr.Blocks() as demo:
|
216 |
+
gr.Markdown("# AI-Powered Video Stabilization")
|
217 |
+
gr.Markdown("Upload a video and select a zoom factor. The system will automatically use a deep learning model (RAFT) to generate motion data and then stabilize the video.")
|
218 |
+
|
219 |
with gr.Row():
|
220 |
with gr.Column():
|
221 |
video_input = gr.Video(label="Input Video")
|
|
|
222 |
zoom_slider = gr.Slider(minimum=1.0, maximum=2.0, step=0.1, value=1.0, label="Zoom Factor")
|
223 |
+
process_button = gr.Button("Process Video")
|
224 |
with gr.Column():
|
225 |
original_video = gr.Video(label="Original Video")
|
226 |
stabilized_video = gr.Video(label="Stabilized Video")
|
227 |
|
228 |
process_button.click(
|
229 |
+
fn=process_video_ai,
|
230 |
+
inputs=[video_input, zoom_slider],
|
231 |
outputs=[original_video, stabilized_video]
|
232 |
)
|
233 |
|