import gradio as gr import torch import spaces from transformers import AutoModelForImageSegmentation from torchvision import transforms import moviepy.editor as mp from PIL import Image import numpy as np import tempfile import time import os import shutil import ffmpeg from concurrent.futures import ThreadPoolExecutor from gradio.themes.base import Base from gradio.themes.utils import colors, fonts # Custom Theme Definition class WhiteTheme(Base): def __init__( self, *, primary_hue: colors.Color | str = colors.orange, font: fonts.Font | str | tuple[fonts.Font | str, ...] = ( fonts.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif", ), font_mono: fonts.Font | str | tuple[fonts.Font | str, ...] = ( fonts.GoogleFont("Inter"), "ui-monospace", "system-ui", "monospace", ) ): super().__init__( primary_hue=primary_hue, font=font, font_mono=font_mono, ) self.set( # Light mode specific colors background_fill_primary="*primary_50", background_fill_secondary="white", border_color_primary="*primary_300", # General colors that should stay constant body_background_fill="white", body_background_fill_dark="white", block_background_fill="white", block_background_fill_dark="white", panel_background_fill="white", panel_background_fill_dark="white", body_text_color="black", body_text_color_dark="black", block_label_text_color="black", block_label_text_color_dark="black", block_border_color="white", panel_border_color="white", input_border_color="lightgray", input_background_fill="white", input_background_fill_dark="white", shadow_drop="none" ) # Set precision and device torch.set_float32_matmul_precision("medium") device = "cuda" if torch.cuda.is_available() else "cpu" # Load models print("Loading models...") birefnet = AutoModelForImageSegmentation.from_pretrained("ZhengPeng7/BiRefNet", trust_remote_code=True) birefnet.to(device) birefnet_lite = AutoModelForImageSegmentation.from_pretrained("ZhengPeng7/BiRefNet_lite", trust_remote_code=True) birefnet_lite.to(device) print("Models loaded successfully!") # Image transformation transform_image = transforms.Compose([ transforms.Resize((1024, 1024)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) def process_frame(frame, fast_mode=True): """ Process a single frame through the BiRefNet model. Maintains original resolution throughout processing. Returns a PIL Image with alpha channel. """ try: # Preserve original resolution for final output image_ori = Image.fromarray(frame).convert('RGB') original_size = image_ori.size # Transform for model input while maintaining aspect ratio input_images = transform_image(image_ori).unsqueeze(0).to(device) # Select model based on mode model = birefnet_lite if fast_mode else birefnet with torch.no_grad(): preds = model(input_images)[-1].sigmoid().cpu() pred = preds[0].squeeze() # Resize mask back to original resolution pred_pil = transforms.ToPILImage()(pred) pred_pil = pred_pil.resize(original_size, Image.BICUBIC) # Create foreground with transparency foreground = image_ori.copy() foreground.putalpha(pred_pil) return foreground except Exception as e: print(f"Error processing frame: {e}") return None @spaces.GPU(duration=300) # 5-minute duration for processing def process_video(video_path, fps=0, fast_mode=True, max_workers=6): """ Process video to create transparent MOV file using ProRes 4444. Maintains original resolution and framerate if fps=0. """ temp_dir = None try: start_time = time.time() video = mp.VideoFileClip(video_path) # Use original video FPS if not specified if fps == 0: fps = video.fps frames = list(video.iter_frames(fps=fps)) total_frames = len(frames) print(f"Processing {total_frames} frames at {fps} FPS...") # Create temporary directory for PNG sequence temp_dir = tempfile.mkdtemp() png_dir = os.path.join(temp_dir, "frames") os.makedirs(png_dir, exist_ok=True) # Prepare to collect processed frames for live preview processed_frames = [] # Process frames with parallel execution with ThreadPoolExecutor(max_workers=max_workers) as executor: futures = [executor.submit(process_frame, frame, fast_mode) for frame in frames] for i, future in enumerate(futures): try: result = future.result() if result: # Save frame as PNG with transparency frame_path = os.path.join(png_dir, f"frame_{i:06d}.png") result.save(frame_path, "PNG") # Collect processed frames for live preview processed_frames.append(np.array(result)) # Update live preview elapsed_time = time.time() - start_time yield processed_frames[-1], None, None, None, f"Processing frame {i+1}/{total_frames}... Elapsed time: {elapsed_time:.2f} seconds" if (i + 1) % 10 == 0: print(f"Processed {i+1}/{total_frames} frames") except Exception as e: print(f"Error processing frame {i+1}: {e}") print("Creating output files...") # Create permanent output directory output_dir = os.path.join(os.path.dirname(video_path), "output") os.makedirs(output_dir, exist_ok=True) # Create ZIP file of PNG sequence zip_filename = f"frames_{int(time.time())}.zip" zip_path = os.path.join(output_dir, zip_filename) shutil.make_archive(zip_path[:-4], 'zip', png_dir) # Create MOV file with ProRes 4444 print("Creating ProRes 4444 MOV...") mov_filename = f"video_{int(time.time())}.mov" mov_path = os.path.join(output_dir, mov_filename) try: # FFmpeg settings for high-quality ProRes 4444 stream = ffmpeg.input( os.path.join(png_dir, 'frame_%06d.png'), pattern_type='sequence', framerate=fps ) # ProRes 4444 settings for maximum quality with alpha stream = ffmpeg.output( stream, mov_path, vcodec='prores_ks', # ProRes codec pix_fmt='yuva444p10le', # 10-bit 4:4:4:4 pixel format with alpha profile='4444', # ProRes 4444 profile for alpha support alpha_bits=16, # Maximum alpha bit depth qscale=1, # Highest quality setting vendor='ap10', # Standard ProRes vendor tag bits_per_mb=8000, # High bitrate for quality threads=max_workers # Parallel processing ) # Run FFmpeg command ffmpeg.run(stream, overwrite_output=True, capture_stdout=True, capture_stderr=True) print("MOV video created successfully!") except ffmpeg.Error as e: print(f"Error creating MOV video: {e.stderr.decode() if e.stderr else str(e)}") mov_path = None print("Processing complete!") # Yield the final outputs yield None, zip_path, mov_path, None, f"Processing complete! Total time: {time.time() - start_time:.2f} seconds" except Exception as e: print(f"Error: {e}") yield None, None, None, None, f"Error processing video: {e}" finally: # Clean up temporary directory if temp_dir and os.path.exists(temp_dir): try: shutil.rmtree(temp_dir) except Exception as e: print(f"Error cleaning up temp directory: {e}") @spaces.GPU(duration=300) # Match process_video duration def process_wrapper(video, fps=0, fast_mode=True, max_workers=6): if video is None: raise gr.Error("Please upload a video.") try: for outputs in process_video(video, fps, fast_mode, max_workers): yield outputs except Exception as e: raise gr.Error(f"Error processing video: {str(e)}") # Custom CSS for styling custom_css = """ .title-container { text-align: center; padding: 10px 0; } #title { font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif; font-size: 36px; font-weight: bold; color: #000000; padding: 10px; border-radius: 10px; display: inline-block; background: linear-gradient( 135deg, #e0f7fa, #e8f5e9, #fff9c4, #ffebee, #f3e5f5, #e1f5fe, #fff3e0, #e8eaf6 ); background-size: 400% 400%; animation: gradient-animation 15s ease infinite; } @keyframes gradient-animation { 0% { background-position: 0% 50%; } 50% { background-position: 100% 50%; } 100% { background-position: 0% 50%; } } #submit-button { background: linear-gradient( 135deg, #e0f7fa, #e8f5e9, #fff9c4, #ffebee, #f3e5f5, #e1f5fe, #fff3e0, #e8eaf6 ); background-size: 400% 400%; animation: gradient-animation 15s ease infinite; border-radius: 12px; color: black; } /* Force light mode styles */ :root, :root[data-theme='light'], :root[data-theme='dark'] { --body-background-fill: white !important; --background-fill-primary: white !important; --background-fill-secondary: white !important; --block-background-fill: white !important; --panel-background-fill: white !important; --body-text-color: black !important; --block-label-text-color: black !important; } /* Additional overrides for dark mode */ @media (prefers-color-scheme: dark) { :root { color-scheme: light; } } """ # Gradio Interface with gr.Blocks(css=custom_css, theme=WhiteTheme()) as demo: gr.HTML('''