Spaces:
fantos
/
Runtime error

File size: 3,876 Bytes
dd67556
 
 
 
 
 
a119d24
dd67556
a119d24
dd67556
a119d24
dd67556
fc9e8ce
 
dd67556
 
 
fc9e8ce
dd67556
fc9e8ce
 
dd67556
fc9e8ce
dd67556
 
 
 
5801493
e84892b
a119d24
dc146a8
 
 
 
fc9e8ce
dd67556
fc9e8ce
dc146a8
e84892b
fc9e8ce
dc146a8
 
a119d24
dc146a8
e84892b
a119d24
fc9e8ce
a119d24
a6a92bb
a119d24
a6a92bb
a119d24
a6a92bb
 
a119d24
 
 
 
 
fc9e8ce
e84892b
a119d24
 
 
 
dd67556
e84892b
ca70732
a119d24
a6a92bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e84892b
 
597ebe6
ca70732
 
 
e84892b
 
ca70732
 
5801493
 
ca70732
e84892b
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import numpy as np
import torch
import torch.nn.functional as F
from torchvision.transforms.functional import normalize
from huggingface_hub import hf_hub_download
import gradio as gr
from gradio_imageslider import ImageSlider
from briarmbg import BriaRMBG
import PIL
from PIL import Image
from typing import Tuple

# ๋ชจ๋ธ ์ดˆ๊ธฐํ™” ๋ฐ ๋กœ๋“œ
net = BriaRMBG()
model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth')
if torch.cuda.is_available():
    net.load_state_dict(torch.load(model_path))
    net = net.cuda()
else:
    net.load_state_dict(torch.load(model_path, map_location="cpu"))
net.eval()

def resize_image(image, model_input_size=(1024, 1024)):
    image = image.convert('RGB')
    image = image.resize(model_input_size, Image.BILINEAR)
    return image

def process(image, background_image=None):
    # ์ด๋ฏธ์ง€ ์ค€๋น„
    orig_image = Image.fromarray(image).convert("RGB")
    w, h = orig_im_size = orig_image.size
    image = resize_image(orig_image)
    im_np = np.array(image)
    im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0) / 255.0
    im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
    if torch.cuda.is_available():
        im_tensor = im_tensor.cuda()

    # ์ถ”๋ก 
    with torch.no_grad():
        result = net(im_tensor)
    
    # ํ›„์ฒ˜๋ฆฌ
    result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode='bilinear', align_corners=False), 0)
    result = torch.sigmoid(result)
    mask = (result * 255).byte().cpu().numpy()

    # mask ๋ฐฐ์—ด์ด ์˜ˆ์ƒ๋Œ€๋กœ 2์ฐจ์›์ธ์ง€ ํ™•์ธํ•˜๊ณ , ์•„๋‹ˆ๋ผ๋ฉด ์กฐ์ •
    if mask.ndim > 2:
        mask = mask.squeeze()

    # mask ๋ฐฐ์—ด์„ ๋ช…ํ™•ํžˆ uint8๋กœ ๋ณ€ํ™˜
    mask = mask.astype(np.uint8)

    # ๋งˆ์Šคํฌ๋ฅผ ์•ŒํŒŒ ์ฑ„๋„๋กœ ์‚ฌ์šฉํ•˜์—ฌ ์ตœ์ข… ์ด๋ฏธ์ง€ ์ƒ์„ฑ
    orig_image = orig_image.convert("RGBA")
    final_image = Image.new("RGBA", orig_image.size, (0, 0, 0, 0))
    mask_image = Image.fromarray(mask, mode='L')
    foreground_image = Image.composite(orig_image, final_image, mask_image)

    # ์„ ํƒ์  ๋ฐฐ๊ฒฝ ์ด๋ฏธ์ง€ ์ฒ˜๋ฆฌ
    if background_image:
        final_image = merge_images(background_image, foreground_image)
    else:
        final_image = foreground_image

    return final_image


def merge_images(background_image, foreground_image):
    """
    ๋ฐฐ๊ฒฝ ์ด๋ฏธ์ง€์— ๋ฐฐ๊ฒฝ์ด ์ œ๊ฑฐ๋œ ์ด๋ฏธ์ง€๋ฅผ ํˆฌ๋ช…ํ•˜๊ฒŒ ์‚ฝ์ž…ํ•ฉ๋‹ˆ๋‹ค.
    ๋ฐฐ๊ฒฝ์ด ์ œ๊ฑฐ๋œ ์ด๋ฏธ์ง€๋Š” ๋ฐฐ๊ฒฝ ์ด๋ฏธ์ง€ ์ค‘์•™์— 30% ํฌ๊ธฐ๋กœ ์ถ•์†Œ๋˜์–ด ์‚ฝ์ž…๋ฉ๋‹ˆ๋‹ค.
    """
    background = background_image.convert("RGBA")
    foreground = foreground_image.convert("RGBA")
    
    # ์ „๊ฒฝ ์ด๋ฏธ์ง€๋ฅผ ๋ฐฐ๊ฒฝ ์ด๋ฏธ์ง€์˜ 30% ํฌ๊ธฐ๋กœ ์กฐ์ •
    scale_factor = 0.3
    foreground_width = int(background.width * scale_factor)
    foreground_height = int(foreground.height * foreground_width / foreground.width)
    new_size = (foreground_width, foreground_height)
    foreground_resized = foreground.resize(new_size, Image.Resampling.LANCZOS)
    
    # ์ „๊ฒฝ ์ด๋ฏธ์ง€๋ฅผ ๋ฐฐ๊ฒฝ ์ด๋ฏธ์ง€์˜ ๊ฐ€์šด๋ฐ์— ์œ„์น˜์‹œํ‚ค๊ธฐ ์œ„ํ•œ ์ขŒํ‘œ ๊ณ„์‚ฐ
    x = (background.width - foreground_width) // 2
    y = (background.height - foreground_height) // 2
    
    # ๋ฐฐ๊ฒฝ ์ด๋ฏธ์ง€ ์œ„์— ์ „๊ฒฝ ์ด๋ฏธ์ง€๋ฅผ ๋ถ™์ž„
    background.paste(foreground_resized, (x, y), foreground_resized)
    
    return background


title = "Background Removal"
description = "This is a demo for BRIA RMBG 1.4 using the BRIA RMBG-1.4 image matting model as backbone."

demo = gr.Interface(
    fn=process,
    inputs=[
        gr.Image(type="numpy", label="Image to remove background"),
        gr.Image(type="pil", label="Optional: Background Image")  # 'optional=True'๋Š” Gradio ๋ฒ„์ „์— ๋”ฐ๋ผ ํ•„์š” ์—†์„ ์ˆ˜ ์žˆ์Œ
    ],
    outputs="image",
    title=title,
    description=description
)

if __name__ == "__main__":
    demo.launch()