File size: 3,773 Bytes
dd67556 fc9e8ce dd67556 fc9e8ce dd67556 fc9e8ce dd67556 fc9e8ce dd67556 5801493 e84892b dc146a8 fc9e8ce dd67556 fc9e8ce dc146a8 e84892b fc9e8ce dc146a8 a6a92bb dc146a8 e84892b a6a92bb fc9e8ce a6a92bb dc146a8 fc9e8ce e84892b dd67556 e84892b ca70732 a6a92bb e84892b 597ebe6 ca70732 e84892b ca70732 5801493 ca70732 e84892b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import numpy as np
import torch
import torch.nn.functional as F
from torchvision.transforms.functional import normalize
from huggingface_hub import hf_hub_download
import gradio as gr
from briarmbg import BriaRMBG
from PIL import Image
# ๋ชจ๋ธ ์ด๊ธฐํ ๋ฐ ๋ก๋
net = BriaRMBG()
model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth')
if torch.cuda.is_available():
net.load_state_dict(torch.load(model_path))
net = net.cuda()
else:
net.load_state_dict(torch.load(model_path, map_location="cpu"))
net.eval()
def resize_image(image, model_input_size=(1024, 1024)):
image = image.convert('RGB')
image = image.resize(model_input_size, Image.BILINEAR)
return image
def process(image, background_image=None):
# ์ด๋ฏธ์ง ์ค๋น
orig_image = Image.fromarray(image).convert("RGBA")
w, h = orig_im_size = orig_image.size
image = resize_image(orig_image)
im_np = np.array(image)
im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0) / 255.0
im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
if torch.cuda.is_available():
im_tensor = im_tensor.cuda()
# ์ถ๋ก
with torch.no_grad():
result = net(im_tensor)
# ํ์ฒ๋ฆฌ
result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode='bilinear', align_corners=False), 0)
result = torch.sigmoid(result)
mask = (result * 255).byte().cpu().numpy() # ๋ง์คํฌ๋ฅผ 0~255 ์ฌ์ด์ ๊ฐ์ผ๋ก ๋ณํ
# mask ๋ฐฐ์ด์ด ์์๋๋ก 2์ฐจ์์ธ์ง ํ์ธํ๊ณ , ์๋๋ผ๋ฉด ์กฐ์ ํฉ๋๋ค.
if mask.ndim > 2:
mask = mask.squeeze() # ์ฐจ์ ์ถ์
# mask ๋ฐฐ์ด์ ๋ช
ํํ uint8๋ก ๋ณํํฉ๋๋ค.
mask = mask.astype(np.uint8)
# mask๋ฅผ PIL ์ด๋ฏธ์ง๋ก ๋ณํ
mask_image = Image.fromarray(mask, 'L') # 'L' ๋ชจ๋๋ ๊ทธ๋ ์ด์ค์ผ์ผ ์ด๋ฏธ์ง๋ฅผ ๋ํ๋
๋๋ค.
final_image = Image.new("RGBA", orig_image.size)
final_image.paste(orig_image, mask=mask_image)
# ์ ํ์ ๋ฐฐ๊ฒฝ ์ด๋ฏธ์ง ์ฒ๋ฆฌ
if background_image is not None:
final_image = merge_images(background_image, final_image)
return final_image
def merge_images(background_image, foreground_image):
"""
๋ฐฐ๊ฒฝ ์ด๋ฏธ์ง์ ๋ฐฐ๊ฒฝ์ด ์ ๊ฑฐ๋ ์ด๋ฏธ์ง๋ฅผ ํฌ๋ช
ํ๊ฒ ์ฝ์
ํฉ๋๋ค.
๋ฐฐ๊ฒฝ์ด ์ ๊ฑฐ๋ ์ด๋ฏธ์ง๋ ๋ฐฐ๊ฒฝ ์ด๋ฏธ์ง ์ค์์ 30% ํฌ๊ธฐ๋ก ์ถ์๋์ด ์ฝ์
๋ฉ๋๋ค.
"""
background = background_image.convert("RGBA")
foreground = foreground_image.convert("RGBA")
# ์ ๊ฒฝ ์ด๋ฏธ์ง๋ฅผ ๋ฐฐ๊ฒฝ ์ด๋ฏธ์ง์ 30% ํฌ๊ธฐ๋ก ์กฐ์
scale_factor = 0.3
foreground_width = int(background.width * scale_factor)
foreground_height = int(foreground.height * foreground_width / foreground.width)
new_size = (foreground_width, foreground_height)
foreground_resized = foreground.resize(new_size, Image.Resampling.LANCZOS)
# ์ ๊ฒฝ ์ด๋ฏธ์ง๋ฅผ ๋ฐฐ๊ฒฝ ์ด๋ฏธ์ง์ ๊ฐ์ด๋ฐ์ ์์น์ํค๊ธฐ ์ํ ์ขํ ๊ณ์ฐ
x = (background.width - foreground_width) // 2
y = (background.height - foreground_height) // 2
# ๋ฐฐ๊ฒฝ ์ด๋ฏธ์ง ์์ ์ ๊ฒฝ ์ด๋ฏธ์ง๋ฅผ ๋ถ์
background.paste(foreground_resized, (x, y), foreground_resized)
return background
title = "Background Removal"
description = "This is a demo for BRIA RMBG 1.4 using the BRIA RMBG-1.4 image matting model as backbone."
demo = gr.Interface(
fn=process,
inputs=[
gr.Image(type="numpy", label="Image to remove background"),
gr.Image(type="pil", label="Optional: Background Image") # 'optional=True'๋ Gradio ๋ฒ์ ์ ๋ฐ๋ผ ํ์ ์์ ์ ์์
],
outputs="image",
title=title,
description=description
)
if __name__ == "__main__":
demo.launch()
|