user-agent's picture
Update app.py
461565d verified
raw
history blame
3.66 kB
import gradio as gr
import spaces
import torch
import uuid
from PIL import Image
from torchvision import transforms
from transformers import AutoModelForImageSegmentation
from typing import Union, List
from loadimg import load_img
torch.set_float32_matmul_precision("high")
# Load RMBG v1.4 model
model = AutoModelForImageSegmentation.from_pretrained(
"briaai/RMBG-1.4",
trust_remote_code=True
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Transform for RMBG v1.4
transform_image = transforms.Compose([
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
@spaces.GPU
def process(image: Image.Image) -> Image.Image:
image_size = image.size
input_tensor = transform_image(image).unsqueeze(0).to(device)
with torch.no_grad():
# RMBG v1.4 returns a tuple, we need the first element
preds = model(input_tensor)
# Handle different return types
if isinstance(preds, tuple):
pred = preds[0] # Take first element if tuple
elif isinstance(preds, list):
pred = preds[-1] # Take last element if list
else:
pred = preds
# Apply sigmoid and move to CPU
mask = pred.sigmoid().cpu()
# Process the mask
mask_tensor = mask[0].squeeze()
mask_pil = transforms.ToPILImage()(mask_tensor).resize(image_size).convert("L")
# Create binary mask with threshold
binary_mask = mask_pil.point(lambda p: 255 if p > 127 else 0)
# Apply mask with white background
white_bg = Image.new("RGB", image_size, (255, 255, 255))
result = Image.composite(image, white_bg, binary_mask)
return result
@spaces.GPU
def handler(image=None, image_url=None, batch_urls=None) -> Union[str, List[str], None]:
results = []
try:
if image is not None:
image = image.convert("RGB")
processed = process(image)
filename = f"output_{uuid.uuid4().hex[:8]}.png"
processed.save(filename)
return filename
if image_url:
im = load_img(image_url, output_type="pil").convert("RGB")
processed = process(im)
filename = f"output_{uuid.uuid4().hex[:8]}.png"
processed.save(filename)
return filename
if batch_urls:
urls = [u.strip() for u in batch_urls.split(",") if u.strip()]
for url in urls:
try:
im = load_img(url, output_type="pil").convert("RGB")
processed = process(im)
filename = f"output_{uuid.uuid4().hex[:8]}.png"
processed.save(filename)
results.append(filename)
except Exception as e:
print(f"Error with {url}: {e}")
return results if results else None
except Exception as e:
print("General error:", e)
# Add debug info
import traceback
traceback.print_exc()
return None
demo = gr.Interface(
fn=handler,
inputs=[
gr.Image(label="Upload Image", type="pil"),
gr.Textbox(label="Paste Image URL"),
gr.Textbox(label="Comma-separated Image URLs (Batch)"),
],
outputs=gr.File(label="Output File(s)", file_count="multiple"),
title="Background Remover (RMBG v1.4)",
description="Upload an image, paste a URL, or send a batch of URLs to remove the background and replace it with white.",
)
if __name__ == "__main__":
demo.launch(show_error=True, mcp_server=True)