habulaj commited on
Commit
4b2ecc5
·
verified ·
1 Parent(s): dcb31a8

Rename main.py to app.py

Browse files
Files changed (2) hide show
  1. app.py +27 -0
  2. main.py +0 -89
app.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ import httpx
3
+ from rembg import remove
4
+ from io import BytesIO
5
+ from fastapi.responses import StreamingResponse
6
+
7
+ app = FastAPI()
8
+
9
+ @app.get("/remove-background/")
10
+ async def remove_background(image_url: str):
11
+ try:
12
+ # Baixa a imagem da URL fornecida
13
+ async with httpx.AsyncClient() as client:
14
+ response = await client.get(image_url)
15
+ response.raise_for_status() # Verifica se a URL foi bem-sucedida
16
+ image_data = response.content
17
+
18
+ # Remover o fundo da imagem usando o rembg
19
+ output_image = remove(image_data)
20
+
21
+ # Retorna a imagem com o fundo removido como resposta
22
+ return StreamingResponse(BytesIO(output_image), media_type="image/png")
23
+
24
+ except httpx.RequestError as e:
25
+ raise HTTPException(status_code=400, detail="Erro ao baixar a imagem: " + str(e))
26
+ except Exception as e:
27
+ raise HTTPException(status_code=500, detail="Erro ao processar a imagem: " + str(e))
main.py DELETED
@@ -1,89 +0,0 @@
1
- import numpy as np
2
- import torch
3
- import torch.nn.functional as F
4
- from torchvision.transforms.functional import normalize
5
- from huggingface_hub import hf_hub_download
6
- from fastapi import FastAPI, HTTPException
7
- from PIL import Image
8
- import httpx
9
- import io
10
- import briarmbg # Importando o modelo de remoção de fundo
11
-
12
- app = FastAPI()
13
-
14
- # Carregar modelo BRIA RMBG
15
- net = briarmbg.BriaRMBG()
16
- model_path = hf_hub_download("briaai/RMBG-1.4", "model.pth")
17
-
18
- if torch.cuda.is_available():
19
- net.load_state_dict(torch.load(model_path))
20
- net = net.cuda()
21
- else:
22
- net.load_state_dict(torch.load(model_path, map_location="cpu"))
23
-
24
- net.eval()
25
-
26
- # Função para redimensionar a imagem antes de processar
27
- def resize_image(image):
28
- image = image.convert("RGB")
29
- model_input_size = (1024, 1024)
30
- return image.resize(model_input_size, Image.BILINEAR)
31
-
32
- # Função para remover o fundo da imagem
33
- def remove_bg(image: Image.Image):
34
- orig_image = image
35
- w, h = orig_image.size
36
- image = resize_image(orig_image)
37
-
38
- im_np = np.array(image)
39
- im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
40
- im_tensor = torch.unsqueeze(im_tensor, 0) / 255.0
41
- im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
42
-
43
- if torch.cuda.is_available():
44
- im_tensor = im_tensor.cuda()
45
-
46
- # Inference do modelo
47
- result = net(im_tensor)[0][0]
48
- result = torch.squeeze(F.interpolate(result, size=(h, w), mode="bilinear"), 0)
49
-
50
- # Normalizar para 0-255
51
- ma, mi = torch.max(result), torch.min(result)
52
- result = (result - mi) / (ma - mi)
53
-
54
- im_array = (result * 255).cpu().data.numpy().astype(np.uint8)
55
- pil_im = Image.fromarray(np.squeeze(im_array))
56
-
57
- # Criar imagem sem fundo
58
- new_im = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
59
- new_im.paste(orig_image, mask=pil_im)
60
-
61
- return new_im
62
-
63
- # Função para baixar imagem de uma URL
64
- async def download_image(image_url: str) -> Image.Image:
65
- async with httpx.AsyncClient() as client:
66
- response = await client.get(image_url)
67
- if response.status_code != 200:
68
- raise HTTPException(status_code=400, detail="Erro ao baixar imagem")
69
-
70
- return Image.open(io.BytesIO(response.content))
71
-
72
- # Endpoint para remover fundo
73
- @app.get("/remove-bg/")
74
- async def remove_bg_from_url(image_url: str):
75
- try:
76
- image = await download_image(image_url)
77
- output_image = remove_bg(image)
78
-
79
- # Salvar a imagem temporariamente na memória
80
- img_io = io.BytesIO()
81
- output_image.save(img_io, format="PNG")
82
- img_io.seek(0)
83
-
84
- return {
85
- "message": "Fundo removido com sucesso!",
86
- "image": img_io.getvalue()
87
- }
88
- except Exception as e:
89
- raise HTTPException(status_code=500, detail=str(e))