Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import spaces | |
import torch | |
import uuid | |
from PIL import Image | |
from torchvision import transforms | |
from transformers import AutoModelForImageSegmentation | |
from typing import Union, List | |
from loadimg import load_img # Your helper to load from URL or file | |
torch.set_float32_matmul_precision("high") | |
# Load BiRefNet model | |
birefnet = AutoModelForImageSegmentation.from_pretrained( | |
"ZhengPeng7/BiRefNet", trust_remote_code=True | |
) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
birefnet.to(device) | |
# Image transformation | |
transform_image = transforms.Compose([ | |
transforms.Resize((1024, 1024)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
]) | |
def process(image: Image.Image) -> Image.Image: | |
image_size = image.size | |
input_tensor = transform_image(image).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
preds = birefnet(input_tensor)[-1].sigmoid().cpu() | |
pred = preds[0].squeeze() | |
mask = transforms.ToPILImage()(pred).resize(image_size).convert("L") | |
binary_mask = mask.point(lambda p: 255 if p > 127 else 0) | |
white_bg = Image.new("RGB", image_size, (255, 255, 255)) | |
result = Image.composite(image, white_bg, binary_mask) | |
return result | |
def handler(image=None, image_url=None, batch_urls=None) -> Union[str, List[str], None]: | |
results = [] | |
try: | |
# Single image upload | |
if image is not None: | |
image = image.convert("RGB") | |
processed = process(image) | |
filename = f"output_{uuid.uuid4().hex[:8]}.png" | |
processed.save(filename) | |
return filename | |
# Single image from URL | |
if image_url: | |
im = load_img(image_url, output_type="pil").convert("RGB") | |
processed = process(im) | |
filename = f"output_{uuid.uuid4().hex[:8]}.png" | |
processed.save(filename) | |
return filename | |
# Batch of URLs | |
if batch_urls: | |
urls = [u.strip() for u in batch_urls.split(",") if u.strip()] | |
for url in urls: | |
try: | |
im = load_img(url, output_type="pil").convert("RGB") | |
processed = process(im) | |
filename = f"output_{uuid.uuid4().hex[:8]}.png" | |
processed.save(filename) | |
results.append(filename) | |
except Exception as e: | |
print(f"Error with {url}: {e}") | |
return results if results else None | |
except Exception as e: | |
print("General error:", e) | |
return None | |
# Interface | |
demo = gr.Interface( | |
fn=handler, | |
inputs=[ | |
gr.Image(label="Upload Image", type="pil"), | |
gr.Textbox(label="Paste Image URL"), | |
gr.Textbox(label="Comma-separated Image URLs (Batch)"), | |
], | |
outputs=gr.File(label="Output File(s)", file_count="multiple"), | |
title="Background Remover (White Fill)", | |
description="Upload an image, paste a URL, or send a batch of URLs to remove the background and replace it with white.", | |
) | |
if __name__ == "__main__": | |
demo.launch(show_error=True, mcp_server=True) |