import gradio as gr from loadimg import load_img import spaces from transformers import AutoModelForImageSegmentation import torch from torchvision import transforms import moviepy.editor as mp from pydub import AudioSegment from PIL import Image import numpy as np import os import tempfile import uuid import time from concurrent.futures import ThreadPoolExecutor from PIL import Image, ImageSequence import base64 import io import numpy as np import tempfile from gradio_imageslider import ImageSlider torch.set_float32_matmul_precision(["high", "highest"][0]) device = "cuda" if torch.cuda.is_available() else "cpu" # Maximum image size Image.MAX_IMAGE_PIXELS = None # Load both BiRefNet models birefnet = AutoModelForImageSegmentation.from_pretrained( "ZhengPeng7/BiRefNet", trust_remote_code=True ) birefnet.to(device) birefnet_lite = AutoModelForImageSegmentation.from_pretrained( "ZhengPeng7/BiRefNet_lite", trust_remote_code=True ) birefnet_lite.to(device) transform_image = transforms.Compose( [ transforms.Resize((1024, 1024)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ] ) # Video processing # Function to process a single frame def process_frame( frame, bg_type, bg, fast_mode, bg_frame_index, background_frames, color ): try: pil_image = Image.fromarray(frame) if bg_type == "Color": processed_image = process(pil_image, color, fast_mode) elif bg_type == "Image": processed_image = process(pil_image, bg, fast_mode) elif bg_type == "Video": background_frame = background_frames[ bg_frame_index ] # Access the correct background frame bg_frame_index += 1 background_image = Image.fromarray(background_frame) processed_image = process(pil_image, background_image, fast_mode) else: processed_image = ( pil_image # Default to original image if no background is selected ) return np.array(processed_image), bg_frame_index except Exception as e: print(f"Error processing frame: {e}") return frame, bg_frame_index @spaces.GPU def remove_bg_video( vid, bg_type="Color", bg_image=None, bg_video=None, color="#00FF00", fps=0, video_handling="slow_down", fast_mode=True, max_workers=6, ): try: start_time = time.time() # Start the timer video = mp.VideoFileClip(vid) if fps == 0: fps = video.fps audio = video.audio frames = list(video.iter_frames(fps=fps)) processed_frames = [] yield gr.update(visible=True), gr.update( visible=False ), f"Processing started... Elapsed time: 0 seconds" if bg_type == "Video": background_video = mp.VideoFileClip(bg_video) if background_video.duration < video.duration: if video_handling == "slow_down": background_video = background_video.fx( mp.vfx.speedx, factor=video.duration / background_video.duration ) else: # video_handling == "loop" background_video = mp.concatenate_videoclips( [background_video] * int(video.duration / background_video.duration + 1) ) background_frames = list(background_video.iter_frames(fps=fps)) else: background_frames = None bg_frame_index = 0 # Initialize background frame index with ThreadPoolExecutor(max_workers=max_workers) as executor: # Pass bg_frame_index as part of the function arguments futures = [ executor.submit( process_frame, frames[i], bg_type, bg_image, fast_mode, bg_frame_index + i, background_frames, color, ) for i in range(len(frames)) ] for i, future in enumerate(futures): result, _ = future.result() # No need to update bg_frame_index here processed_frames.append(result) elapsed_time = time.time() - start_time yield result, None, f"Processing frame {i+1}/{len(frames)}... Elapsed time: {elapsed_time:.2f} seconds" processed_video = mp.ImageSequenceClip(processed_frames, fps=fps) processed_video = processed_video.set_audio(audio) with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_file: temp_filepath = temp_file.name processed_video.write_videofile(temp_filepath, codec="libx264") elapsed_time = time.time() - start_time yield gr.update(visible=False), gr.update( visible=True ), f"Processing complete! Elapsed time: {elapsed_time:.2f} seconds" yield processed_frames[ -1 ], temp_filepath, f"Processing complete! Elapsed time: {elapsed_time:.2f} seconds" except Exception as e: print(f"Error: {e}") elapsed_time = time.time() - start_time yield gr.update(visible=False), gr.update( visible=True ), f"Error processing video: {e}. Elapsed time: {elapsed_time:.2f} seconds" yield None, f"Error processing video: {e}", f"Error processing video: {e}. Elapsed time: {elapsed_time:.2f} seconds" def process(image, bg, fast_mode=False): image_size = image.size input_images = transform_image(image).unsqueeze(0).to(device) model = birefnet_lite if fast_mode else birefnet with torch.no_grad(): preds = model(input_images)[-1].sigmoid().cpu() pred = preds[0].squeeze() pred_pil = transforms.ToPILImage()(pred) mask = pred_pil.resize(image_size) if isinstance(bg, str) and bg.startswith("#"): color_rgb = tuple(int(bg[i : i + 2], 16) for i in (1, 3, 5)) background = Image.new("RGBA", image_size, color_rgb + (255,)) elif isinstance(bg, Image.Image): background = bg.convert("RGBA").resize(image_size) else: background = Image.open(bg).convert("RGBA").resize(image_size) image = Image.composite(image, background, mask) return image # Image processing # Function to remove background from an image def remove_bg_fn(image): im = load_img(image, output_type="pil") im = im.convert("RGB") origin = im.copy() if im.format == "GIF": frames = [] for frame in ImageSequence.Iterator(im): frame = frame.convert("RGBA") processed_frame = process_image(frame) frames.append(processed_frame) processed_image = frames[0] processed_image.save( io.BytesIO(), format="GIF", save_all=True, append_images=frames[1:], loop=0, ) else: processed_image = process_image(im) return (processed_image, origin) @spaces.GPU def process_image(image): image_size = image.size input_images = transform_image(image).unsqueeze(0).to(device) # Prediction with torch.no_grad(): preds = birefnet(input_images)[-1].sigmoid().cpu() pred = preds[0].squeeze() pred_pil = transforms.ToPILImage()(pred) mask = pred_pil.resize(image_size) image.putalpha(mask) return image # Function to apply background to an image @spaces.GPU def apply_background(image, background): if background.mode != "RGBA": background = background.convert("RGBA") image = image.convert("RGBA") combined = Image.alpha_composite(background, image) return combined # Function to convert hex color to RGBA def hex_to_rgba(hex_color): hex_color = hex_color.lstrip("#") lv = len(hex_color) return tuple(int(hex_color[i : i + lv // 3], 16) for i in range(0, lv, lv // 3)) + ( 255, ) def apply_bg_image(image, background_file=None, background_color=None, bg_type="Color"): try: image_data = image.read() input_image = Image.open(io.BytesIO(image_data)) origin = input_image.copy() color_profile = input_image.info.get("icc_profile") if background_file is not None: background_image = Image.open(io.BytesIO(background_file.read())) else: background_image = None if bg_type == "Color": background_image = Image.new("RGBA", input_image.size, hex_to_rgba(background_color)) elif bg_type == "Image" and background_image is not None: if background_image.size != input_image.size: background_image = background_image.resize(input_image.size) if input_image.format == "GIF": frames = [] for frame in ImageSequence.Iterator(input_image): frame = frame.convert("RGBA") output_frame = apply_background(frame, background_image) frames.append(output_frame) output_image = io.BytesIO() frames[0].save( output_image, format="GIF", save_all=True, append_images=frames[1:], loop=0, icc_profile=color_profile, ) output_image_base64 = base64.b64encode(output_image.getvalue()).decode("utf-8") else: output_image = apply_background(input_image, background_image) buffered = io.BytesIO() output_image.save(buffered, format="PNG", optimize=True, icc_profile=color_profile) output_image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8") output_image_data = base64.b64decode(output_image_base64) return (Image.open(io.BytesIO(output_image_data)), origin) except Exception as e: return str(e) # Gradio interface with gr.Blocks(theme=gr.themes.Ocean()) as demo: gr.Markdown("# Image and Video Background Remover & Changer\n\nRemove or apply background to images and videos.") with gr.Tab("Remove Image Background"): with gr.Row(): image_input = gr.Image(label="Upload Image", interactive=True) slider = ImageSlider(label="Processed Image", type="pil") remove_button = gr.Button("Remove Image Background", interactive=True) examples = gr.Examples( [ load_img( "https://images.rawpixel.com/image_800/cHJpdmF0ZS9sci9pbWFnZXMvd2Vic2l0ZS8yMDIzLTA4L3Jhd3BpeGVsX29mZmljZV8yX3Bob3RvX29mX2FfbGlvbl9pc29sYXRlZF9vbl9jb2xvcl9iYWNrZ3JvdW5kXzJhNzgwMjM1LWRlYTgtNDMyOS04OWVjLTY3ZWMwNjcxZDhiMV8xLmpwZw.jpg", output_type="pil", ) ], inputs=image_input, fn=remove_bg_fn, outputs=slider, cache_examples=True, cache_mode="eager", ) remove_button.click(remove_bg_fn, inputs=image_input, outputs=slider) with gr.Tab("Apply Background to Image"): with gr.Row(): image_input = gr.Image(label="Upload Image", interactive=True) slider = ImageSlider(label="Processed Image", type="pil") apply_button = gr.Button("Apply Background", interactive=True) with gr.Row(): bg_type = gr.Radio( ["Color", "Image"], label="Background Type", value="Color", interactive=True, ) color_picker = gr.ColorPicker( label="Background Color", value="#00FF00", visible=True, interactive=True, ) bg_image = gr.Image( label="Background Image", type="filepath", visible=False, interactive=True, ) def update_visibility(bg_type): if bg_type == "Color": return ( gr.update(visible=True), gr.update(visible=False), ) elif bg_type == "Image": return ( gr.update(visible=False), gr.update(visible=True), ) else: return ( gr.update(visible=False), gr.update(visible=False), ) bg_type.change( update_visibility, inputs=bg_type, outputs=[color_picker, bg_image], ) examples = gr.Examples( [ ["https://pngimg.com/d/mario_PNG125.png", None, "#0cfa38", "Color"], [ "https://pngimg.com/d/mario_PNG125.png", "https://cdn.photoroom.com/v2/image-cache?path=gs://background-7ef44.appspot.com/backgrounds_v3/black/47_-_black.jpg", None, "Image", ], ], inputs=[image_input, bg_image, color_picker, bg_type], fn=apply_bg_image, outputs=slider, cache_examples=True, cache_mode="eager", ) apply_button.click( apply_bg_image, inputs=[image_input, bg_image, color_picker, bg_type], outputs= slider, ) with gr.Tab("Remove Video Background"): with gr.Row(): in_video = gr.Video(label="Input Video", interactive=True) stream_image = gr.Image(label="Streaming Output", visible=False) out_video = gr.Video(label="Final Output Video") submit_button = gr.Button("Change Background", interactive=True) with gr.Row(): fps_slider = gr.Slider( minimum=0, maximum=60, step=1, value=0, label="Output FPS (0 will inherit the original fps value)", interactive=True, ) bg_type = gr.Radio( ["Color", "Image", "Video"], label="Background Type", value="Color", interactive=True, ) color_picker = gr.ColorPicker( label="Background Color", value="#00FF00", visible=True, interactive=True, ) bg_image = gr.Image( label="Background Image", type="filepath", visible=False, interactive=True, ) bg_video = gr.Video( label="Background Video", visible=False, interactive=True ) with gr.Column(visible=False) as video_handling_options: video_handling_radio = gr.Radio( ["slow_down", "loop"], label="Video Handling", value="slow_down", interactive=True, ) fast_mode_checkbox = gr.Checkbox( label="Fast Mode (Use BiRefNet_lite)", value=True, interactive=True ) max_workers_slider = gr.Slider( minimum=1, maximum=32, step=1, value=6, label="Max Workers", info="Determines how many frames to process in parallel", interactive=True, ) time_textbox = gr.Textbox(label="Time Elapsed", interactive=False) def update_visibility(bg_type): if bg_type == "Color": return ( gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), ) elif bg_type == "Image": return ( gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), ) elif bg_type == "Video": return ( gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), ) else: return ( gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), ) bg_type.change( update_visibility, inputs=bg_type, outputs=[color_picker, bg_image, bg_video, video_handling_options], ) examples = gr.Examples( [ [ "https://www.w3schools.com/html/mov_bbb.mp4", "Video", None, "https://www.w3schools.com/howto/rain.mp4", ], [ "https://www.w3schools.com/html/mov_bbb.mp4", "Image", "https://cdn.photoroom.com/v2/image-cache?path=gs://background-7ef44.appspot.com/backgrounds_v3/black/47_-_black.jpg", None, ], ["https://www.w3schools.com/html/mov_bbb.mp4", "Color", None, None], ], inputs=[in_video, bg_type, bg_image, bg_video], outputs=[stream_image, out_video, time_textbox], fn=remove_bg_video, cache_examples=True, cache_mode="eager", ) submit_button.click( remove_bg_video, inputs=[ in_video, bg_type, bg_image, bg_video, color_picker, fps_slider, video_handling_radio, fast_mode_checkbox, max_workers_slider, ], outputs=[stream_image, out_video, time_textbox], ) if __name__ == "__main__": demo.launch(show_error=True)