import gradio as gr import spaces import torch import uuid from PIL import Image from torchvision import transforms from transformers import AutoModelForImageSegmentation from typing import Union, List from loadimg import load_img # Your helper to load from URL or file torch.set_float32_matmul_precision("high") # Load BiRefNet model birefnet = AutoModelForImageSegmentation.from_pretrained( "ZhengPeng7/BiRefNet", trust_remote_code=True ) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") birefnet.to(device) # Image transformation transform_image = transforms.Compose([ transforms.Resize((1024, 1024)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) @spaces.GPU def process(image: Image.Image) -> Image.Image: image_size = image.size input_tensor = transform_image(image).unsqueeze(0).to(device) with torch.no_grad(): preds = birefnet(input_tensor)[-1].sigmoid().cpu() pred = preds[0].squeeze() mask = transforms.ToPILImage()(pred).resize(image_size).convert("L") binary_mask = mask.point(lambda p: 255 if p > 127 else 0) white_bg = Image.new("RGB", image_size, (255, 255, 255)) result = Image.composite(image, white_bg, binary_mask) return result @spaces.GPU def handler(image=None, image_url=None, batch_urls=None) -> Union[str, List[str], None]: results = [] try: # Single image upload if image is not None: image = image.convert("RGB") processed = process(image) filename = f"output_{uuid.uuid4().hex[:8]}.png" processed.save(filename) return filename # Single image from URL if image_url: im = load_img(image_url, output_type="pil").convert("RGB") processed = process(im) filename = f"output_{uuid.uuid4().hex[:8]}.png" processed.save(filename) return filename # Batch of URLs if batch_urls: urls = [u.strip() for u in batch_urls.split(",") if u.strip()] for url in urls: try: im = load_img(url, output_type="pil").convert("RGB") processed = process(im) filename = f"output_{uuid.uuid4().hex[:8]}.png" processed.save(filename) results.append(filename) except Exception as e: print(f"Error with {url}: {e}") return results if results else None except Exception as e: print("General error:", e) return None # Interface demo = gr.Interface( fn=handler, inputs=[ gr.Image(label="Upload Image", type="pil"), gr.Textbox(label="Paste Image URL"), gr.Textbox(label="Comma-separated Image URLs (Batch)"), ], outputs=gr.File(label="Output File(s)", file_count="multiple"), title="Background Remover (White Fill)", description="Upload an image, paste a URL, or send a batch of URLs to remove the background and replace it with white.", ) if __name__ == "__main__": demo.launch(show_error=True, mcp_server=True)