Spaces:
fantos
/
Runtime error

File size: 7,223 Bytes
dd67556
 
 
 
 
 
a119d24
dd67556
a119d24
dd67556
a119d24
dd67556
a1915c8
 
 
 
 
 
 
 
 
 
 
63e5794
 
 
 
 
 
 
 
 
 
 
 
 
a1915c8
 
63e5794
a1915c8
 
 
 
 
c384edc
 
dd67556
 
 
c384edc
dd67556
c384edc
 
dd67556
c384edc
 
dd67556
c384edc
dd67556
 
 
52b892c
c384edc
cffeaa2
 
 
 
 
 
52b892c
cffeaa2
c384edc
 
cffeaa2
 
 
 
c384edc
cffeaa2
c384edc
cffeaa2
 
c384edc
cffeaa2
c384edc
 
cffeaa2
c384edc
cffeaa2
c384edc
 
cffeaa2
c384edc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d975af1
 
 
 
 
 
c384edc
 
 
a6a92bb
c384edc
d975af1
 
 
a6a92bb
d975af1
 
 
 
c384edc
ef55172
c384edc
 
 
4421191
c384edc
8f831fa
c384edc
 
8f831fa
c384edc
 
 
 
 
 
 
 
637a6ee
 
c384edc
 
725fefe
2686b72
d975af1
 
 
 
 
 
 
c384edc
a1915c8
4421191
 
 
 
 
 
 
 
 
 
 
a1915c8
c384edc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import numpy as np
import torch
import torch.nn.functional as F
from torchvision.transforms.functional import normalize
from huggingface_hub import hf_hub_download
import gradio as gr
from gradio_imageslider import ImageSlider
from briarmbg import BriaRMBG
import PIL
from PIL import Image
from typing import Tuple

import os
import requests
from moviepy.editor import VideoFileClip
from moviepy.audio.AudioClip import AudioClip

def search_pexels_images(query):
    API_KEY = os.getenv("API_KEY")
    url = f"https://api.pexels.com/v1/search?query={query}&per_page=80"
    headers = {"Authorization": API_KEY}
    response = requests.get(url, headers=headers)
    data = response.json()
    
    # ๊ณ ํ•ด์ƒ๋„ ์ด๋ฏธ์ง€ URL๋งŒ ์„ ํƒํ•˜์—ฌ ๋ฆฌ์ŠคํŠธ ์ƒ์„ฑ
    images_urls = []
    for photo in data.get('photos', []):
        # 'large2x' ํ•ด์ƒ๋„์˜ ์ด๋ฏธ์ง€๊ฐ€ ์ œ๊ณต๋˜๋Š” ๊ฒฝ์šฐ, ํ•ด๋‹น URL ์‚ฌ์šฉ
        if 'src' in photo and 'large2x' in photo['src']:
            images_urls.append(photo['src']['large2x'])
        # 'large2x' ํ•ด์ƒ๋„์˜ ์ด๋ฏธ์ง€๊ฐ€ ์—†๋Š” ๊ฒฝ์šฐ, 'large' ๋˜๋Š” 'original'์„ ๋Œ€์ฒด๋กœ ์‚ฌ์šฉ
        elif 'large' in photo['src']:
            images_urls.append(photo['src']['large'])
        elif 'original' in photo['src']:
            images_urls.append(photo['src']['original'])

    return images_urls


def show_search_results(query):
    images_urls = search_pexels_images(query)
    return images_urls
    

net=BriaRMBG()
# model_path = "./model1.pth"
model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth')
if torch.cuda.is_available():
    net.load_state_dict(torch.load(model_path))
    net=net.cuda()
else:
    net.load_state_dict(torch.load(model_path,map_location="cpu"))
net.eval() 

    
def resize_image(image):
    image = image.convert('RGB')
    model_input_size = (1024, 1024)
    image = image.resize(model_input_size, Image.BILINEAR)
    return image


def process(image):
    # ์ด๋ฏธ์ง€๊ฐ€ numpy ๋ฐฐ์—ด์ธ ๊ฒฝ์šฐ์—๋งŒ PIL.Image ๊ฐ์ฒด๋กœ ๋ณ€ํ™˜
    if isinstance(image, np.ndarray):
        orig_image = Image.fromarray(image)
    else:
        # ์ด๋ฏธ PIL.Image.Image ๊ฐ์ฒด์ธ ๊ฒฝ์šฐ, ๋ณ€ํ™˜ ์—†์ด ์‚ฌ์šฉ
        orig_image = image

    w, h = orig_im_size = orig_image.size
    image = resize_image(orig_image)
    im_np = np.array(image)
    im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
    im_tensor = torch.unsqueeze(im_tensor, 0)
    im_tensor = torch.divide(im_tensor, 255.0)
    im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
    if torch.cuda.is_available():
        im_tensor = im_tensor.cuda()

    # inference
    result = net(im_tensor)
    # post process
    result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode='bilinear'), 0)
    ma = torch.max(result)
    mi = torch.min(result)
    result = (result - mi) / (ma - mi)
    # image to pil
    im_array = (result * 255).cpu().data.numpy().astype(np.uint8)
    pil_im = Image.fromarray(np.squeeze(im_array))
    # paste the mask on the original image
    new_im = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
    new_im.paste(orig_image, mask=pil_im)

    return new_im

def calculate_position(org_size, add_size, position):
    if position == "์ƒ๋‹จ ์ขŒ์ธก":
        return (0, 0)
    elif position == "์ƒ๋‹จ ๊ฐ€์šด๋ฐ":
        return ((org_size[0] - add_size[0]) // 2, 0)
    elif position == "์ƒ๋‹จ ์šฐ์ธก":
        return (org_size[0] - add_size[0], 0)
    elif position == "์ค‘์•™ ์ขŒ์ธก":
        return (0, (org_size[1] - add_size[1]) // 2)
    elif position == "์ค‘์•™ ๊ฐ€์šด๋ฐ":
        return ((org_size[0] - add_size[0]) // 2, (org_size[1] - add_size[1]) // 2)
    elif position == "์ค‘์•™ ์šฐ์ธก":
        return (org_size[0] - add_size[0], (org_size[1] - add_size[1]) // 2)
    elif position == "ํ•˜๋‹จ ์ขŒ์ธก":
        return (0, org_size[1] - add_size[1])
    elif position == "ํ•˜๋‹จ ๊ฐ€์šด๋ฐ":
        return ((org_size[0] - add_size[0]) // 2, org_size[1] - add_size[1])
    elif position == "ํ•˜๋‹จ ์šฐ์ธก":
        return (org_size[0] - add_size[0], org_size[1] - add_size[1])


def merge(org_image, add_image, scale, position, display_size):
    # ์‚ฌ์šฉ์ž๊ฐ€ ์„ ํƒํ•œ ๋””์Šคํ”Œ๋ ˆ์ด ํฌ๊ธฐ์— ๋”ฐ๋ผ ๊ฒฐ๊ณผ ์ด๋ฏธ์ง€ ํฌ๊ธฐ ์กฐ์ ˆ
    display_width, display_height = map(int, display_size.split('x'))
    
    # ์ด๋ฏธ์ง€ ๋ณ‘ํ•ฉ ๋กœ์ง
    scale_percentage = scale / 100.0
    new_size = (int(add_image.width * scale_percentage), int(add_image.height * scale_percentage))
    add_image = add_image.resize(new_size, Image.Resampling.LANCZOS)
    
    position = calculate_position(org_image.size, add_image.size, position)
    merged_image = Image.new("RGBA", org_image.size)
    merged_image.paste(org_image, (0, 0))
    merged_image.paste(add_image, position, add_image)
    
    # ๊ฒฐ๊ณผ ์ด๋ฏธ์ง€ ๋””์Šคํ”Œ๋ ˆ์ด ํฌ๊ธฐ ์กฐ์ ˆ
    final_image = merged_image.resize((display_width, display_height), Image.Resampling.LANCZOS)
    
    return final_image


with gr.Blocks() as demo:
    with gr.Tab("Background Removal"):
        with gr.Column():
            gr.Markdown("๋ˆ„๋ผ๋”ฐ๊ธฐ์˜ ์™• '๋ˆ„ํ‚น'(Nuking)")
            gr.HTML('''
                <p style="margin-bottom: 10px; font-size: 94%">
                This is a demo for BRIA RMBG 1.4 that using
                <a href="https://huggingface.co/briaai/RMBG-1.4" target="_blank">BRIA RMBG-1.4 image matting model</a> as backbone. 
                </p>
            ''')
            input_image = gr.Image(type="pil")
            output_image = gr.Image()
            process_button = gr.Button("Remove Background")
            process_button.click(fn=process, inputs=input_image, outputs=output_image)

    with gr.Tab("Merge"):
        with gr.Column():
            org_image = gr.Image(label="Background", type='pil', image_mode='RGBA', height=400)  # ์˜ˆ์‹œ๋กœ ๋†’์ด ์กฐ์ ˆ
            add_image = gr.Image(label="Foreground", type='pil', image_mode='RGBA', height=400)  # ์˜ˆ์‹œ๋กœ ๋†’์ด ์กฐ์ ˆ
            scale = gr.Slider(minimum=10, maximum=200, step=1, value=100, label="Scale of Foreground Image (%)")
            position = gr.Radio(choices=["์ค‘์•™ ๊ฐ€์šด๋ฐ", "์ƒ๋‹จ ์ขŒ์ธก", "์ƒ๋‹จ ๊ฐ€์šด๋ฐ", "์ƒ๋‹จ ์šฐ์ธก", "์ค‘์•™ ์ขŒ์ธก", "์ค‘์•™ ์šฐ์ธก", "ํ•˜๋‹จ ์ขŒ์ธก", "ํ•˜๋‹จ ๊ฐ€์šด๋ฐ", "ํ•˜๋‹จ ์šฐ์ธก"], value="์ค‘์•™ ๊ฐ€์šด๋ฐ", label="Position of Foreground Image")
            display_size = gr.Textbox(value="1024x768", label="Display Size (Width x Height)")
            btn_merge = gr.Button("Merge Images")
            result_merge = gr.Image()
            
            btn_merge.click(
                fn=merge,
                inputs=[org_image, add_image, scale, position, display_size],
                outputs=result_merge,
            )

            
    with gr.TabItem("Image Search"):
        with gr.Column():
            gr.Markdown("### FREE Image Search")
            search_query = gr.Textbox(label="์‚ฌ์ง„ ๊ฒ€์ƒ‰")
            search_btn = gr.Button("๊ฒ€์ƒ‰")
            images_output = gr.Gallery(label="๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ ์ด๋ฏธ์ง€")
            search_btn.click(
                fn=show_search_results,
                inputs=search_query,
                outputs=images_output
            )

demo.launch()