|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
from torchvision.transforms.functional import normalize |
|
from huggingface_hub import hf_hub_download |
|
import gradio as gr |
|
from gradio_imageslider import ImageSlider |
|
from briarmbg import BriaRMBG |
|
import PIL |
|
from PIL import Image |
|
from typing import Tuple |
|
|
|
|
|
net = BriaRMBG() |
|
model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth') |
|
if torch.cuda.is_available(): |
|
net.load_state_dict(torch.load(model_path)) |
|
net = net.cuda() |
|
else: |
|
net.load_state_dict(torch.load(model_path, map_location="cpu")) |
|
net.eval() |
|
|
|
def resize_image(image, model_input_size=(1024, 1024)): |
|
image = image.convert('RGB') |
|
image = image.resize(model_input_size, Image.BILINEAR) |
|
return image |
|
|
|
def process(image, background_image=None): |
|
|
|
orig_image = Image.fromarray(image).convert("RGB") |
|
w, h = orig_im_size = orig_image.size |
|
image = resize_image(orig_image) |
|
im_np = np.array(image) |
|
im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0) / 255.0 |
|
im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0]) |
|
if torch.cuda.is_available(): |
|
im_tensor = im_tensor.cuda() |
|
|
|
|
|
with torch.no_grad(): |
|
result = net(im_tensor) |
|
|
|
|
|
result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode='bilinear', align_corners=False), 0) |
|
result = torch.sigmoid(result) |
|
mask = (result * 255).byte().cpu().numpy() |
|
|
|
|
|
if mask.ndim > 2: |
|
mask = mask.squeeze() |
|
|
|
|
|
mask = mask.astype(np.uint8) |
|
|
|
|
|
orig_image = orig_image.convert("RGBA") |
|
final_image = Image.new("RGBA", orig_image.size, (0, 0, 0, 0)) |
|
mask_image = Image.fromarray(mask, mode='L') |
|
foreground_image = Image.composite(orig_image, final_image, mask_image) |
|
|
|
|
|
if background_image: |
|
final_image = merge_images(background_image, foreground_image) |
|
else: |
|
final_image = foreground_image |
|
|
|
return final_image |
|
|
|
|
|
def merge_images(background_image, foreground_image): |
|
""" |
|
๋ฐฐ๊ฒฝ ์ด๋ฏธ์ง์ ๋ฐฐ๊ฒฝ์ด ์ ๊ฑฐ๋ ์ด๋ฏธ์ง๋ฅผ ํฌ๋ช
ํ๊ฒ ์ฝ์
ํฉ๋๋ค. |
|
๋ฐฐ๊ฒฝ์ด ์ ๊ฑฐ๋ ์ด๋ฏธ์ง๋ ๋ฐฐ๊ฒฝ ์ด๋ฏธ์ง ์ค์์ 30% ํฌ๊ธฐ๋ก ์ถ์๋์ด ์ฝ์
๋ฉ๋๋ค. |
|
""" |
|
background = background_image.convert("RGBA") |
|
foreground = foreground_image.convert("RGBA") |
|
|
|
|
|
scale_factor = 0.3 |
|
foreground_width = int(background.width * scale_factor) |
|
foreground_height = int(foreground.height * foreground_width / foreground.width) |
|
new_size = (foreground_width, foreground_height) |
|
foreground_resized = foreground.resize(new_size, Image.Resampling.LANCZOS) |
|
|
|
|
|
x = (background.width - foreground_width) // 2 |
|
y = (background.height - foreground_height) // 2 |
|
|
|
|
|
background.paste(foreground_resized, (x, y), foreground_resized) |
|
|
|
return background |
|
|
|
|
|
title = "Background Removal" |
|
description = "This is a demo for BRIA RMBG 1.4 using the BRIA RMBG-1.4 image matting model as backbone." |
|
|
|
demo = gr.Interface( |
|
fn=process, |
|
inputs=[ |
|
gr.Image(type="numpy", label="Image to remove background"), |
|
gr.Image(type="pil", label="Optional: Background Image") |
|
], |
|
outputs="image", |
|
title=title, |
|
description=description |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|