File size: 3,429 Bytes
542c815
3f8e328
542c815
 
 
a888400
d6e753e
 
 
8a357d1
542c815
3267028
018621a
3267028
b98efed
542c815
 
155e7bd
 
 
 
 
542c815
 
155e7bd
542c815
155e7bd
 
 
 
 
542c815
 
 
 
 
 
70974c3
155e7bd
 
 
 
 
 
 
 
 
542c815
155e7bd
 
 
542c815
155e7bd
542c815
 
d909bca
 
 
 
 
 
155e7bd
 
d909bca
 
 
 
155e7bd
 
 
542c815
d909bca
155e7bd
d909bca
155e7bd
542c815
d909bca
 
c530952
d909bca
c530952
 
 
155e7bd
 
c530952
 
155e7bd
 
 
 
d909bca
 
99eb030
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import numpy as np
import torch
import torch.nn.functional as F
from torchvision.transforms.functional import normalize
import gradio as gr
from gradio_imageslider import ImageSlider
from briarmbg import BriaRMBG
import PIL
from PIL import Image
from typing import Tuple


net = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net.to(device)

    
def redimensionar_imagem(imagem):
    imagem = imagem.convert('RGB')
    tamanho_entrada_modelo = (1024, 1024)
    imagem = imagem.resize(tamanho_entrada_modelo, Image.BILINEAR)
    return imagem


def processar(imagem):

    # preparar entrada
    imagem_original = Image.fromarray(imagem)
    w,h = tamanho_imagem_original = imagem_original.size
    imagem = redimensionar_imagem(imagem_original)
    im_np = np.array(imagem)
    im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2,0,1)
    im_tensor = torch.unsqueeze(im_tensor,0)
    im_tensor = torch.divide(im_tensor,255.0)
    im_tensor = normalize(im_tensor,[0.5,0.5,0.5],[1.0,1.0,1.0])
    if torch.cuda.is_available():
        im_tensor=im_tensor.cuda()

    # inferência
    resultado = net(im_tensor)
    # pós-processamento
    resultado = torch.squeeze(F.interpolate(resultado[0][0], size=(h,w), mode='bilinear') ,0)
    ma = torch.max(resultado)
    mi = torch.min(resultado)
    resultado = (resultado-mi)/(ma-mi)    
    # imagem para pil
    im_array = (resultado*255).cpu().data.numpy().astype(np.uint8)
    pil_im = Image.fromarray(np.squeeze(im_array))
    # colar a máscara na imagem original
    nova_im = Image.new("RGBA", pil_im.size, (0,0,0,0))
    nova_im.paste(imagem_original, mask=pil_im)

    return nova_im


# block = gr.Blocks().queue()

# with block:
#     gr.Markdown("## BRIA RMBG 1.4")
#     gr.HTML('''
#       <p style="margin-bottom: 10px; font-size: 94%">
#         Esta é uma demonstração do BRIA RMBG 1.4 que utiliza
#         <a href="https://huggingface.co/briaai/RMBG-1.4" target="_blank">o modelo de matting de imagem BRIA RMBG-1.4</a> como base. 
#       </p>
#     ''')
#     with gr.Row():
#         with gr.Column():
#             input_image = gr.Image(sources=None, type="pil") # None para upload, ctrl+v e webcam
#             # input_image = gr.Image(sources=None, type="numpy") # None para upload, ctrl+v e webcam
#             run_button = gr.Button(value="Executar")
            
#         with gr.Column():
#             result_gallery = gr.Gallery(label='Resultado', show_label=False, elem_id="gallery", columns=[1], height='auto')
#     ips = [input_image]
#     run_button.click(fn=processar, inputs=ips, outputs=[result_gallery])

# block.launch(debug = True)

# block = gr.Blocks().queue()

gr.Markdown("## BRIA RMBG 1.4")
gr.HTML('''
  <p style="margin-bottom: 10px; font-size: 94%">
    Esta é uma demonstração do BRIA RMBG 1.4 que utiliza
    <a href="https://huggingface.co/briaai/RMBG-1.4" target="_blank">o modelo de matting de imagem BRIA RMBG-1.4</a> como base. 
  </p>
''')
exemplos = [['./input.jpg'],]
# output = ImageSlider(position=0.5,label='Imagem sem fundo', type="pil", show_download_button=True)
# demo = gr.Interface(fn=processar,inputs="image", outputs=output, examples=exemplos, title=title, description=description)
demo = gr.Interface(fn=processar,inputs="image", outputs="image", examples=exemplos)

if __name__ == "__main__":
    demo.launch(share=False)