Spaces:
fantos
/
Runtime error

File size: 4,920 Bytes
dd67556
 
 
 
 
 
a119d24
dd67556
a119d24
dd67556
a119d24
dd67556
c384edc
 
dd67556
 
 
c384edc
dd67556
c384edc
 
dd67556
c384edc
 
dd67556
c384edc
dd67556
 
 
52b892c
c384edc
cffeaa2
 
 
 
 
 
52b892c
cffeaa2
c384edc
 
cffeaa2
 
 
 
c384edc
cffeaa2
c384edc
cffeaa2
 
c384edc
cffeaa2
c384edc
 
cffeaa2
c384edc
cffeaa2
c384edc
 
cffeaa2
c384edc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a6a92bb
c384edc
 
a6a92bb
c384edc
 
 
 
 
 
 
 
 
8f831fa
c384edc
 
8f831fa
c384edc
 
 
 
 
 
 
 
8f831fa
 
c384edc
 
2686b72
8f831fa
 
c384edc
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import numpy as np
import torch
import torch.nn.functional as F
from torchvision.transforms.functional import normalize
from huggingface_hub import hf_hub_download
import gradio as gr
from gradio_imageslider import ImageSlider
from briarmbg import BriaRMBG
import PIL
from PIL import Image
from typing import Tuple

net=BriaRMBG()
# model_path = "./model1.pth"
model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth')
if torch.cuda.is_available():
    net.load_state_dict(torch.load(model_path))
    net=net.cuda()
else:
    net.load_state_dict(torch.load(model_path,map_location="cpu"))
net.eval() 

    
def resize_image(image):
    image = image.convert('RGB')
    model_input_size = (1024, 1024)
    image = image.resize(model_input_size, Image.BILINEAR)
    return image


def process(image):
    # 이미지가 numpy 배열인 경우에만 PIL.Image 객체로 변환
    if isinstance(image, np.ndarray):
        orig_image = Image.fromarray(image)
    else:
        # 이미 PIL.Image.Image 객체인 경우, 변환 없이 사용
        orig_image = image

    w, h = orig_im_size = orig_image.size
    image = resize_image(orig_image)
    im_np = np.array(image)
    im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
    im_tensor = torch.unsqueeze(im_tensor, 0)
    im_tensor = torch.divide(im_tensor, 255.0)
    im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
    if torch.cuda.is_available():
        im_tensor = im_tensor.cuda()

    # inference
    result = net(im_tensor)
    # post process
    result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode='bilinear'), 0)
    ma = torch.max(result)
    mi = torch.min(result)
    result = (result - mi) / (ma - mi)
    # image to pil
    im_array = (result * 255).cpu().data.numpy().astype(np.uint8)
    pil_im = Image.fromarray(np.squeeze(im_array))
    # paste the mask on the original image
    new_im = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
    new_im.paste(orig_image, mask=pil_im)

    return new_im

def calculate_position(org_size, add_size, position):
    if position == "상단 좌측":
        return (0, 0)
    elif position == "상단 가운데":
        return ((org_size[0] - add_size[0]) // 2, 0)
    elif position == "상단 우측":
        return (org_size[0] - add_size[0], 0)
    elif position == "중앙 좌측":
        return (0, (org_size[1] - add_size[1]) // 2)
    elif position == "중앙 가운데":
        return ((org_size[0] - add_size[0]) // 2, (org_size[1] - add_size[1]) // 2)
    elif position == "중앙 우측":
        return (org_size[0] - add_size[0], (org_size[1] - add_size[1]) // 2)
    elif position == "하단 좌측":
        return (0, org_size[1] - add_size[1])
    elif position == "하단 가운데":
        return ((org_size[0] - add_size[0]) // 2, org_size[1] - add_size[1])
    elif position == "하단 우측":
        return (org_size[0] - add_size[0], org_size[1] - add_size[1])

def merge(org_image, add_image, scale, position):
    scale_percentage = scale / 100.0
    new_size = (int(add_image.width * scale_percentage), int(add_image.height * scale_percentage))
    add_image = add_image.resize(new_size, Image.Resampling.LANCZOS)
    
    position = calculate_position(org_image.size, add_image.size, position)
    org_image.paste(add_image, position, add_image)
    
    return org_image



with gr.Blocks() as demo:
    with gr.Tab("Background Removal"):
        with gr.Column():
            gr.Markdown("## BRIA RMBG 1.4")
            gr.HTML('''
                <p style="margin-bottom: 10px; font-size: 94%">
                This is a demo for BRIA RMBG 1.4 that using
                <a href="https://huggingface.co/briaai/RMBG-1.4" target="_blank">BRIA RMBG-1.4 image matting model</a> as backbone. 
                </p>
            ''')
            input_image = gr.Image(type="pil")
            output_image = gr.Image()
            process_button = gr.Button("Remove Background")
            process_button.click(fn=process, inputs=input_image, outputs=output_image)

    with gr.Tab("Merge"):
        with gr.Column():
            org_image = gr.Image(label="Background", type='pil', image_mode='RGBA', height="40vh")
            add_image = gr.Image(label="Foreground", type='pil', image_mode='RGBA', height="40vh")
            scale = gr.Slider(minimum=10, maximum=200, step=1, value=100, label="Scale of Foreground Image (%)")
            position = gr.Radio(choices=["중앙 가운데", "상단 좌측", "상단 가운데", "상단 우측", "중앙 좌측", "중앙 우측", "하단 좌측", "하단 가운데", "하단 우측"], value="중앙 가운데", label="Position of Foreground Image")
            btn_merge = gr.Button("Merge Images")
            result_merge = gr.Image(height="40vh")
            btn_merge.click(fn=merge, inputs=[org_image, add_image, scale, position], outputs=result_merge)

demo.launch()