Spaces:
Running
Running
from fastapi import FastAPI, File, UploadFile, HTTPException | |
from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor, AutoModelForSemanticSegmentation | |
from pydantic import BaseModel | |
from PIL import Image | |
import numpy as np | |
import io, base64, logging, requests, torch | |
import torch.nn as nn | |
# Inizializza l'app FastAPI | |
app = FastAPI() | |
# Add this class for the request body | |
class ImageURL(BaseModel): | |
url: str | |
# Configura il logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Carica il modello e il processore SegFormer | |
try: | |
logger.info("Caricamento del modello SegFormer...") | |
model = SegformerForSemanticSegmentation.from_pretrained("sayeed99/segformer-b3-fashion") | |
processor = SegformerImageProcessor.from_pretrained("sayeed99/segformer-b3-fashion") | |
model.to("cpu") # Usa CPU per il free tier | |
logger.info("Modello caricato con successo.") | |
except Exception as e: | |
logger.error(f"Errore nel caricamento del modello: {str(e)}") | |
raise RuntimeError(f"Errore nel caricamento del modello: {str(e)}") | |
# Add new model and processor initialization after existing ones | |
try: | |
logger.info("Loading clothes segmentation model...") | |
clothes_model = AutoModelForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes") | |
clothes_processor = SegformerImageProcessor.from_pretrained("mattmdjaga/segformer_b2_clothes") | |
clothes_model.to("cpu") | |
logger.info("Clothes model loaded successfully.") | |
except Exception as e: | |
logger.error(f"Error loading clothes model: {str(e)}") | |
raise RuntimeError(f"Error loading clothes model: {str(e)}") | |
# Funzione per segmentare l'immagine | |
def segment_image(image: Image.Image): | |
# Prepara l'input per SegFormer | |
logger.info("Preparazione dell'immagine per l'inferenza...") | |
inputs = processor(images=image, return_tensors="pt").to("cpu") | |
# Inferenza | |
logger.info("Esecuzione dell'inferenza...") | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
# Post-processa la maschera | |
logger.info("Post-processing della maschera...") | |
mask = torch.argmax(logits, dim=1)[0] | |
mask = mask.cpu().numpy() | |
# Converti la maschera in immagine | |
mask_img = Image.fromarray((mask * 255 / mask.max()).astype(np.uint8)) | |
# Converti la maschera in base64 per la risposta | |
buffered = io.BytesIO() | |
mask_img.save(buffered, format="PNG") | |
mask_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
# Annotazioni | |
annotations = {"mask": mask.tolist(), "label": logits } | |
return mask_base64, annotations | |
# Endpoint API | |
async def segment_endpoint(file: UploadFile = File(...)): | |
try: | |
logger.info("Ricezione del file...") | |
image_data = await file.read() | |
image = Image.open(io.BytesIO(image_data)).convert("RGB") | |
logger.info("Segmentazione dell'immagine...") | |
mask_base64, annotations = segment_image(image) | |
return { | |
"mask": f"data:image/png;base64,{mask_base64}", | |
"annotations": annotations | |
} | |
except Exception as e: | |
logger.error(f"Errore nell'endpoint: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Errore nell'elaborazione: {str(e)}") | |
# Add new endpoint | |
async def segment_url_endpoint(image_data: ImageURL): | |
try: | |
logger.info("Downloading image from URL...") | |
response = requests.get(image_data.url, stream=True) | |
if response.status_code != 200: | |
raise HTTPException(status_code=400, detail="Could not download image from URL") | |
# Open image from URL | |
image = Image.open(response.raw).convert("RGB") | |
# Process image with SegFormer | |
logger.info("Processing image...") | |
inputs = processor(images=image, return_tensors="pt") | |
outputs = model(**inputs) | |
logits = outputs.logits.cpu() | |
# Upsample logits to match original image size | |
upsampled_logits = nn.functional.interpolate( | |
logits, | |
size=image.size[::-1], | |
mode="bilinear", | |
align_corners=False, | |
) | |
# Get prediction | |
pred_seg = upsampled_logits.argmax(dim=1)[0] | |
# Convert to image | |
mask_img = Image.fromarray((pred_seg.numpy() * 255).astype(np.uint8)) | |
# Convert to base64 | |
buffered = io.BytesIO() | |
mask_img.save(buffered, format="PNG") | |
mask_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
return { | |
"mask": f"data:image/png;base64,{mask_base64}", | |
"size": image.size, | |
"labels" : pred_seg | |
} | |
except Exception as e: | |
logger.error(f"Error processing URL: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}") | |
# Add new endpoint | |
async def segment_clothes_url_endpoint(image_data: ImageURL): | |
try: | |
logger.info("Downloading image from URL...") | |
response = requests.get(image_data.url, stream=True) | |
if response.status_code != 200: | |
raise HTTPException(status_code=400, detail="Could not download image from URL") | |
# Open image from URL | |
image = Image.open(response.raw).convert("RGB") | |
# Process image with SegFormer | |
logger.info("Processing image...") | |
inputs = clothes_processor(images=image, return_tensors="pt") | |
outputs = clothes_model(**inputs) | |
logits = outputs.logits.cpu() | |
# Upsample logits to match original image size | |
upsampled_logits = nn.functional.interpolate( | |
logits, | |
size=image.size[::-1], | |
mode="bilinear", | |
align_corners=False, | |
) | |
# Get prediction | |
pred_seg = upsampled_logits.argmax(dim=1)[0] | |
# Convert to image | |
mask_img = Image.fromarray((pred_seg.numpy() * 255).astype(np.uint8)) | |
# Convert to base64 | |
buffered = io.BytesIO() | |
mask_img.save(buffered, format="PNG") | |
mask_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
return { | |
"mask": f"data:image/png;base64,{mask_base64}", | |
"size": image.size, | |
"predictions": pred_seg.numpy().tolist() | |
} | |
except Exception as e: | |
logger.error(f"Error processing URL: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}") | |
# Per compatibilità con Hugging Face Spaces | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) |