user-agent's picture
Update app.py
c93abb9 verified
import gradio as gr
import spaces
import torch
import uuid
from PIL import Image
from torchvision import transforms
from transformers import AutoModelForImageSegmentation
from typing import Union, List
from loadimg import load_img # Your helper to load from URL or file
torch.set_float32_matmul_precision("high")
# Load BiRefNet model
birefnet = AutoModelForImageSegmentation.from_pretrained(
"ZhengPeng7/BiRefNet", trust_remote_code=True
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
birefnet.to(device)
# Image transformation
transform_image = transforms.Compose([
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
@spaces.GPU
def process(image: Image.Image) -> Image.Image:
image_size = image.size
input_tensor = transform_image(image).unsqueeze(0).to(device)
with torch.no_grad():
preds = birefnet(input_tensor)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
mask = transforms.ToPILImage()(pred).resize(image_size).convert("L")
binary_mask = mask.point(lambda p: 255 if p > 127 else 0)
white_bg = Image.new("RGB", image_size, (255, 255, 255))
result = Image.composite(image, white_bg, binary_mask)
return result
@spaces.GPU
def handler(image=None, image_url=None, batch_urls=None) -> Union[str, List[str], None]:
results = []
try:
# Single image upload
if image is not None:
image = image.convert("RGB")
processed = process(image)
filename = f"output_{uuid.uuid4().hex[:8]}.png"
processed.save(filename)
return filename
# Single image from URL
if image_url:
im = load_img(image_url, output_type="pil").convert("RGB")
processed = process(im)
filename = f"output_{uuid.uuid4().hex[:8]}.png"
processed.save(filename)
return filename
# Batch of URLs
if batch_urls:
urls = [u.strip() for u in batch_urls.split(",") if u.strip()]
for url in urls:
try:
im = load_img(url, output_type="pil").convert("RGB")
processed = process(im)
filename = f"output_{uuid.uuid4().hex[:8]}.png"
processed.save(filename)
results.append(filename)
except Exception as e:
print(f"Error with {url}: {e}")
return results if results else None
except Exception as e:
print("General error:", e)
return None
# Interface
demo = gr.Interface(
fn=handler,
inputs=[
gr.Image(label="Upload Image", type="pil"),
gr.Textbox(label="Paste Image URL"),
gr.Textbox(label="Comma-separated Image URLs (Batch)"),
],
outputs=gr.File(label="Output File(s)", file_count="multiple"),
title="Background Remover (White Fill)",
description="Upload an image, paste a URL, or send a batch of URLs to remove the background and replace it with white.",
)
if __name__ == "__main__":
demo.launch(show_error=True, mcp_server=True)