import numpy as np import torch import torch.nn.functional as F from torchvision.transforms.functional import normalize from huggingface_hub import hf_hub_download from fastapi import FastAPI, HTTPException from PIL import Image import httpx import io import briarmbg # Importando o modelo de remoção de fundo app = FastAPI() # Carregar modelo BRIA RMBG net = briarmbg.BriaRMBG() model_path = hf_hub_download("briaai/RMBG-1.4", "model.pth") if torch.cuda.is_available(): net.load_state_dict(torch.load(model_path)) net = net.cuda() else: net.load_state_dict(torch.load(model_path, map_location="cpu")) net.eval() # Função para redimensionar a imagem antes de processar def resize_image(image): image = image.convert("RGB") model_input_size = (1024, 1024) return image.resize(model_input_size, Image.BILINEAR) # Função para remover o fundo da imagem def remove_bg(image: Image.Image): orig_image = image w, h = orig_image.size image = resize_image(orig_image) im_np = np.array(image) im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1) im_tensor = torch.unsqueeze(im_tensor, 0) / 255.0 im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0]) if torch.cuda.is_available(): im_tensor = im_tensor.cuda() # Inference do modelo result = net(im_tensor)[0][0] result = torch.squeeze(F.interpolate(result, size=(h, w), mode="bilinear"), 0) # Normalizar para 0-255 ma, mi = torch.max(result), torch.min(result) result = (result - mi) / (ma - mi) im_array = (result * 255).cpu().data.numpy().astype(np.uint8) pil_im = Image.fromarray(np.squeeze(im_array)) # Criar imagem sem fundo new_im = Image.new("RGBA", pil_im.size, (0, 0, 0, 0)) new_im.paste(orig_image, mask=pil_im) return new_im # Função para baixar imagem de uma URL async def download_image(image_url: str) -> Image.Image: async with httpx.AsyncClient() as client: response = await client.get(image_url) if response.status_code != 200: raise HTTPException(status_code=400, detail="Erro ao baixar imagem") return Image.open(io.BytesIO(response.content)) # Endpoint para remover fundo @app.get("/remove-bg/") async def remove_bg_from_url(image_url: str): try: image = await download_image(image_url) output_image = remove_bg(image) # Salvar a imagem temporariamente na memória img_io = io.BytesIO() output_image.save(img_io, format="PNG") img_io.seek(0) return { "message": "Fundo removido com sucesso!", "image": img_io.getvalue() } except Exception as e: raise HTTPException(status_code=500, detail=str(e))