|
import torch |
|
from transformers import AutoModelForImageSegmentation |
|
from PIL import Image |
|
from torchvision import transforms |
|
import gradio as gr |
|
|
|
|
|
birefnet = AutoModelForImageSegmentation.from_pretrained('zhengpeng7/BiRefNet_lite', trust_remote_code=True) |
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
birefnet.to(device) |
|
birefnet.eval() |
|
|
|
|
|
image_size = (1024, 1024) |
|
transform_image = transforms.Compose([ |
|
transforms.Resize(image_size), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
|
]) |
|
|
|
def extract_object(image): |
|
input_images = transform_image(image).unsqueeze(0).to(device) |
|
with torch.no_grad(): |
|
preds = birefnet(input_images)[-1].sigmoid().cpu() |
|
pred = preds[0].squeeze() |
|
pred_pil = transforms.ToPILImage()(pred) |
|
mask = pred_pil.resize(image.size) |
|
image_with_alpha = image.convert("RGBA") |
|
image_with_alpha.putalpha(mask) |
|
return image_with_alpha |
|
|
|
iface = gr.Interface( |
|
fn=extract_object, |
|
inputs=gr.Image(type="pil", label="Upload Image"), |
|
outputs=gr.Image(type="pil", label="Segmented Image"), |
|
title="BiRefNet Background Removal", |
|
description="Upload an image and get the foreground object extracted." |
|
) |
|
|
|
if __name__ == "__main__": |
|
iface.launch() |