|
import cv2 |
|
import numpy as np |
|
import csv |
|
import math |
|
import torch |
|
import tempfile |
|
import os |
|
import gradio as gr |
|
import time |
|
import io |
|
from contextlib import redirect_stdout |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
print(f"[INFO] Using device: {device}") |
|
|
|
|
|
try: |
|
print("[INFO] Attempting to load RAFT model from torch.hub...") |
|
raft_model = torch.hub.load("princeton-vl/RAFT", "raft_small", pretrained=True, trust_repo=True) |
|
raft_model = raft_model.to(device) |
|
raft_model.eval() |
|
print("[INFO] RAFT model loaded successfully.") |
|
except Exception as e: |
|
print("[ERROR] Error loading RAFT model:", e) |
|
print("[INFO] Falling back to OpenCV Farneback optical flow.") |
|
raft_model = None |
|
gr.Warning("Falling back to OpenCV Farneback optical flow.") |
|
|
|
def compress_video(video_file, target_width, target_height, progress=gr.Progress(), progress_offset=0.0, progress_scale=0.2, output_file=None): |
|
""" |
|
Compresses the video by resizing each frame to the specified target resolution. |
|
The new resolution is exactly (target_width, target_height). |
|
Updates progress from progress_offset to progress_offset+progress_scale. |
|
""" |
|
start_time = time.time() |
|
cap = cv2.VideoCapture(video_file) |
|
if not cap.isOpened(): |
|
raise gr.Error("Could not open video file for compression.") |
|
|
|
fps = cap.get(cv2.CAP_PROP_FPS) |
|
new_width = int(target_width) |
|
new_height = int(target_height) |
|
|
|
if output_file is None: |
|
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') |
|
output_file = temp_file.name |
|
temp_file.close() |
|
|
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
|
out = cv2.VideoWriter(output_file, fourcc, fps, (new_width, new_height)) |
|
|
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
frame_idx = 1 |
|
print(f"[INFO] Starting video compression: {total_frames} frames, target resolution: {new_width}x{new_height}") |
|
while True: |
|
ret, frame = cap.read() |
|
if not ret: |
|
break |
|
compressed_frame = cv2.resize(frame, (new_width, new_height)) |
|
out.write(compressed_frame) |
|
if frame_idx % 10 == 0 or frame_idx == total_frames: |
|
print(f"[INFO] Compressed frame {frame_idx}/{total_frames}") |
|
progress(progress_offset + (frame_idx / total_frames) * progress_scale, desc="Compressing Video") |
|
frame_idx += 1 |
|
|
|
cap.release() |
|
out.release() |
|
elapsed = time.time() - start_time |
|
print(f"[INFO] Compressed video saved to: {output_file} in {elapsed:.2f} seconds") |
|
return output_file |
|
|
|
def generate_motion_csv(video_file, output_csv=None, progress=gr.Progress(), progress_offset=0.0, progress_scale=0.4): |
|
""" |
|
Generates a CSV file with motion data (columns: frame, mag, ang, zoom) from an input video. |
|
Uses RAFT if available, otherwise falls back to OpenCV's Farneback optical flow. |
|
Updates progress from progress_offset to progress_offset+progress_scale. |
|
""" |
|
start_time = time.time() |
|
if output_csv is None: |
|
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.csv') |
|
output_csv = temp_file.name |
|
temp_file.close() |
|
|
|
cap = cv2.VideoCapture(video_file) |
|
if not cap.isOpened(): |
|
raise gr.Error("Could not open video file for CSV generation.") |
|
|
|
print(f"[INFO] Generating motion CSV for video: {video_file}") |
|
with open(output_csv, 'w', newline='') as csvfile: |
|
fieldnames = ['frame', 'mag', 'ang', 'zoom'] |
|
writer = csv.DictWriter(csvfile, fieldnames=fieldnames) |
|
writer.writeheader() |
|
|
|
ret, first_frame = cap.read() |
|
if not ret: |
|
raise gr.Error("Cannot read first frame from video.") |
|
|
|
if raft_model is not None: |
|
first_frame_rgb = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB) |
|
prev_tensor = torch.from_numpy(first_frame_rgb).permute(2, 0, 1).float().unsqueeze(0) / 255.0 |
|
prev_tensor = prev_tensor.to(device) |
|
print("[INFO] Using RAFT model for optical flow computation.") |
|
else: |
|
prev_gray = cv2.cvtColor(first_frame, cv2.COLOR_BGR2GRAY) |
|
print("[INFO] Using OpenCV Farneback optical flow for computation.") |
|
|
|
frame_idx = 1 |
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
print(f"[INFO] Total frames to process: {total_frames}") |
|
while True: |
|
ret, frame = cap.read() |
|
if not ret: |
|
break |
|
|
|
if raft_model is not None: |
|
curr_frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
curr_tensor = torch.from_numpy(curr_frame_rgb).permute(2, 0, 1).float().unsqueeze(0) / 255.0 |
|
curr_tensor = curr_tensor.to(device) |
|
with torch.no_grad(): |
|
flow_low, flow_up = raft_model(prev_tensor, curr_tensor, iters=20, test_mode=True) |
|
flow = flow_up[0].permute(1, 2, 0).cpu().numpy() |
|
prev_tensor = curr_tensor.clone() |
|
else: |
|
curr_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) |
|
flow = cv2.calcOpticalFlowFarneback(prev_gray, curr_gray, None, |
|
pyr_scale=0.5, levels=3, winsize=15, |
|
iterations=3, poly_n=5, poly_sigma=1.2, flags=0) |
|
prev_gray = curr_gray |
|
|
|
mag, ang = cv2.cartToPolar(flow[..., 0], flow[..., 1], angleInDegrees=True) |
|
median_mag = np.median(mag) |
|
median_ang = np.median(ang) |
|
|
|
h, w = flow.shape[:2] |
|
center_x, center_y = w / 2, h / 2 |
|
x_coords, y_coords = np.meshgrid(np.arange(w), np.arange(h)) |
|
x_offset = x_coords - center_x |
|
y_offset = y_coords - center_y |
|
dot = flow[..., 0] * x_offset + flow[..., 1] * y_offset |
|
zoom_factor = np.count_nonzero(dot > 0) / (w * h) |
|
|
|
writer.writerow({ |
|
'frame': frame_idx, |
|
'mag': median_mag, |
|
'ang': median_ang, |
|
'zoom': zoom_factor |
|
}) |
|
|
|
if frame_idx % 10 == 0 or frame_idx == total_frames: |
|
print(f"[INFO] Processed frame {frame_idx}/{total_frames}") |
|
|
|
progress(progress_offset + (frame_idx / total_frames) * progress_scale, desc="Generating Motion CSV") |
|
frame_idx += 1 |
|
|
|
cap.release() |
|
elapsed = time.time() - start_time |
|
print(f"[INFO] Motion CSV generated: {output_csv} in {elapsed:.2f} seconds") |
|
return output_csv |
|
|
|
def read_motion_csv(csv_filename): |
|
""" |
|
Reads a motion CSV file and computes cumulative offset per frame. |
|
Returns a dictionary mapping frame numbers to (dx, dy) offsets. |
|
""" |
|
print(f"[INFO] Reading motion CSV: {csv_filename}") |
|
motion_data = {} |
|
cumulative_dx = 0.0 |
|
cumulative_dy = 0.0 |
|
with open(csv_filename, 'r') as csvfile: |
|
reader = csv.DictReader(csvfile) |
|
for row in reader: |
|
frame_num = int(row['frame']) |
|
mag = float(row['mag']) |
|
ang = float(row['ang']) |
|
rad = math.radians(ang) |
|
dx = mag * math.cos(rad) |
|
dy = mag * math.sin(rad) |
|
cumulative_dx += dx |
|
cumulative_dy += dy |
|
motion_data[frame_num] = (-cumulative_dx, -cumulative_dy) |
|
print("[INFO] Completed reading motion CSV.") |
|
return motion_data |
|
|
|
def stabilize_video_using_csv(video_file, csv_file, zoom=1.0, vertical_only=False, |
|
progress=gr.Progress(), progress_offset=0.6, progress_scale=0.4, |
|
output_file=None): |
|
""" |
|
Stabilizes the video using motion data from the CSV. |
|
If vertical_only is True, only vertical motion is corrected. |
|
Updates progress from progress_offset to progress_offset+progress_scale. |
|
Uses cv2.BORDER_REPLICATE to fill any gaps, preventing black borders. |
|
""" |
|
start_time = time.time() |
|
print(f"[INFO] Starting stabilization using CSV: {csv_file}") |
|
motion_data = read_motion_csv(csv_file) |
|
|
|
cap = cv2.VideoCapture(video_file) |
|
if not cap.isOpened(): |
|
raise gr.Error("Could not open video file for stabilization.") |
|
|
|
fps = cap.get(cv2.CAP_PROP_FPS) |
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
print(f"[INFO] Video properties - FPS: {fps}, Width: {width}, Height: {height}") |
|
|
|
if output_file is None: |
|
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') |
|
output_file = temp_file.name |
|
temp_file.close() |
|
|
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
|
out = cv2.VideoWriter(output_file, fourcc, fps, (width, height)) |
|
|
|
frame_idx = 1 |
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
print(f"[INFO] Total frames to stabilize: {total_frames}") |
|
while True: |
|
ret, frame = cap.read() |
|
if not ret: |
|
break |
|
|
|
|
|
if zoom != 1.0: |
|
zoomed_frame = cv2.resize(frame, None, fx=zoom, fy=zoom, interpolation=cv2.INTER_LINEAR) |
|
zoomed_h, zoomed_w = zoomed_frame.shape[:2] |
|
start_x = max((zoomed_w - width) // 2, 0) |
|
start_y = max((zoomed_h - height) // 2, 0) |
|
frame = zoomed_frame[start_y:start_y+height, start_x:start_x+width] |
|
|
|
dx, dy = motion_data.get(frame_idx, (0, 0)) |
|
if vertical_only: |
|
dx = 0 |
|
transform = np.array([[1, 0, dx], |
|
[0, 1, dy]], dtype=np.float32) |
|
|
|
stabilized_frame = cv2.warpAffine(frame, transform, (width, height), borderMode=cv2.BORDER_REPLICATE) |
|
|
|
out.write(stabilized_frame) |
|
if frame_idx % 10 == 0 or frame_idx == total_frames: |
|
print(f"[INFO] Stabilized frame {frame_idx}/{total_frames}") |
|
|
|
progress(progress_offset + (frame_idx / total_frames) * progress_scale, desc="Stabilizing Video") |
|
frame_idx += 1 |
|
|
|
cap.release() |
|
out.release() |
|
elapsed = time.time() - start_time |
|
print(f"[INFO] Stabilized video saved to: {output_file} in {elapsed:.2f} seconds") |
|
return output_file |
|
|
|
def process_video_ai(video_file, zoom, vertical_only, compress_mode, target_width, target_height, auto_zoom, progress=gr.Progress(track_tqdm=True)): |
|
""" |
|
Gradio interface function: |
|
- Optionally compresses the video if compress_mode is True, resizing it to the chosen resolution. |
|
- Generates motion data from the (possibly compressed) video. |
|
- If auto_zoom is enabled, computes the optimal zoom level based on the maximum cumulative displacements. |
|
- Stabilizes the video based on the generated motion data. |
|
- If vertical_only is True, only vertical stabilization is applied. |
|
|
|
Returns: |
|
Tuple: (original video file path, stabilized video file path, log output) |
|
""" |
|
gr.Info("Starting AI-powered video processing...") |
|
log_buffer = io.StringIO() |
|
with redirect_stdout(log_buffer): |
|
if isinstance(video_file, dict): |
|
video_file = video_file.get("name", None) |
|
if video_file is None: |
|
raise gr.Error("Please upload a video file.") |
|
|
|
|
|
if compress_mode: |
|
gr.Info("Compressing video before processing...") |
|
video_file = compress_video(video_file, target_width, target_height, progress=progress, progress_offset=0.0, progress_scale=0.2) |
|
gr.Info("Video compression complete.") |
|
motion_offset = 0.2 |
|
motion_scale = 0.4 |
|
stabilization_offset = 0.6 |
|
stabilization_scale = 0.4 |
|
else: |
|
motion_offset = 0.0 |
|
motion_scale = 0.5 |
|
stabilization_offset = 0.5 |
|
stabilization_scale = 0.5 |
|
|
|
csv_file = generate_motion_csv(video_file, progress=progress, progress_offset=motion_offset, progress_scale=motion_scale) |
|
gr.Info("Motion CSV generated successfully.") |
|
|
|
|
|
if auto_zoom: |
|
gr.Info("Auto Zoom Mode enabled. Computing optimal zoom factor...") |
|
motion_data = read_motion_csv(csv_file) |
|
|
|
left_disp = abs(min(v[0] for v in motion_data.values())) |
|
right_disp = max(v[0] for v in motion_data.values()) |
|
top_disp = abs(min(v[1] for v in motion_data.values())) |
|
bottom_disp = max(v[1] for v in motion_data.values()) |
|
cap = cv2.VideoCapture(video_file) |
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
cap.release() |
|
safe_width = width - (left_disp + right_disp) |
|
safe_height = height - (top_disp + bottom_disp) |
|
zoom_x = width / safe_width if safe_width > 0 else 1.0 |
|
zoom_y = height / safe_height if safe_height > 0 else 1.0 |
|
auto_zoom_factor = max(1.0, zoom_x, zoom_y) |
|
gr.Info(f"Auto zoom factor computed: {auto_zoom_factor:.2f}") |
|
zoom = auto_zoom_factor |
|
|
|
stabilized_path = stabilize_video_using_csv(video_file, csv_file, zoom=zoom, vertical_only=vertical_only, |
|
progress=progress, progress_offset=stabilization_offset, progress_scale=stabilization_scale) |
|
gr.Info("Video stabilization complete.") |
|
print("[INFO] Video processing complete.") |
|
logs = log_buffer.getvalue() |
|
return video_file, stabilized_path, logs |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# AI-Powered Video Stabilization") |
|
gr.Markdown( |
|
"Upload a video, select a zoom factor (or use Auto Zoom Mode), choose whether to apply only vertical stabilization, and optionally compress the video before processing. " |
|
"If compressing, specify the target resolution (width and height) for the compressed video. " |
|
"The system will generate motion data using an AI model (RAFT if available) and then stabilize the video with live progress updates and alerts." |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
video_input = gr.Video(label="Input Video") |
|
zoom_slider = gr.Slider(minimum=1.0, maximum=3.0, step=0.1, value=1.0, label="Zoom Factor (ignored if Auto Zoom enabled)") |
|
auto_zoom_checkbox = gr.Checkbox(label="Auto Zoom Mode", value=False) |
|
vertical_checkbox = gr.Checkbox(label="Vertical Stabilization Only", value=False) |
|
compress_checkbox = gr.Checkbox(label="Compress Video Before Processing", value=False) |
|
target_width = gr.Number(label="Target Width (px)", value=640) |
|
target_height = gr.Number(label="Target Height (px)", value=360) |
|
process_button = gr.Button("Process Video") |
|
with gr.Column(): |
|
original_video = gr.Video(label="Original Video") |
|
stabilized_video = gr.Video(label="Stabilized Video") |
|
logs_output = gr.Textbox(label="Logs", lines=10) |
|
|
|
process_button.click( |
|
fn=process_video_ai, |
|
inputs=[video_input, zoom_slider, vertical_checkbox, compress_checkbox, target_width, target_height, auto_zoom_checkbox], |
|
outputs=[original_video, stabilized_video, logs_output] |
|
) |
|
|
|
demo.launch() |
|
|