Spaces:
Running
Running
File size: 6,824 Bytes
7364060 0e3833c 077a679 7077928 0e3833c 077a679 7077928 3b62419 077a679 7364060 f6210c2 7364060 f6210c2 7364060 7077928 0e3833c 3b62419 f6210c2 7364060 f6210c2 7077928 7364060 7077928 f6210c2 3b62419 7364060 f6210c2 7077928 3b62419 f6210c2 7077928 3b62419 7077928 3b62419 4e5e0e7 3b62419 7077928 3b62419 779723a 7364060 7077928 077a679 0e3833c 077a679 0e3833c f6210c2 3b62419 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
from fastapi import FastAPI, File, UploadFile, HTTPException
from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor, AutoModelForSemanticSegmentation
from pydantic import BaseModel
from PIL import Image
import numpy as np
import io, base64, logging, requests, torch
import torch.nn as nn
# Inizializza l'app FastAPI
app = FastAPI()
# Add this class for the request body
class ImageURL(BaseModel):
url: str
# Configura il logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Carica il modello e il processore SegFormer
try:
logger.info("Caricamento del modello SegFormer...")
model = SegformerForSemanticSegmentation.from_pretrained("sayeed99/segformer-b3-fashion")
processor = SegformerImageProcessor.from_pretrained("sayeed99/segformer-b3-fashion")
model.to("cpu") # Usa CPU per il free tier
logger.info("Modello caricato con successo.")
except Exception as e:
logger.error(f"Errore nel caricamento del modello: {str(e)}")
raise RuntimeError(f"Errore nel caricamento del modello: {str(e)}")
# Add new model and processor initialization after existing ones
try:
logger.info("Loading clothes segmentation model...")
clothes_model = AutoModelForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes")
clothes_processor = SegformerImageProcessor.from_pretrained("mattmdjaga/segformer_b2_clothes")
clothes_model.to("cpu")
logger.info("Clothes model loaded successfully.")
except Exception as e:
logger.error(f"Error loading clothes model: {str(e)}")
raise RuntimeError(f"Error loading clothes model: {str(e)}")
# Funzione per segmentare l'immagine
def segment_image(image: Image.Image):
# Prepara l'input per SegFormer
logger.info("Preparazione dell'immagine per l'inferenza...")
inputs = processor(images=image, return_tensors="pt").to("cpu")
# Inferenza
logger.info("Esecuzione dell'inferenza...")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
# Post-processa la maschera
logger.info("Post-processing della maschera...")
mask = torch.argmax(logits, dim=1)[0]
mask = mask.cpu().numpy()
# Converti la maschera in immagine
mask_img = Image.fromarray((mask * 255 / mask.max()).astype(np.uint8))
# Converti la maschera in base64 per la risposta
buffered = io.BytesIO()
mask_img.save(buffered, format="PNG")
mask_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
# Annotazioni
annotations = {"mask": mask.tolist(), "label": logits }
return mask_base64, annotations
# Endpoint API
@app.post("/segment")
async def segment_endpoint(file: UploadFile = File(...)):
try:
logger.info("Ricezione del file...")
image_data = await file.read()
image = Image.open(io.BytesIO(image_data)).convert("RGB")
logger.info("Segmentazione dell'immagine...")
mask_base64, annotations = segment_image(image)
return {
"mask": f"data:image/png;base64,{mask_base64}",
"annotations": annotations
}
except Exception as e:
logger.error(f"Errore nell'endpoint: {str(e)}")
raise HTTPException(status_code=500, detail=f"Errore nell'elaborazione: {str(e)}")
# Add new endpoint
@app.post("/segment-url")
async def segment_url_endpoint(image_data: ImageURL):
try:
logger.info("Downloading image from URL...")
response = requests.get(image_data.url, stream=True)
if response.status_code != 200:
raise HTTPException(status_code=400, detail="Could not download image from URL")
# Open image from URL
image = Image.open(response.raw).convert("RGB")
# Process image with SegFormer
logger.info("Processing image...")
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits.cpu()
# Upsample logits to match original image size
upsampled_logits = nn.functional.interpolate(
logits,
size=image.size[::-1],
mode="bilinear",
align_corners=False,
)
# Get prediction
pred_seg = upsampled_logits.argmax(dim=1)[0]
# Convert to image
mask_img = Image.fromarray((pred_seg.numpy() * 255).astype(np.uint8))
# Convert to base64
buffered = io.BytesIO()
mask_img.save(buffered, format="PNG")
mask_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
return {
"mask": f"data:image/png;base64,{mask_base64}",
"size": image.size,
"labels" : pred_seg
}
except Exception as e:
logger.error(f"Error processing URL: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
# Add new endpoint
@app.post("/segment-clothes-url")
async def segment_clothes_url_endpoint(image_data: ImageURL):
try:
logger.info("Downloading image from URL...")
response = requests.get(image_data.url, stream=True)
if response.status_code != 200:
raise HTTPException(status_code=400, detail="Could not download image from URL")
# Open image from URL
image = Image.open(response.raw).convert("RGB")
# Process image with SegFormer
logger.info("Processing image...")
inputs = clothes_processor(images=image, return_tensors="pt")
outputs = clothes_model(**inputs)
logits = outputs.logits.cpu()
# Upsample logits to match original image size
upsampled_logits = nn.functional.interpolate(
logits,
size=image.size[::-1],
mode="bilinear",
align_corners=False,
)
# Get prediction
pred_seg = upsampled_logits.argmax(dim=1)[0]
# Convert to image
mask_img = Image.fromarray((pred_seg.numpy() * 255).astype(np.uint8))
# Convert to base64
buffered = io.BytesIO()
mask_img.save(buffered, format="PNG")
mask_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
return {
"mask": f"data:image/png;base64,{mask_base64}",
"size": image.size,
"predictions": pred_seg.numpy().tolist()
}
except Exception as e:
logger.error(f"Error processing URL: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
# Per compatibilità con Hugging Face Spaces
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860) |