File size: 3,876 Bytes
dd67556 a119d24 dd67556 a119d24 dd67556 a119d24 dd67556 fc9e8ce dd67556 fc9e8ce dd67556 fc9e8ce dd67556 fc9e8ce dd67556 5801493 e84892b a119d24 dc146a8 fc9e8ce dd67556 fc9e8ce dc146a8 e84892b fc9e8ce dc146a8 a119d24 dc146a8 e84892b a119d24 fc9e8ce a119d24 a6a92bb a119d24 a6a92bb a119d24 a6a92bb a119d24 fc9e8ce e84892b a119d24 dd67556 e84892b ca70732 a119d24 a6a92bb e84892b 597ebe6 ca70732 e84892b ca70732 5801493 ca70732 e84892b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
import numpy as np
import torch
import torch.nn.functional as F
from torchvision.transforms.functional import normalize
from huggingface_hub import hf_hub_download
import gradio as gr
from gradio_imageslider import ImageSlider
from briarmbg import BriaRMBG
import PIL
from PIL import Image
from typing import Tuple
# ๋ชจ๋ธ ์ด๊ธฐํ ๋ฐ ๋ก๋
net = BriaRMBG()
model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth')
if torch.cuda.is_available():
net.load_state_dict(torch.load(model_path))
net = net.cuda()
else:
net.load_state_dict(torch.load(model_path, map_location="cpu"))
net.eval()
def resize_image(image, model_input_size=(1024, 1024)):
image = image.convert('RGB')
image = image.resize(model_input_size, Image.BILINEAR)
return image
def process(image, background_image=None):
# ์ด๋ฏธ์ง ์ค๋น
orig_image = Image.fromarray(image).convert("RGB")
w, h = orig_im_size = orig_image.size
image = resize_image(orig_image)
im_np = np.array(image)
im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0) / 255.0
im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
if torch.cuda.is_available():
im_tensor = im_tensor.cuda()
# ์ถ๋ก
with torch.no_grad():
result = net(im_tensor)
# ํ์ฒ๋ฆฌ
result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode='bilinear', align_corners=False), 0)
result = torch.sigmoid(result)
mask = (result * 255).byte().cpu().numpy()
# mask ๋ฐฐ์ด์ด ์์๋๋ก 2์ฐจ์์ธ์ง ํ์ธํ๊ณ , ์๋๋ผ๋ฉด ์กฐ์
if mask.ndim > 2:
mask = mask.squeeze()
# mask ๋ฐฐ์ด์ ๋ช
ํํ uint8๋ก ๋ณํ
mask = mask.astype(np.uint8)
# ๋ง์คํฌ๋ฅผ ์ํ ์ฑ๋๋ก ์ฌ์ฉํ์ฌ ์ต์ข
์ด๋ฏธ์ง ์์ฑ
orig_image = orig_image.convert("RGBA")
final_image = Image.new("RGBA", orig_image.size, (0, 0, 0, 0))
mask_image = Image.fromarray(mask, mode='L')
foreground_image = Image.composite(orig_image, final_image, mask_image)
# ์ ํ์ ๋ฐฐ๊ฒฝ ์ด๋ฏธ์ง ์ฒ๋ฆฌ
if background_image:
final_image = merge_images(background_image, foreground_image)
else:
final_image = foreground_image
return final_image
def merge_images(background_image, foreground_image):
"""
๋ฐฐ๊ฒฝ ์ด๋ฏธ์ง์ ๋ฐฐ๊ฒฝ์ด ์ ๊ฑฐ๋ ์ด๋ฏธ์ง๋ฅผ ํฌ๋ช
ํ๊ฒ ์ฝ์
ํฉ๋๋ค.
๋ฐฐ๊ฒฝ์ด ์ ๊ฑฐ๋ ์ด๋ฏธ์ง๋ ๋ฐฐ๊ฒฝ ์ด๋ฏธ์ง ์ค์์ 30% ํฌ๊ธฐ๋ก ์ถ์๋์ด ์ฝ์
๋ฉ๋๋ค.
"""
background = background_image.convert("RGBA")
foreground = foreground_image.convert("RGBA")
# ์ ๊ฒฝ ์ด๋ฏธ์ง๋ฅผ ๋ฐฐ๊ฒฝ ์ด๋ฏธ์ง์ 30% ํฌ๊ธฐ๋ก ์กฐ์
scale_factor = 0.3
foreground_width = int(background.width * scale_factor)
foreground_height = int(foreground.height * foreground_width / foreground.width)
new_size = (foreground_width, foreground_height)
foreground_resized = foreground.resize(new_size, Image.Resampling.LANCZOS)
# ์ ๊ฒฝ ์ด๋ฏธ์ง๋ฅผ ๋ฐฐ๊ฒฝ ์ด๋ฏธ์ง์ ๊ฐ์ด๋ฐ์ ์์น์ํค๊ธฐ ์ํ ์ขํ ๊ณ์ฐ
x = (background.width - foreground_width) // 2
y = (background.height - foreground_height) // 2
# ๋ฐฐ๊ฒฝ ์ด๋ฏธ์ง ์์ ์ ๊ฒฝ ์ด๋ฏธ์ง๋ฅผ ๋ถ์
background.paste(foreground_resized, (x, y), foreground_resized)
return background
title = "Background Removal"
description = "This is a demo for BRIA RMBG 1.4 using the BRIA RMBG-1.4 image matting model as backbone."
demo = gr.Interface(
fn=process,
inputs=[
gr.Image(type="numpy", label="Image to remove background"),
gr.Image(type="pil", label="Optional: Background Image") # 'optional=True'๋ Gradio ๋ฒ์ ์ ๋ฐ๋ผ ํ์ ์์ ์ ์์
],
outputs="image",
title=title,
description=description
)
if __name__ == "__main__":
demo.launch()
|