segmentation / app.py
Alex
clothes segmentation
0e3833c
raw
history blame
6.82 kB
from fastapi import FastAPI, File, UploadFile, HTTPException
from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor, AutoModelForSemanticSegmentation
from pydantic import BaseModel
from PIL import Image
import numpy as np
import io, base64, logging, requests, torch
import torch.nn as nn
# Inizializza l'app FastAPI
app = FastAPI()
# Add this class for the request body
class ImageURL(BaseModel):
url: str
# Configura il logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Carica il modello e il processore SegFormer
try:
logger.info("Caricamento del modello SegFormer...")
model = SegformerForSemanticSegmentation.from_pretrained("sayeed99/segformer-b3-fashion")
processor = SegformerImageProcessor.from_pretrained("sayeed99/segformer-b3-fashion")
model.to("cpu") # Usa CPU per il free tier
logger.info("Modello caricato con successo.")
except Exception as e:
logger.error(f"Errore nel caricamento del modello: {str(e)}")
raise RuntimeError(f"Errore nel caricamento del modello: {str(e)}")
# Add new model and processor initialization after existing ones
try:
logger.info("Loading clothes segmentation model...")
clothes_model = AutoModelForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes")
clothes_processor = SegformerImageProcessor.from_pretrained("mattmdjaga/segformer_b2_clothes")
clothes_model.to("cpu")
logger.info("Clothes model loaded successfully.")
except Exception as e:
logger.error(f"Error loading clothes model: {str(e)}")
raise RuntimeError(f"Error loading clothes model: {str(e)}")
# Funzione per segmentare l'immagine
def segment_image(image: Image.Image):
# Prepara l'input per SegFormer
logger.info("Preparazione dell'immagine per l'inferenza...")
inputs = processor(images=image, return_tensors="pt").to("cpu")
# Inferenza
logger.info("Esecuzione dell'inferenza...")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
# Post-processa la maschera
logger.info("Post-processing della maschera...")
mask = torch.argmax(logits, dim=1)[0]
mask = mask.cpu().numpy()
# Converti la maschera in immagine
mask_img = Image.fromarray((mask * 255 / mask.max()).astype(np.uint8))
# Converti la maschera in base64 per la risposta
buffered = io.BytesIO()
mask_img.save(buffered, format="PNG")
mask_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
# Annotazioni
annotations = {"mask": mask.tolist(), "label": logits }
return mask_base64, annotations
# Endpoint API
@app.post("/segment")
async def segment_endpoint(file: UploadFile = File(...)):
try:
logger.info("Ricezione del file...")
image_data = await file.read()
image = Image.open(io.BytesIO(image_data)).convert("RGB")
logger.info("Segmentazione dell'immagine...")
mask_base64, annotations = segment_image(image)
return {
"mask": f"data:image/png;base64,{mask_base64}",
"annotations": annotations
}
except Exception as e:
logger.error(f"Errore nell'endpoint: {str(e)}")
raise HTTPException(status_code=500, detail=f"Errore nell'elaborazione: {str(e)}")
# Add new endpoint
@app.post("/segment-url")
async def segment_url_endpoint(image_data: ImageURL):
try:
logger.info("Downloading image from URL...")
response = requests.get(image_data.url, stream=True)
if response.status_code != 200:
raise HTTPException(status_code=400, detail="Could not download image from URL")
# Open image from URL
image = Image.open(response.raw).convert("RGB")
# Process image with SegFormer
logger.info("Processing image...")
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits.cpu()
# Upsample logits to match original image size
upsampled_logits = nn.functional.interpolate(
logits,
size=image.size[::-1],
mode="bilinear",
align_corners=False,
)
# Get prediction
pred_seg = upsampled_logits.argmax(dim=1)[0]
# Convert to image
mask_img = Image.fromarray((pred_seg.numpy() * 255).astype(np.uint8))
# Convert to base64
buffered = io.BytesIO()
mask_img.save(buffered, format="PNG")
mask_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
return {
"mask": f"data:image/png;base64,{mask_base64}",
"size": image.size,
"labels" : pred_seg
}
except Exception as e:
logger.error(f"Error processing URL: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
# Add new endpoint
@app.post("/segment-clothes-url")
async def segment_clothes_url_endpoint(image_data: ImageURL):
try:
logger.info("Downloading image from URL...")
response = requests.get(image_data.url, stream=True)
if response.status_code != 200:
raise HTTPException(status_code=400, detail="Could not download image from URL")
# Open image from URL
image = Image.open(response.raw).convert("RGB")
# Process image with SegFormer
logger.info("Processing image...")
inputs = clothes_processor(images=image, return_tensors="pt")
outputs = clothes_model(**inputs)
logits = outputs.logits.cpu()
# Upsample logits to match original image size
upsampled_logits = nn.functional.interpolate(
logits,
size=image.size[::-1],
mode="bilinear",
align_corners=False,
)
# Get prediction
pred_seg = upsampled_logits.argmax(dim=1)[0]
# Convert to image
mask_img = Image.fromarray((pred_seg.numpy() * 255).astype(np.uint8))
# Convert to base64
buffered = io.BytesIO()
mask_img.save(buffered, format="PNG")
mask_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
return {
"mask": f"data:image/png;base64,{mask_base64}",
"size": image.size,
"predictions": pred_seg.numpy().tolist()
}
except Exception as e:
logger.error(f"Error processing URL: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
# Per compatibilità con Hugging Face Spaces
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)