import gradio as gr import torch import uuid from PIL import Image from torchvision import transforms from transformers import AutoModelForImageSegmentation from typing import Union, List from loadimg import load_img # Your helper to load from URL or file torch.set_float32_matmul_precision("high") # Load BiRefNet model birefnet = AutoModelForImageSegmentation.from_pretrained( "ZhengPeng7/BiRefNet", trust_remote_code=True ) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") birefnet.to(device) # Image transformation transform_image = transforms.Compose([ transforms.Resize((1024, 1024)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) def process(image: Image.Image) -> Image.Image: image_size = image.size input_tensor = transform_image(image).unsqueeze(0).to(device) with torch.no_grad(): preds = birefnet(input_tensor)[-1].sigmoid().cpu() pred = preds[0].squeeze() mask = transforms.ToPILImage()(pred).resize(image_size).convert("L") binary_mask = mask.point(lambda p: 255 if p > 127 else 0) white_bg = Image.new("RGB", image_size, (255, 255, 255)) result = Image.composite(image, white_bg, binary_mask) return result def handler(image=None, image_url=None, batch_urls=None) -> Union[str, List[str], None]: results = [] try: # Single image upload if image is not None: image = image.convert("RGB") processed = process(image) filename = f"output_{uuid.uuid4().hex[:8]}.png" processed.save(filename) return filename # Single image from URL if image_url: im = load_img(image_url, output_type="pil").convert("RGB") processed = process(im) filename = f"output_{uuid.uuid4().hex[:8]}.png" processed.save(filename) return filename # Batch of URLs if batch_urls: urls = [u.strip() for u in batch_urls.split(",") if u.strip()] for url in urls: try: im = load_img(url, output_type="pil").convert("RGB") processed = process(im) filename = f"output_{uuid.uuid4().hex[:8]}.png" processed.save(filename) results.append(filename) except Exception as e: print(f"Error with {url}: {e}") return results if results else None except Exception as e: print("General error:", e) return None # Interface demo = gr.Interface( fn=handler, inputs=[ gr.Image(label="Upload Image", type="pil"), gr.Textbox(label="Paste Image URL"), gr.Textbox(label="Comma-separated Image URLs (Batch)"), ], outputs=gr.File(label="Output File(s)", file_count="multiple"), title="Background Remover (White Fill)", description="Upload an image, paste a URL, or send a batch of URLs to remove the background and replace it with white.", ) if __name__ == "__main__": demo.launch(show_error=True, mcp_server=True)