Spaces:
fantos
/
Runtime error

File size: 3,773 Bytes
dd67556
 
 
 
 
 
 
 
 
fc9e8ce
 
dd67556
 
 
fc9e8ce
dd67556
fc9e8ce
 
dd67556
fc9e8ce
dd67556
 
 
 
5801493
e84892b
dc146a8
 
 
 
 
fc9e8ce
dd67556
fc9e8ce
dc146a8
e84892b
fc9e8ce
dc146a8
 
a6a92bb
dc146a8
e84892b
a6a92bb
fc9e8ce
a6a92bb
 
 
 
 
 
 
 
 
 
dc146a8
 
fc9e8ce
e84892b
 
 
dd67556
e84892b
ca70732
a6a92bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e84892b
 
597ebe6
ca70732
 
 
e84892b
 
ca70732
 
5801493
 
ca70732
e84892b
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import numpy as np
import torch
import torch.nn.functional as F
from torchvision.transforms.functional import normalize
from huggingface_hub import hf_hub_download
import gradio as gr
from briarmbg import BriaRMBG
from PIL import Image

# ๋ชจ๋ธ ์ดˆ๊ธฐํ™” ๋ฐ ๋กœ๋“œ
net = BriaRMBG()
model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth')
if torch.cuda.is_available():
    net.load_state_dict(torch.load(model_path))
    net = net.cuda()
else:
    net.load_state_dict(torch.load(model_path, map_location="cpu"))
net.eval()

def resize_image(image, model_input_size=(1024, 1024)):
    image = image.convert('RGB')
    image = image.resize(model_input_size, Image.BILINEAR)
    return image

def process(image, background_image=None):
    # ์ด๋ฏธ์ง€ ์ค€๋น„
    orig_image = Image.fromarray(image).convert("RGBA")
    w, h = orig_im_size = orig_image.size
    image = resize_image(orig_image)
    im_np = np.array(image)
    im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0) / 255.0
    im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
    if torch.cuda.is_available():
        im_tensor = im_tensor.cuda()

    # ์ถ”๋ก 
    with torch.no_grad():
        result = net(im_tensor)
    
# ํ›„์ฒ˜๋ฆฌ
    result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode='bilinear', align_corners=False), 0)
    result = torch.sigmoid(result)
    mask = (result * 255).byte().cpu().numpy()  # ๋งˆ์Šคํฌ๋ฅผ 0~255 ์‚ฌ์ด์˜ ๊ฐ’์œผ๋กœ ๋ณ€ํ™˜

# mask ๋ฐฐ์—ด์ด ์˜ˆ์ƒ๋Œ€๋กœ 2์ฐจ์›์ธ์ง€ ํ™•์ธํ•˜๊ณ , ์•„๋‹ˆ๋ผ๋ฉด ์กฐ์ •ํ•ฉ๋‹ˆ๋‹ค.
    if mask.ndim > 2:
        mask = mask.squeeze()  # ์ฐจ์› ์ถ•์†Œ

# mask ๋ฐฐ์—ด์„ ๋ช…ํ™•ํžˆ uint8๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
    mask = mask.astype(np.uint8)

# mask๋ฅผ PIL ์ด๋ฏธ์ง€๋กœ ๋ณ€ํ™˜
    mask_image = Image.fromarray(mask, 'L')  # 'L' ๋ชจ๋“œ๋Š” ๊ทธ๋ ˆ์ด์Šค์ผ€์ผ ์ด๋ฏธ์ง€๋ฅผ ๋‚˜ํƒ€๋ƒ…๋‹ˆ๋‹ค.
    
    final_image = Image.new("RGBA", orig_image.size)
    final_image.paste(orig_image, mask=mask_image)

    # ์„ ํƒ์  ๋ฐฐ๊ฒฝ ์ด๋ฏธ์ง€ ์ฒ˜๋ฆฌ
    if background_image is not None:
        final_image = merge_images(background_image, final_image)

    return final_image

def merge_images(background_image, foreground_image):
    """
    ๋ฐฐ๊ฒฝ ์ด๋ฏธ์ง€์— ๋ฐฐ๊ฒฝ์ด ์ œ๊ฑฐ๋œ ์ด๋ฏธ์ง€๋ฅผ ํˆฌ๋ช…ํ•˜๊ฒŒ ์‚ฝ์ž…ํ•ฉ๋‹ˆ๋‹ค.
    ๋ฐฐ๊ฒฝ์ด ์ œ๊ฑฐ๋œ ์ด๋ฏธ์ง€๋Š” ๋ฐฐ๊ฒฝ ์ด๋ฏธ์ง€ ์ค‘์•™์— 30% ํฌ๊ธฐ๋กœ ์ถ•์†Œ๋˜์–ด ์‚ฝ์ž…๋ฉ๋‹ˆ๋‹ค.
    """
    background = background_image.convert("RGBA")
    foreground = foreground_image.convert("RGBA")
    
    # ์ „๊ฒฝ ์ด๋ฏธ์ง€๋ฅผ ๋ฐฐ๊ฒฝ ์ด๋ฏธ์ง€์˜ 30% ํฌ๊ธฐ๋กœ ์กฐ์ •
    scale_factor = 0.3
    foreground_width = int(background.width * scale_factor)
    foreground_height = int(foreground.height * foreground_width / foreground.width)
    new_size = (foreground_width, foreground_height)
    foreground_resized = foreground.resize(new_size, Image.Resampling.LANCZOS)
    
    # ์ „๊ฒฝ ์ด๋ฏธ์ง€๋ฅผ ๋ฐฐ๊ฒฝ ์ด๋ฏธ์ง€์˜ ๊ฐ€์šด๋ฐ์— ์œ„์น˜์‹œํ‚ค๊ธฐ ์œ„ํ•œ ์ขŒํ‘œ ๊ณ„์‚ฐ
    x = (background.width - foreground_width) // 2
    y = (background.height - foreground_height) // 2
    
    # ๋ฐฐ๊ฒฝ ์ด๋ฏธ์ง€ ์œ„์— ์ „๊ฒฝ ์ด๋ฏธ์ง€๋ฅผ ๋ถ™์ž„
    background.paste(foreground_resized, (x, y), foreground_resized)
    
    return background


title = "Background Removal"
description = "This is a demo for BRIA RMBG 1.4 using the BRIA RMBG-1.4 image matting model as backbone."

demo = gr.Interface(
    fn=process,
    inputs=[
        gr.Image(type="numpy", label="Image to remove background"),
        gr.Image(type="pil", label="Optional: Background Image")  # 'optional=True'๋Š” Gradio ๋ฒ„์ „์— ๋”ฐ๋ผ ํ•„์š” ์—†์„ ์ˆ˜ ์žˆ์Œ
    ],
    outputs="image",
    title=title,
    description=description
)

if __name__ == "__main__":
    demo.launch()