Spaces:
Sleeping
Sleeping
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from torchvision.transforms.functional import normalize | |
from huggingface_hub import hf_hub_download | |
from fastapi import FastAPI, HTTPException | |
from PIL import Image | |
import httpx | |
import io | |
import briarmbg # Importando o modelo de remoção de fundo | |
app = FastAPI() | |
# Carregar modelo BRIA RMBG | |
net = briarmbg.BriaRMBG() | |
model_path = hf_hub_download("briaai/RMBG-1.4", "model.pth") | |
if torch.cuda.is_available(): | |
net.load_state_dict(torch.load(model_path)) | |
net = net.cuda() | |
else: | |
net.load_state_dict(torch.load(model_path, map_location="cpu")) | |
net.eval() | |
# Função para redimensionar a imagem antes de processar | |
def resize_image(image): | |
image = image.convert("RGB") | |
model_input_size = (1024, 1024) | |
return image.resize(model_input_size, Image.BILINEAR) | |
# Função para remover o fundo da imagem | |
def remove_bg(image: Image.Image): | |
orig_image = image | |
w, h = orig_image.size | |
image = resize_image(orig_image) | |
im_np = np.array(image) | |
im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1) | |
im_tensor = torch.unsqueeze(im_tensor, 0) / 255.0 | |
im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0]) | |
if torch.cuda.is_available(): | |
im_tensor = im_tensor.cuda() | |
# Inference do modelo | |
result = net(im_tensor)[0][0] | |
result = torch.squeeze(F.interpolate(result, size=(h, w), mode="bilinear"), 0) | |
# Normalizar para 0-255 | |
ma, mi = torch.max(result), torch.min(result) | |
result = (result - mi) / (ma - mi) | |
im_array = (result * 255).cpu().data.numpy().astype(np.uint8) | |
pil_im = Image.fromarray(np.squeeze(im_array)) | |
# Criar imagem sem fundo | |
new_im = Image.new("RGBA", pil_im.size, (0, 0, 0, 0)) | |
new_im.paste(orig_image, mask=pil_im) | |
return new_im | |
# Função para baixar imagem de uma URL | |
async def download_image(image_url: str) -> Image.Image: | |
async with httpx.AsyncClient() as client: | |
response = await client.get(image_url) | |
if response.status_code != 200: | |
raise HTTPException(status_code=400, detail="Erro ao baixar imagem") | |
return Image.open(io.BytesIO(response.content)) | |
# Endpoint para remover fundo | |
async def remove_bg_from_url(image_url: str): | |
try: | |
image = await download_image(image_url) | |
output_image = remove_bg(image) | |
# Salvar a imagem temporariamente na memória | |
img_io = io.BytesIO() | |
output_image.save(img_io, format="PNG") | |
img_io.seek(0) | |
return { | |
"message": "Fundo removido com sucesso!", | |
"image": img_io.getvalue() | |
} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) |