habulaj commited on
Commit
6a37b07
·
verified ·
1 Parent(s): a4ccf85

Update app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +89 -1
app/main.py CHANGED
@@ -1 +1,89 @@
1
- #
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torchvision.transforms.functional import normalize
5
+ from huggingface_hub import hf_hub_download
6
+ from fastapi import FastAPI, HTTPException
7
+ from PIL import Image
8
+ import httpx
9
+ import io
10
+ import briarmbg # Importando o modelo de remoção de fundo
11
+
12
+ app = FastAPI()
13
+
14
+ # Carregar modelo BRIA RMBG
15
+ net = briarmbg.BriaRMBG()
16
+ model_path = hf_hub_download("briaai/RMBG-1.4", "model.pth")
17
+
18
+ if torch.cuda.is_available():
19
+ net.load_state_dict(torch.load(model_path))
20
+ net = net.cuda()
21
+ else:
22
+ net.load_state_dict(torch.load(model_path, map_location="cpu"))
23
+
24
+ net.eval()
25
+
26
+ # Função para redimensionar a imagem antes de processar
27
+ def resize_image(image):
28
+ image = image.convert("RGB")
29
+ model_input_size = (1024, 1024)
30
+ return image.resize(model_input_size, Image.BILINEAR)
31
+
32
+ # Função para remover o fundo da imagem
33
+ def remove_bg(image: Image.Image):
34
+ orig_image = image
35
+ w, h = orig_image.size
36
+ image = resize_image(orig_image)
37
+
38
+ im_np = np.array(image)
39
+ im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
40
+ im_tensor = torch.unsqueeze(im_tensor, 0) / 255.0
41
+ im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
42
+
43
+ if torch.cuda.is_available():
44
+ im_tensor = im_tensor.cuda()
45
+
46
+ # Inference do modelo
47
+ result = net(im_tensor)[0][0]
48
+ result = torch.squeeze(F.interpolate(result, size=(h, w), mode="bilinear"), 0)
49
+
50
+ # Normalizar para 0-255
51
+ ma, mi = torch.max(result), torch.min(result)
52
+ result = (result - mi) / (ma - mi)
53
+
54
+ im_array = (result * 255).cpu().data.numpy().astype(np.uint8)
55
+ pil_im = Image.fromarray(np.squeeze(im_array))
56
+
57
+ # Criar imagem sem fundo
58
+ new_im = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
59
+ new_im.paste(orig_image, mask=pil_im)
60
+
61
+ return new_im
62
+
63
+ # Função para baixar imagem de uma URL
64
+ async def download_image(image_url: str) -> Image.Image:
65
+ async with httpx.AsyncClient() as client:
66
+ response = await client.get(image_url)
67
+ if response.status_code != 200:
68
+ raise HTTPException(status_code=400, detail="Erro ao baixar imagem")
69
+
70
+ return Image.open(io.BytesIO(response.content))
71
+
72
+ # Endpoint para remover fundo
73
+ @app.get("/remove-bg/")
74
+ async def remove_bg_from_url(image_url: str):
75
+ try:
76
+ image = await download_image(image_url)
77
+ output_image = remove_bg(image)
78
+
79
+ # Salvar a imagem temporariamente na memória
80
+ img_io = io.BytesIO()
81
+ output_image.save(img_io, format="PNG")
82
+ img_io.seek(0)
83
+
84
+ return {
85
+ "message": "Fundo removido com sucesso!",
86
+ "image": img_io.getvalue()
87
+ }
88
+ except Exception as e:
89
+ raise HTTPException(status_code=500, detail=str(e))