File size: 2,311 Bytes
dd67556
 
 
 
3ccd256
dd67556
 
 
 
c384edc
dd67556
c384edc
dd67556
 
 
c384edc
cffeaa2
 
 
 
 
 
52b892c
cffeaa2
c384edc
 
cffeaa2
 
 
 
c384edc
cffeaa2
c384edc
815368c
 
 
 
 
 
 
 
 
 
cffeaa2
c384edc
cffeaa2
c384edc
 
cffeaa2
c384edc
cffeaa2
c384edc
 
cffeaa2
c384edc
 
 
 
3ccd256
 
 
 
 
 
 
 
815368c
 
 
3ccd256
815368c
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import numpy as np
import torch
import torch.nn.functional as F
from torchvision.transforms.functional import normalize
from briarmbg import BriaRMBG
from huggingface_hub import hf_hub_download
import gradio as gr
from PIL import Image

def resize_image(image):
    image = image.convert('RGB')
    model_input_size = (1024, 1024)
    image = image.resize(model_input_size, Image.BILINEAR)
    return image

def process(image):
    # 이미지가 numpy 배열인 경우에만 PIL.Image 객체로 변환
    if isinstance(image, np.ndarray):
        orig_image = Image.fromarray(image)
    else:
        # 이미 PIL.Image.Image 객체인 경우, 변환 없이 사용
        orig_image = image

    w, h = orig_im_size = orig_image.size
    image = resize_image(orig_image)
    im_np = np.array(image)
    im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
    im_tensor = torch.unsqueeze(im_tensor, 0)
    im_tensor = torch.divide(im_tensor, 255.0)
    im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
    if torch.cuda.is_available():
        im_tensor = im_tensor.cuda()

    # 모델 로딩 및 예측
    net = BriaRMBG()
    model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth')
    if torch.cuda.is_available():
        net.load_state_dict(torch.load(model_path))
        net = net.cuda()
    else:
        net.load_state_dict(torch.load(model_path, map_location="cpu"))
    net.eval()

    result = net(im_tensor)
    # post process
    result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode='bilinear'), 0)
    ma = torch.max(result)
    mi = torch.min(result)
    result = (result - mi) / (ma - mi)
    # image to pil
    im_array = (result * 255).cpu().data.numpy().astype(np.uint8)
    pil_im = Image.fromarray(np.squeeze(im_array))
    # paste the mask on the original image
    new_im = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
    new_im.paste(orig_image, mask=pil_im)

    return new_im


css = """
footer { 
    visibility: hidden; 
}
"""

with gr.Blocks(css=css) as demo:
    with gr.Column():
        input_image = gr.Image(type="pil")
        output_image = gr.Image()
        process_button = gr.Button("Remove Background Image")
        process_button.click(fn=process, inputs=input_image, outputs=output_image)

demo.launch()