Spaces:
Running
Running
from fastapi import FastAPI, File, UploadFile | |
from pydantic import BaseModel | |
from transformers import SamModel, SamProcessor | |
import torch | |
from PIL import Image | |
import numpy as np | |
import io | |
import base64 | |
class ImageRequest(BaseModel): | |
image_base64: str | |
# Inizializza l'app FastAPI | |
app = FastAPI() | |
# Carica il modello e il processore SAM | |
model = SamModel.from_pretrained("facebook/sam-vit-base") | |
processor = SamProcessor.from_pretrained("facebook/sam-vit-base") | |
model.to("cpu") # Usa CPU per il free tier | |
async def health_check(): | |
return {"status": "ok"} | |
def preprocess_image(image: Image.Image, size=(320, 320)): | |
"""Ridimensiona l'immagine per velocizzare l'inferenza""" | |
img = image.convert("RGB") | |
img = img.resize(size, Image.LANCZOS) # 320x320 è veloce su CPU | |
return img | |
# Funzione per segmentare l'immagine | |
def segment_image(image: Image.Image): | |
# Prepara l'input per SAM | |
inputs = processor(image, return_tensors="pt").to("cpu") | |
# Inferenza | |
with torch.no_grad(): | |
outputs = model(**inputs, multimask_output=False) | |
# Post-processa la maschera | |
mask = processor.image_processor.post_process_masks( | |
outputs.pred_masks, inputs["original_sizes"], inputs["reshaped_input_sizes"] | |
)[0][0].cpu().numpy() | |
# Converti la maschera in immagine | |
mask_img = Image.fromarray((mask * 255).astype(np.uint8)) | |
# Converti la maschera in base64 per la risposta | |
buffered = io.BytesIO() | |
mask_img.save(buffered, format="PNG") | |
mask_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
# Annotazioni | |
annotations = {"mask": mask.tolist(), "label": "object"} | |
return mask_base64, annotations | |
# Endpoint API | |
# @app.post("/segment") | |
async def segment_endpoint(file: ImageRequest): | |
try: | |
# Decodifica la stringa Base64 | |
image_data = base64.b64decode(file.image_base64) | |
image = Image.open(io.BytesIO(image_data)) | |
image = preprocess_image(image) | |
# Segmenta l'immagine | |
mask_base64, annotations = segment_image(image) | |
# Restituisci la risposta | |
return { | |
"mask": f"data:image/png;base64,{mask_base64}", | |
"annotations": annotations | |
} | |
except Exception as e: | |
# In caso di errore (es. Base64 non valido), restituisci False | |
return {"output": False, "error": str(e), "debug": file} | |
# Per compatibilità con Hugging Face Spaces (Uvicorn viene gestito automaticamente) | |
# if __name__ == "__main__": | |
# import uvicorn | |
# uvicorn.run(app, host="0.0.0.0", port=7860) |