Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import spaces | |
from transformers import AutoModelForImageSegmentation | |
from torchvision import transforms | |
import moviepy.editor as mp | |
from PIL import Image | |
import numpy as np | |
import tempfile | |
import time | |
import os | |
import shutil | |
import ffmpeg | |
from concurrent.futures import ThreadPoolExecutor | |
from gradio.themes.base import Base | |
from gradio.themes.utils import colors, fonts | |
# Custom Theme Definition | |
class WhiteTheme(Base): | |
def __init__( | |
self, | |
*, | |
primary_hue: colors.Color | str = colors.orange, | |
font: fonts.Font | str | tuple[fonts.Font | str, ...] = ( | |
fonts.GoogleFont("Inter"), | |
"ui-sans-serif", | |
"system-ui", | |
"sans-serif", | |
), | |
font_mono: fonts.Font | str | tuple[fonts.Font | str, ...] = ( | |
fonts.GoogleFont("Inter"), | |
"ui-monospace", | |
"system-ui", | |
"monospace", | |
) | |
): | |
super().__init__( | |
primary_hue=primary_hue, | |
font=font, | |
font_mono=font_mono, | |
) | |
self.set( | |
# Light mode specific colors | |
background_fill_primary="*primary_50", | |
background_fill_secondary="white", | |
border_color_primary="*primary_300", | |
# General colors that should stay constant | |
body_background_fill="white", | |
body_background_fill_dark="white", | |
block_background_fill="white", | |
block_background_fill_dark="white", | |
panel_background_fill="white", | |
panel_background_fill_dark="white", | |
body_text_color="black", | |
body_text_color_dark="black", | |
block_label_text_color="black", | |
block_label_text_color_dark="black", | |
block_border_color="white", | |
panel_border_color="white", | |
input_border_color="lightgray", | |
input_background_fill="white", | |
input_background_fill_dark="white", | |
shadow_drop="none" | |
) | |
# Set precision and device | |
torch.set_float32_matmul_precision("medium") | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Load models | |
print("Loading models...") | |
birefnet = AutoModelForImageSegmentation.from_pretrained("ZhengPeng7/BiRefNet", trust_remote_code=True) | |
birefnet.to(device) | |
birefnet_lite = AutoModelForImageSegmentation.from_pretrained("ZhengPeng7/BiRefNet_lite", trust_remote_code=True) | |
birefnet_lite.to(device) | |
print("Models loaded successfully!") | |
# Image transformation | |
transform_image = transforms.Compose([ | |
transforms.Resize((1024, 1024)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], | |
[0.229, 0.224, 0.225]), | |
]) | |
def process_frame(frame, fast_mode=True): | |
""" | |
Process a single frame through the BiRefNet model. | |
Maintains original resolution throughout processing. | |
Returns a PIL Image with alpha channel. | |
""" | |
try: | |
# Preserve original resolution for final output | |
image_ori = Image.fromarray(frame).convert('RGB') | |
original_size = image_ori.size | |
# Transform for model input while maintaining aspect ratio | |
input_images = transform_image(image_ori).unsqueeze(0).to(device) | |
# Select model based on mode | |
model = birefnet_lite if fast_mode else birefnet | |
with torch.no_grad(): | |
preds = model(input_images)[-1].sigmoid().cpu() | |
pred = preds[0].squeeze() | |
# Resize mask back to original resolution | |
pred_pil = transforms.ToPILImage()(pred) | |
pred_pil = pred_pil.resize(original_size, Image.BICUBIC) | |
# Create foreground with transparency | |
foreground = image_ori.copy() | |
foreground.putalpha(pred_pil) | |
return foreground | |
except Exception as e: | |
print(f"Error processing frame: {e}") | |
return None | |
# 5-minute duration for processing | |
def process_video(video_path, fps=0, fast_mode=True, max_workers=6): | |
""" | |
Process video to create transparent MOV file using ProRes 4444. | |
Maintains original resolution and framerate if fps=0. | |
""" | |
temp_dir = None | |
try: | |
start_time = time.time() | |
video = mp.VideoFileClip(video_path) | |
# Use original video FPS if not specified | |
if fps == 0: | |
fps = video.fps | |
frames = list(video.iter_frames(fps=fps)) | |
total_frames = len(frames) | |
print(f"Processing {total_frames} frames at {fps} FPS...") | |
# Create temporary directory for PNG sequence | |
temp_dir = tempfile.mkdtemp() | |
png_dir = os.path.join(temp_dir, "frames") | |
os.makedirs(png_dir, exist_ok=True) | |
# Prepare to collect processed frames for live preview | |
processed_frames = [] | |
# Process frames with parallel execution | |
with ThreadPoolExecutor(max_workers=max_workers) as executor: | |
futures = [executor.submit(process_frame, frame, fast_mode) for frame in frames] | |
for i, future in enumerate(futures): | |
try: | |
result = future.result() | |
if result: | |
# Save frame as PNG with transparency | |
frame_path = os.path.join(png_dir, f"frame_{i:06d}.png") | |
result.save(frame_path, "PNG") | |
# Collect processed frames for live preview | |
processed_frames.append(np.array(result)) | |
# Update live preview | |
elapsed_time = time.time() - start_time | |
yield processed_frames[-1], None, None, None, f"Processing frame {i+1}/{total_frames}... Elapsed time: {elapsed_time:.2f} seconds" | |
if (i + 1) % 10 == 0: | |
print(f"Processed {i+1}/{total_frames} frames") | |
except Exception as e: | |
print(f"Error processing frame {i+1}: {e}") | |
print("Creating output files...") | |
# Create permanent output directory | |
output_dir = os.path.join(os.path.dirname(video_path), "output") | |
os.makedirs(output_dir, exist_ok=True) | |
# Create ZIP file of PNG sequence | |
zip_filename = f"frames_{int(time.time())}.zip" | |
zip_path = os.path.join(output_dir, zip_filename) | |
shutil.make_archive(zip_path[:-4], 'zip', png_dir) | |
# Create MOV file with ProRes 4444 | |
print("Creating ProRes 4444 MOV...") | |
mov_filename = f"video_{int(time.time())}.mov" | |
mov_path = os.path.join(output_dir, mov_filename) | |
try: | |
# FFmpeg settings for high-quality ProRes 4444 | |
stream = ffmpeg.input( | |
os.path.join(png_dir, 'frame_%06d.png'), | |
pattern_type='sequence', | |
framerate=fps | |
) | |
# ProRes 4444 settings for maximum quality with alpha | |
stream = ffmpeg.output( | |
stream, | |
mov_path, | |
vcodec='prores_ks', # ProRes codec | |
pix_fmt='yuva444p10le', # 10-bit 4:4:4:4 pixel format with alpha | |
profile='4444', # ProRes 4444 profile for alpha support | |
alpha_bits=16, # Maximum alpha bit depth | |
qscale=1, # Highest quality setting | |
vendor='ap10', # Standard ProRes vendor tag | |
bits_per_mb=8000, # High bitrate for quality | |
threads=max_workers # Parallel processing | |
) | |
# Run FFmpeg command | |
ffmpeg.run(stream, overwrite_output=True, capture_stdout=True, capture_stderr=True) | |
print("MOV video created successfully!") | |
except ffmpeg.Error as e: | |
print(f"Error creating MOV video: {e.stderr.decode() if e.stderr else str(e)}") | |
mov_path = None | |
print("Processing complete!") | |
# Yield the final outputs | |
yield None, zip_path, mov_path, None, f"Processing complete! Total time: {time.time() - start_time:.2f} seconds" | |
except Exception as e: | |
print(f"Error: {e}") | |
yield None, None, None, None, f"Error processing video: {e}" | |
finally: | |
# Clean up temporary directory | |
if temp_dir and os.path.exists(temp_dir): | |
try: | |
shutil.rmtree(temp_dir) | |
except Exception as e: | |
print(f"Error cleaning up temp directory: {e}") | |
# Match process_video duration | |
def process_wrapper(video, fps=0, fast_mode=True, max_workers=6): | |
if video is None: | |
raise gr.Error("Please upload a video.") | |
try: | |
for outputs in process_video(video, fps, fast_mode, max_workers): | |
yield outputs | |
except Exception as e: | |
raise gr.Error(f"Error processing video: {str(e)}") | |
# Custom CSS for styling | |
custom_css = """ | |
.title-container { | |
text-align: center; | |
padding: 10px 0; | |
} | |
#title { | |
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif; | |
font-size: 36px; | |
font-weight: bold; | |
color: #000000; | |
padding: 10px; | |
border-radius: 10px; | |
display: inline-block; | |
background: linear-gradient( | |
135deg, | |
#e0f7fa, #e8f5e9, #fff9c4, #ffebee, | |
#f3e5f5, #e1f5fe, #fff3e0, #e8eaf6 | |
); | |
background-size: 400% 400%; | |
animation: gradient-animation 15s ease infinite; | |
} | |
@keyframes gradient-animation { | |
0% { background-position: 0% 50%; } | |
50% { background-position: 100% 50%; } | |
100% { background-position: 0% 50%; } | |
} | |
#submit-button { | |
background: linear-gradient( | |
135deg, | |
#e0f7fa, #e8f5e9, #fff9c4, #ffebee, | |
#f3e5f5, #e1f5fe, #fff3e0, #e8eaf6 | |
); | |
background-size: 400% 400%; | |
animation: gradient-animation 15s ease infinite; | |
border-radius: 12px; | |
color: black; | |
} | |
/* Force light mode styles */ | |
:root, :root[data-theme='light'], :root[data-theme='dark'] { | |
--body-background-fill: white !important; | |
--background-fill-primary: white !important; | |
--background-fill-secondary: white !important; | |
--block-background-fill: white !important; | |
--panel-background-fill: white !important; | |
--body-text-color: black !important; | |
--block-label-text-color: black !important; | |
} | |
/* Additional overrides for dark mode */ | |
@media (prefers-color-scheme: dark) { | |
:root { | |
color-scheme: light; | |
} | |
} | |
""" | |
# Gradio Interface | |
with gr.Blocks(css=custom_css, theme=WhiteTheme()) as demo: | |
gr.HTML(''' | |
<div class="title-container"> | |
<div id="title"> | |
<span>{.</span><span id="typed-text"></span><span>}</span> | |
</div> | |
</div> | |
<script> | |
(function() { | |
const text = "video"; | |
const typedTextSpan = document.getElementById("typed-text"); | |
let charIndex = 0; | |
function type() { | |
if (charIndex < text.length) { | |
typedTextSpan.textContent += text[charIndex]; | |
charIndex++; | |
setTimeout(type, 150); | |
} | |
} | |
setTimeout(type, 150); | |
})(); | |
</script> | |
''') | |
with gr.Row(): | |
with gr.Column(): | |
video_input = gr.Video( | |
label="Upload Video", | |
interactive=True, | |
show_label=True, | |
height=360, | |
width=640 | |
) | |
with gr.Row(): | |
fps_slider = gr.Slider( | |
minimum=0, | |
maximum=60, | |
step=1, | |
value=0, | |
label="Output FPS (0 will inherit the original fps value)", | |
) | |
fast_mode_checkbox = gr.Checkbox( | |
label="Fast Mode (Use BiRefNet_lite)", | |
value=True | |
) | |
max_workers_slider = gr.Slider( | |
minimum=1, | |
maximum=32, | |
step=1, | |
value=6, | |
label="Max Workers", | |
info="Determines how many frames to process in parallel" | |
) | |
btn = gr.Button("Process Video", elem_id="submit-button") | |
with gr.Column(): | |
preview_image = gr.Image(label="Live Preview", show_label=True) | |
output_foreground_zip = gr.File(label="Download PNG Sequence (ZIP)") | |
output_foreground_video = gr.File(label="Download Video (ProRes 4444 MOV with transparency)") | |
output_background = gr.Video(label="Background (Coming Soon)") | |
time_textbox = gr.Textbox(label="Status", interactive=False) | |
gr.Markdown(""" | |
### Output Information | |
- MOV file uses ProRes 4444 codec for professional-grade alpha channel | |
- Original resolution and framerate are maintained | |
- PNG sequence provided for maximum compatibility | |
""") | |
btn.click( | |
fn=process_wrapper, | |
inputs=[video_input, fps_slider, fast_mode_checkbox, max_workers_slider], | |
outputs=[preview_image, output_foreground_zip, output_foreground_video, | |
output_background, time_textbox] | |
) | |
if __name__ == "__main__": | |
demo.launch(debug=True) |