Spaces:
Running
Running
Alex
commited on
Commit
·
f6210c2
1
Parent(s):
7364060
switched to segformer fashion
Browse files
app.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
from fastapi import FastAPI, File, UploadFile, HTTPException
|
2 |
-
from transformers import
|
3 |
import torch
|
4 |
from PIL import Image
|
5 |
import numpy as np
|
@@ -14,11 +14,11 @@ app = FastAPI()
|
|
14 |
logging.basicConfig(level=logging.INFO)
|
15 |
logger = logging.getLogger(__name__)
|
16 |
|
17 |
-
# Carica il modello e il processore
|
18 |
try:
|
19 |
-
logger.info("Caricamento del modello
|
20 |
-
model =
|
21 |
-
processor =
|
22 |
model.to("cpu") # Usa CPU per il free tier
|
23 |
logger.info("Modello caricato con successo.")
|
24 |
except Exception as e:
|
@@ -27,23 +27,23 @@ except Exception as e:
|
|
27 |
|
28 |
# Funzione per segmentare l'immagine
|
29 |
def segment_image(image: Image.Image):
|
30 |
-
# Prepara l'input per
|
31 |
logger.info("Preparazione dell'immagine per l'inferenza...")
|
32 |
-
inputs = processor(image, return_tensors="pt").to("cpu")
|
33 |
|
34 |
# Inferenza
|
35 |
logger.info("Esecuzione dell'inferenza...")
|
36 |
with torch.no_grad():
|
37 |
-
outputs = model(**inputs
|
38 |
-
|
|
|
39 |
# Post-processa la maschera
|
40 |
logger.info("Post-processing della maschera...")
|
41 |
-
mask =
|
42 |
-
|
43 |
-
)[0][0].cpu().numpy()
|
44 |
|
45 |
# Converti la maschera in immagine
|
46 |
-
mask_img = Image.fromarray((mask * 255).astype(np.uint8))
|
47 |
|
48 |
# Converti la maschera in base64 per la risposta
|
49 |
buffered = io.BytesIO()
|
@@ -51,7 +51,7 @@ def segment_image(image: Image.Image):
|
|
51 |
mask_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
52 |
|
53 |
# Annotazioni
|
54 |
-
annotations = {"mask": mask.tolist(), "label": "
|
55 |
|
56 |
return mask_base64, annotations
|
57 |
|
@@ -74,7 +74,7 @@ async def segment_endpoint(file: UploadFile = File(...)):
|
|
74 |
logger.error(f"Errore nell'endpoint: {str(e)}")
|
75 |
raise HTTPException(status_code=500, detail=f"Errore nell'elaborazione: {str(e)}")
|
76 |
|
77 |
-
# Per compatibilità con Hugging Face Spaces
|
78 |
if __name__ == "__main__":
|
79 |
import uvicorn
|
80 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|
1 |
from fastapi import FastAPI, File, UploadFile, HTTPException
|
2 |
+
from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
|
3 |
import torch
|
4 |
from PIL import Image
|
5 |
import numpy as np
|
|
|
14 |
logging.basicConfig(level=logging.INFO)
|
15 |
logger = logging.getLogger(__name__)
|
16 |
|
17 |
+
# Carica il modello e il processore SegFormer
|
18 |
try:
|
19 |
+
logger.info("Caricamento del modello SegFormer...")
|
20 |
+
model = SegformerForSemanticSegmentation.from_pretrained("sayeed99/segformer-b3-fashion")
|
21 |
+
processor = SegformerImageProcessor.from_pretrained("sayeed99/segformer-b3-fashion")
|
22 |
model.to("cpu") # Usa CPU per il free tier
|
23 |
logger.info("Modello caricato con successo.")
|
24 |
except Exception as e:
|
|
|
27 |
|
28 |
# Funzione per segmentare l'immagine
|
29 |
def segment_image(image: Image.Image):
|
30 |
+
# Prepara l'input per SegFormer
|
31 |
logger.info("Preparazione dell'immagine per l'inferenza...")
|
32 |
+
inputs = processor(images=image, return_tensors="pt").to("cpu")
|
33 |
|
34 |
# Inferenza
|
35 |
logger.info("Esecuzione dell'inferenza...")
|
36 |
with torch.no_grad():
|
37 |
+
outputs = model(**inputs)
|
38 |
+
logits = outputs.logits
|
39 |
+
|
40 |
# Post-processa la maschera
|
41 |
logger.info("Post-processing della maschera...")
|
42 |
+
mask = torch.argmax(logits, dim=1)[0]
|
43 |
+
mask = mask.cpu().numpy()
|
|
|
44 |
|
45 |
# Converti la maschera in immagine
|
46 |
+
mask_img = Image.fromarray((mask * 255 / mask.max()).astype(np.uint8))
|
47 |
|
48 |
# Converti la maschera in base64 per la risposta
|
49 |
buffered = io.BytesIO()
|
|
|
51 |
mask_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
52 |
|
53 |
# Annotazioni
|
54 |
+
annotations = {"mask": mask.tolist(), "label": "fashion"}
|
55 |
|
56 |
return mask_base64, annotations
|
57 |
|
|
|
74 |
logger.error(f"Errore nell'endpoint: {str(e)}")
|
75 |
raise HTTPException(status_code=500, detail=f"Errore nell'elaborazione: {str(e)}")
|
76 |
|
77 |
+
# Per compatibilità con Hugging Face Spaces
|
78 |
if __name__ == "__main__":
|
79 |
import uvicorn
|
80 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|