BiRefNet_demo / app.py
ghostsInTheMachine's picture
Update app.py
648f19e verified
raw
history blame
5.06 kB
import os
import cv2
import numpy as np
import torch
import gradio as gr
import spaces
from PIL import Image, ImageOps
from transformers import AutoModelForImageSegmentation
from torchvision import transforms
torch.set_float32_matmul_precision('high')
torch.jit.script = lambda f: f
device = "cuda" if torch.cuda.is_available() else "cpu"
def refine_foreground(image, mask, r=90):
if mask.size != image.size:
mask = mask.resize(image.size)
image = np.array(image) / 255.0
mask = np.array(mask) / 255.0
estimated_foreground = FB_blur_fusion_foreground_estimator_2(image, mask, r=r)
image_masked = Image.fromarray((estimated_foreground * 255.0).astype(np.uint8))
return image_masked
def FB_blur_fusion_foreground_estimator_2(image, alpha, r=90):
alpha = alpha[:, :, None]
F, blur_B = FB_blur_fusion_foreground_estimator(
image, image, image, alpha, r)
return FB_blur_fusion_foreground_estimator(image, F, blur_B, alpha, r=6)[0]
def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90):
if isinstance(image, Image.Image):
image = np.array(image) / 255.0
blurred_alpha = cv2.blur(alpha, (r, r))[:, :, None]
blurred_FA = cv2.blur(F * alpha, (r, r))
blurred_F = blurred_FA / (blurred_alpha + 1e-5)
blurred_B1A = cv2.blur(B * (1 - alpha), (r, r))
blurred_B = blurred_B1A / ((1 - blurred_alpha) + 1e-5)
F = blurred_F + alpha * \
(image - alpha * blurred_F - (1 - alpha) * blurred_B)
F = np.clip(F, 0, 1)
return F, blurred_B
class ImagePreprocessor():
def __init__(self, resolution=(1024, 1024)) -> None:
self.transform_image = transforms.Compose([
transforms.Resize(resolution),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
def proc(self, image: Image.Image) -> torch.Tensor:
image = self.transform_image(image)
return image
birefnet = AutoModelForImageSegmentation.from_pretrained('zhengpeng7/BiRefNet-matting', trust_remote_code=True)
birefnet.to(device)
birefnet.eval()
@spaces.GPU
def remove_background(image):
if image is None:
raise gr.Error("Please upload an image.")
image_ori = Image.fromarray(image).convert('RGB')
original_size = image_ori.size
# Preprocess the image
image_preprocessor = ImagePreprocessor(resolution=(1024, 1024))
image_proc = image_preprocessor.proc(image_ori)
image_proc = image_proc.unsqueeze(0)
# Prediction
with torch.no_grad():
preds = birefnet(image_proc.to(device))[-1].sigmoid().cpu()
pred = preds[0].squeeze()
# Process Results
pred_pil = transforms.ToPILImage()(pred)
pred_pil = pred_pil.resize(original_size, Image.BICUBIC) # Resize mask to original size
# Create reverse mask
reverse_mask = Image.new('L', original_size)
reverse_mask.paste(ImageOps.invert(pred_pil))
# Create foreground image (object with transparent background)
foreground = image_ori.copy()
foreground.putalpha(pred_pil)
# Create background image
background = image_ori.copy()
background.putalpha(reverse_mask)
torch.cuda.empty_cache()
# Save all images
mask_path = "mask.png"
pred_pil.save(mask_path)
reverse_mask_path = "reverse_mask.png"
reverse_mask.save(reverse_mask_path)
foreground_path = "foreground.png"
foreground.save(foreground_path)
background_path = "background.png"
background.save(background_path)
return mask_path, reverse_mask_path, foreground_path, background_path
css = """
body {
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif, "Apple Color Emoji", "Segoe UI Emoji", "Segoe UI Symbol";
}
.gradio-container {
background: white;
}
.gradio-button {
font-family: inherit;
font-size: 16px;
font-weight: bold;
color: #000000;
background: linear-gradient(
135deg,
#e0f7fa, #e8f5e9, #fff9c4, #ffebee,
#f3e5f5, #e1f5fe, #fff3e0, #e8eaf6
);
background-size: 400% 400%;
animation: gradient-animation 15s ease infinite;
border: 2px solid black;
border-radius: 10px;
}
.gradio-button:hover {
background: linear-gradient(
135deg,
#b2ebf2, #c8e6c9, #fff176, #ffcdd2,
#e1bee7, #b3e5fc, #ffe0b2, #c5cae9
);
background-size: 400% 400%;
animation: gradient-animation 15s ease infinite;
}
@keyframes gradient-animation {
0% { background-position: 0% 50%; }
50% { background-position: 100% 50%; }
100% { background-position: 0% 50%; }
}
"""
iface = gr.Interface(
fn=remove_background,
inputs=gr.Image(type="numpy"),
outputs=[
gr.Image(type="filepath", label="Mask"),
gr.Image(type="filepath", label="Reverse Mask"),
gr.Image(type="filepath", label="Foreground"),
gr.Image(type="filepath", label="Background")
],
allow_flagging="never",
css=css
)
if __name__ == "__main__":
iface.launch(debug=True)