from fastapi import FastAPI, File, UploadFile from pydantic import BaseModel from transformers import SamModel, SamProcessor import torch from PIL import Image import numpy as np import io import base64 class ImageRequest(BaseModel): image_base64: str # Inizializza l'app FastAPI app = FastAPI() # Carica il modello e il processore SAM model = SamModel.from_pretrained("facebook/sam-vit-base") processor = SamProcessor.from_pretrained("facebook/sam-vit-base") model.to("cpu") # Usa CPU per il free tier @app.get("/health") async def health_check(): return {"status": "ok"} def preprocess_image(image: Image.Image, size=(320, 320)): """Ridimensiona l'immagine per velocizzare l'inferenza""" img = image.convert("RGB") img = img.resize(size, Image.LANCZOS) # 320x320 è veloce su CPU return img # Funzione per segmentare l'immagine def segment_image(image: Image.Image): # Prepara l'input per SAM inputs = processor(image, return_tensors="pt").to("cpu") # Inferenza with torch.no_grad(): outputs = model(**inputs, multimask_output=False) # Post-processa la maschera mask = processor.image_processor.post_process_masks( outputs.pred_masks, inputs["original_sizes"], inputs["reshaped_input_sizes"] )[0][0].cpu().numpy() # Converti la maschera in immagine mask_img = Image.fromarray((mask * 255).astype(np.uint8)) # Converti la maschera in base64 per la risposta buffered = io.BytesIO() mask_img.save(buffered, format="PNG") mask_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8") # Annotazioni annotations = {"mask": mask.tolist(), "label": "object"} return mask_base64, annotations # Endpoint API # @app.post("/segment") async def segment_endpoint(file: ImageRequest): try: # Decodifica la stringa Base64 image_data = base64.b64decode(file.image_base64) image = Image.open(io.BytesIO(image_data)) image = preprocess_image(image) # Segmenta l'immagine mask_base64, annotations = segment_image(image) # Restituisci la risposta return { "mask": f"data:image/png;base64,{mask_base64}", "annotations": annotations } except Exception as e: # In caso di errore (es. Base64 non valido), restituisci False return {"output": False, "error": str(e), "debug": file} # Per compatibilità con Hugging Face Spaces (Uvicorn viene gestito automaticamente) # if __name__ == "__main__": # import uvicorn # uvicorn.run(app, host="0.0.0.0", port=7860)