HYBA-RMBG-2.4 / app.py
HybaAI's picture
Update app.py
1c01319 verified
import numpy as np
import torch
import torch.nn.functional as F
from torchvision.transforms.functional import normalize
import gradio as gr
from briarmbg import BriaRMBG
import PIL
from PIL import Image
from typing import Tuple
# Carregar o modelo pré-treinado
net = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net.to(device)
# Função para redimensionar a imagem para o tamanho que o modelo espera
def redimensionar_imagem(imagem):
imagem = imagem.convert('RGB')
tamanho_entrada_modelo = (1024, 1024)
imagem = imagem.resize(tamanho_entrada_modelo, Image.BILINEAR)
return imagem
# Função principal para processar a imagem
def processar(imagem):
# preparar entrada
imagem_original = Image.fromarray(imagem)
w, h = imagem_original.size
imagem = redimensionar_imagem(imagem_original)
im_np = np.array(imagem)
im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
im_tensor = torch.unsqueeze(im_tensor, 0)
im_tensor = torch.divide(im_tensor, 255.0)
im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
if torch.cuda.is_available():
im_tensor = im_tensor.cuda()
# Inferência com o modelo
resultado = net(im_tensor)
# Pós-processamento
resultado = torch.squeeze(F.interpolate(resultado[0][0], size=(h, w), mode='bilinear'), 0)
ma = torch.max(resultado)
mi = torch.min(resultado)
resultado = (resultado - mi) / (ma - mi)
# Convertendo o resultado para imagem PIL
im_array = (resultado * 255).cpu().data.numpy().astype(np.uint8)
pil_im = Image.fromarray(np.squeeze(im_array))
# Colando a máscara na imagem original
nova_im = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
nova_im.paste(imagem_original, mask=pil_im)
return nova_im
# Interface com Gradio
gr.Markdown("")
# Exemplos
exemplos = [['./input.jpg']]
# Configurando a interface
output = gr.Image(type="pil", label="Imagem Processada")
# Definindo a interface com inputs e outputs
demo = gr.Interface(
fn=processar,
inputs=gr.Image(type="numpy", label="Carregar Imagem"), # Mantido em português
outputs=output,
examples=exemplos
)
# Executando a interface
demo.launch(share=False)