File size: 2,478 Bytes
542c815
3f8e328
542c815
 
 
d6e753e
 
 
29b2f06
 
3267028
018621a
3267028
b98efed
8a70686
542c815
 
988f91c
 
542c815
 
29b2f06
 
 
 
 
 
04ba376
29b2f06
 
8a70686
29b2f06
 
 
 
8a70686
29b2f06
8a70686
29b2f06
 
8a70686
 
29b2f06
 
8a70686
29b2f06
 
8a70686
d909bca
c530952
6ca28a8
4941fcb
6ca28a8
d909bca
29b2f06
 
 
 
 
 
 
 
d909bca
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import numpy as np
import torch
import torch.nn.functional as F
from torchvision.transforms.functional import normalize
import gradio as gr
from briarmbg import BriaRMBG
import PIL
from PIL import Image
import requests
from io import BytesIO

net = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net.to(device)
    
def resize_image(image):
    image = image.convert('RGB')
    model_input_size = (1024, 1024)
    image = image.resize(model_input_size, Image.BILINEAR)
    return image

def process(image=None, url=None):
    if url:
        response = requests.get(url)
        image = Image.open(BytesIO(response.content))
    else:
        image = Image.fromarray(image)

    w, h = orig_im_size = image.size
    image = resize_image(image)
    im_np = np.array(image)
    im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
    im_tensor = torch.unsqueeze(im_tensor, 0)
    im_tensor = torch.divide(im_tensor, 255.0)
    im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
    if torch.cuda.is_available():
        im_tensor = im_tensor.cuda()

    result = net(im_tensor)
    result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode='bilinear'), 0)
    ma = torch.max(result)
    mi = torch.min(result)
    result = (result - mi) / (ma - mi)
    im_array = (result * 255).cpu().data.numpy().astype(np.uint8)
    pil_im = Image.fromarray(np.squeeze(im_array))
    new_im = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
    new_im.paste(image, mask=pil_im)
    return new_im

title = "Background Removal"
description = r"""Background removal model developed by <a href='https://BRIA.AI' target='_blank'><b>BRIA.AI</b></a>, trained on a carefully selected dataset and is available as an open-source model for non-commercial use.<br> 
For test upload your image and wait. Read more at model card <a href='https://huggingface.co/briaai/RMBG-1.4' target='_blank'><b>briaai/RMBG-1.4</b></a>.<br>
"""
examples = [['./input.jpg'],]

inputs = [
    gr.Image(source="upload", tool="editor", type="numpy", label="Upload Image"),
    gr.Textbox(label="Image URL", placeholder="Enter the URL of an image")
]
output = gr.Image(type="pil", label="Image without background", show_download_button=True)

demo = gr.Interface(fn=process, inputs=inputs, outputs=output, examples=examples, title=title, description=description)

if __name__ == "__main__":
    demo.launch(share=False)