Alex commited on
Commit
f6210c2
·
1 Parent(s): 7364060

switched to segformer fashion

Browse files
Files changed (1) hide show
  1. app.py +15 -15
app.py CHANGED
@@ -1,5 +1,5 @@
1
  from fastapi import FastAPI, File, UploadFile, HTTPException
2
- from transformers import SamModel, SamProcessor
3
  import torch
4
  from PIL import Image
5
  import numpy as np
@@ -14,11 +14,11 @@ app = FastAPI()
14
  logging.basicConfig(level=logging.INFO)
15
  logger = logging.getLogger(__name__)
16
 
17
- # Carica il modello e il processore SAM
18
  try:
19
- logger.info("Caricamento del modello SAM...")
20
- model = SamModel.from_pretrained("facebook/sam-vit-base")
21
- processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
22
  model.to("cpu") # Usa CPU per il free tier
23
  logger.info("Modello caricato con successo.")
24
  except Exception as e:
@@ -27,23 +27,23 @@ except Exception as e:
27
 
28
  # Funzione per segmentare l'immagine
29
  def segment_image(image: Image.Image):
30
- # Prepara l'input per SAM
31
  logger.info("Preparazione dell'immagine per l'inferenza...")
32
- inputs = processor(image, return_tensors="pt").to("cpu")
33
 
34
  # Inferenza
35
  logger.info("Esecuzione dell'inferenza...")
36
  with torch.no_grad():
37
- outputs = model(**inputs, multimask_output=False)
38
-
 
39
  # Post-processa la maschera
40
  logger.info("Post-processing della maschera...")
41
- mask = processor.image_processor.post_process_masks(
42
- outputs.pred_masks, inputs["original_sizes"], inputs["reshaped_input_sizes"]
43
- )[0][0].cpu().numpy()
44
 
45
  # Converti la maschera in immagine
46
- mask_img = Image.fromarray((mask * 255).astype(np.uint8))
47
 
48
  # Converti la maschera in base64 per la risposta
49
  buffered = io.BytesIO()
@@ -51,7 +51,7 @@ def segment_image(image: Image.Image):
51
  mask_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
52
 
53
  # Annotazioni
54
- annotations = {"mask": mask.tolist(), "label": "object"}
55
 
56
  return mask_base64, annotations
57
 
@@ -74,7 +74,7 @@ async def segment_endpoint(file: UploadFile = File(...)):
74
  logger.error(f"Errore nell'endpoint: {str(e)}")
75
  raise HTTPException(status_code=500, detail=f"Errore nell'elaborazione: {str(e)}")
76
 
77
- # Per compatibilità con Hugging Face Spaces (Uvicorn viene gestito automaticamente)
78
  if __name__ == "__main__":
79
  import uvicorn
80
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
  from fastapi import FastAPI, File, UploadFile, HTTPException
2
+ from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
3
  import torch
4
  from PIL import Image
5
  import numpy as np
 
14
  logging.basicConfig(level=logging.INFO)
15
  logger = logging.getLogger(__name__)
16
 
17
+ # Carica il modello e il processore SegFormer
18
  try:
19
+ logger.info("Caricamento del modello SegFormer...")
20
+ model = SegformerForSemanticSegmentation.from_pretrained("sayeed99/segformer-b3-fashion")
21
+ processor = SegformerImageProcessor.from_pretrained("sayeed99/segformer-b3-fashion")
22
  model.to("cpu") # Usa CPU per il free tier
23
  logger.info("Modello caricato con successo.")
24
  except Exception as e:
 
27
 
28
  # Funzione per segmentare l'immagine
29
  def segment_image(image: Image.Image):
30
+ # Prepara l'input per SegFormer
31
  logger.info("Preparazione dell'immagine per l'inferenza...")
32
+ inputs = processor(images=image, return_tensors="pt").to("cpu")
33
 
34
  # Inferenza
35
  logger.info("Esecuzione dell'inferenza...")
36
  with torch.no_grad():
37
+ outputs = model(**inputs)
38
+ logits = outputs.logits
39
+
40
  # Post-processa la maschera
41
  logger.info("Post-processing della maschera...")
42
+ mask = torch.argmax(logits, dim=1)[0]
43
+ mask = mask.cpu().numpy()
 
44
 
45
  # Converti la maschera in immagine
46
+ mask_img = Image.fromarray((mask * 255 / mask.max()).astype(np.uint8))
47
 
48
  # Converti la maschera in base64 per la risposta
49
  buffered = io.BytesIO()
 
51
  mask_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
52
 
53
  # Annotazioni
54
+ annotations = {"mask": mask.tolist(), "label": "fashion"}
55
 
56
  return mask_base64, annotations
57
 
 
74
  logger.error(f"Errore nell'endpoint: {str(e)}")
75
  raise HTTPException(status_code=500, detail=f"Errore nell'elaborazione: {str(e)}")
76
 
77
+ # Per compatibilità con Hugging Face Spaces
78
  if __name__ == "__main__":
79
  import uvicorn
80
  uvicorn.run(app, host="0.0.0.0", port=7860)