HybaAI commited on
Commit
155e7bd
·
verified ·
1 Parent(s): 99eb030

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -37
app.py CHANGED
@@ -15,20 +15,20 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
  net.to(device)
16
 
17
 
18
- def resize_image(image):
19
- image = image.convert('RGB')
20
- model_input_size = (1024, 1024)
21
- image = image.resize(model_input_size, Image.BILINEAR)
22
- return image
23
 
24
 
25
- def process(image):
26
 
27
- # prepare input
28
- orig_image = Image.fromarray(image)
29
- w,h = orig_im_size = orig_image.size
30
- image = resize_image(orig_image)
31
- im_np = np.array(image)
32
  im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2,0,1)
33
  im_tensor = torch.unsqueeze(im_tensor,0)
34
  im_tensor = torch.divide(im_tensor,255.0)
@@ -36,21 +36,21 @@ def process(image):
36
  if torch.cuda.is_available():
37
  im_tensor=im_tensor.cuda()
38
 
39
- #inference
40
- result=net(im_tensor)
41
- # post process
42
- result = torch.squeeze(F.interpolate(result[0][0], size=(h,w), mode='bilinear') ,0)
43
- ma = torch.max(result)
44
- mi = torch.min(result)
45
- result = (result-mi)/(ma-mi)
46
- # image to pil
47
- im_array = (result*255).cpu().data.numpy().astype(np.uint8)
48
  pil_im = Image.fromarray(np.squeeze(im_array))
49
- # paste the mask on the original image
50
- new_im = Image.new("RGBA", pil_im.size, (0,0,0,0))
51
- new_im.paste(orig_image, mask=pil_im)
52
 
53
- return new_im
54
 
55
 
56
  # block = gr.Blocks().queue()
@@ -59,20 +59,20 @@ def process(image):
59
  # gr.Markdown("## BRIA RMBG 1.4")
60
  # gr.HTML('''
61
  # <p style="margin-bottom: 10px; font-size: 94%">
62
- # This is a demo for BRIA RMBG 1.4 that using
63
- # <a href="https://huggingface.co/briaai/RMBG-1.4" target="_blank">BRIA RMBG-1.4 image matting model</a> as backbone.
64
  # </p>
65
  # ''')
66
  # with gr.Row():
67
  # with gr.Column():
68
- # input_image = gr.Image(sources=None, type="pil") # None for upload, ctrl+v and webcam
69
- # # input_image = gr.Image(sources=None, type="numpy") # None for upload, ctrl+v and webcam
70
- # run_button = gr.Button(value="Run")
71
 
72
  # with gr.Column():
73
- # result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery", columns=[1], height='auto')
74
  # ips = [input_image]
75
- # run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
76
 
77
  # block.launch(debug = True)
78
 
@@ -81,14 +81,14 @@ def process(image):
81
  gr.Markdown("## BRIA RMBG 1.4")
82
  gr.HTML('''
83
  <p style="margin-bottom: 10px; font-size: 94%">
84
- This is a demo for BRIA RMBG 1.4 that using
85
- <a href="https://huggingface.co/briaai/RMBG-1.4" target="_blank">BRIA RMBG-1.4 image matting model</a> as backbone.
86
  </p>
87
  ''')
88
- examples = [['./input.jpg'],]
89
- # output = ImageSlider(position=0.5,label='Image without background', type="pil", show_download_button=True)
90
- # demo = gr.Interface(fn=process,inputs="image", outputs=output, examples=examples, title=title, description=description)
91
- demo = gr.Interface(fn=process,inputs="image", outputs="image", examples=examples)
92
 
93
  if __name__ == "__main__":
94
  demo.launch(share=False)
 
15
  net.to(device)
16
 
17
 
18
+ def redimensionar_imagem(imagem):
19
+ imagem = imagem.convert('RGB')
20
+ tamanho_entrada_modelo = (1024, 1024)
21
+ imagem = imagem.resize(tamanho_entrada_modelo, Image.BILINEAR)
22
+ return imagem
23
 
24
 
25
+ def processar(imagem):
26
 
27
+ # preparar entrada
28
+ imagem_original = Image.fromarray(imagem)
29
+ w,h = tamanho_imagem_original = imagem_original.size
30
+ imagem = redimensionar_imagem(imagem_original)
31
+ im_np = np.array(imagem)
32
  im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2,0,1)
33
  im_tensor = torch.unsqueeze(im_tensor,0)
34
  im_tensor = torch.divide(im_tensor,255.0)
 
36
  if torch.cuda.is_available():
37
  im_tensor=im_tensor.cuda()
38
 
39
+ # inferência
40
+ resultado = net(im_tensor)
41
+ # pós-processamento
42
+ resultado = torch.squeeze(F.interpolate(resultado[0][0], size=(h,w), mode='bilinear') ,0)
43
+ ma = torch.max(resultado)
44
+ mi = torch.min(resultado)
45
+ resultado = (resultado-mi)/(ma-mi)
46
+ # imagem para pil
47
+ im_array = (resultado*255).cpu().data.numpy().astype(np.uint8)
48
  pil_im = Image.fromarray(np.squeeze(im_array))
49
+ # colar a máscara na imagem original
50
+ nova_im = Image.new("RGBA", pil_im.size, (0,0,0,0))
51
+ nova_im.paste(imagem_original, mask=pil_im)
52
 
53
+ return nova_im
54
 
55
 
56
  # block = gr.Blocks().queue()
 
59
  # gr.Markdown("## BRIA RMBG 1.4")
60
  # gr.HTML('''
61
  # <p style="margin-bottom: 10px; font-size: 94%">
62
+ # Esta é uma demonstração do BRIA RMBG 1.4 que utiliza
63
+ # <a href="https://huggingface.co/briaai/RMBG-1.4" target="_blank">o modelo de matting de imagem BRIA RMBG-1.4</a> como base.
64
  # </p>
65
  # ''')
66
  # with gr.Row():
67
  # with gr.Column():
68
+ # input_image = gr.Image(sources=None, type="pil") # None para upload, ctrl+v e webcam
69
+ # # input_image = gr.Image(sources=None, type="numpy") # None para upload, ctrl+v e webcam
70
+ # run_button = gr.Button(value="Executar")
71
 
72
  # with gr.Column():
73
+ # result_gallery = gr.Gallery(label='Resultado', show_label=False, elem_id="gallery", columns=[1], height='auto')
74
  # ips = [input_image]
75
+ # run_button.click(fn=processar, inputs=ips, outputs=[result_gallery])
76
 
77
  # block.launch(debug = True)
78
 
 
81
  gr.Markdown("## BRIA RMBG 1.4")
82
  gr.HTML('''
83
  <p style="margin-bottom: 10px; font-size: 94%">
84
+ Esta é uma demonstração do BRIA RMBG 1.4 que utiliza
85
+ <a href="https://huggingface.co/briaai/RMBG-1.4" target="_blank">o modelo de matting de imagem BRIA RMBG-1.4</a> como base.
86
  </p>
87
  ''')
88
+ exemplos = [['./input.jpg'],]
89
+ # output = ImageSlider(position=0.5,label='Imagem sem fundo', type="pil", show_download_button=True)
90
+ # demo = gr.Interface(fn=processar,inputs="image", outputs=output, examples=exemplos, title=title, description=description)
91
+ demo = gr.Interface(fn=processar,inputs="image", outputs="image", examples=exemplos)
92
 
93
  if __name__ == "__main__":
94
  demo.launch(share=False)