segmentation / app.py
Alex
update endpoint path
f8b9c38
from fastapi import FastAPI, File, UploadFile, HTTPException
from transformers import SegformerImageProcessor
from huggingface_hub import hf_hub_download
from pydantic import BaseModel
from PIL import Image
import numpy as np
import io, base64, logging, requests, os
import onnxruntime as ort
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
class ImageURL(BaseModel):
url: str
class ModelManager:
def __init__(self):
self.logger = logging.getLogger(__name__)
# self.token = os.getenv("HF_TOKEN")
# if not self.token:
# raise ValueError("HF_TOKEN environment variable is required")
self._initialize_models()
def _initialize_models(self):
try:
# Initialize ONNX runtime sessions
self.logger.info("Loading ONNX models...")
# Download and load fashion model
fashion_path = hf_hub_download(
repo_id="alexgenovese/segformer-onnx",
filename="segformer-b3-fashion.onnx",
# token=self.token
)
self.fashion_model = ort.InferenceSession(fashion_path)
self.fashion_processor = SegformerImageProcessor.from_pretrained(
"sayeed99/segformer-b3-fashion",
# token=self.token
)
# Download and load clothes model
clothes_path = hf_hub_download(
repo_id="alexgenovese/segformer-onnx",
filename="segformer_b2_clothes.onnx",
# token=self.token
)
self.clothes_model = ort.InferenceSession(clothes_path)
self.clothes_processor = SegformerImageProcessor.from_pretrained(
"mattmdjaga/segformer_b2_clothes",
# token=self.token
)
self.logger.info("All models loaded successfully.")
except Exception as e:
self.logger.error(f"Error initializing models: {str(e)}")
raise RuntimeError(f"Error initializing models: {str(e)}")
def process_fashion_image(self, image: Image.Image):
inputs = self.fashion_processor(images=image, return_tensors="np")
onnx_inputs = {
'input': inputs['pixel_values']
}
logits = self.fashion_model.run(None, onnx_inputs)[0]
return self._post_process_outputs(logits, image.size)
def process_clothes_image(self, image: Image.Image):
inputs = self.clothes_processor(images=image, return_tensors="np")
onnx_inputs = {
'input': inputs['pixel_values']
}
logits = self.clothes_model.run(None, onnx_inputs)[0]
return self._post_process_outputs(logits, image.size)
def _post_process_outputs(self, logits, image_size):
# Convert logits to proper shape for processing
logits = np.array(logits)
# Resize prediction to match original image size
from skimage.transform import resize
resized_logits = resize(
logits[0],
(image_size[1], image_size[0]),
order=1,
preserve_range=True,
mode='reflect'
)
# Get prediction
pred_seg = np.argmax(resized_logits, axis=0)
mask_img = Image.fromarray((pred_seg * 255).astype(np.uint8))
# Convert to base64
buffered = io.BytesIO()
mask_img.save(buffered, format="PNG")
mask_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
return {
"mask": f"data:image/png;base64,{mask_base64}",
"size": image_size,
"predictions": pred_seg.tolist()
}
# Initialize FastAPI and ModelManager
app = FastAPI()
model_manager = ModelManager()
@app.post("/segment-clothes-url")
async def segment_clothes_url_endpoint(image_data: ImageURL):
try:
response = requests.get(image_data.url, stream=True)
if response.status_code != 200:
raise HTTPException(status_code=400, detail="Could not download image from URL")
image = Image.open(response.raw).convert("RGB")
return model_manager.process_clothes_image(image)
except Exception as e:
logging.error(f"Error processing URL: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
@app.post("/segment-fashion-url")
async def segment_url_endpoint(image_data: ImageURL):
try:
response = requests.get(image_data.url, stream=True)
if response.status_code != 200:
raise HTTPException(status_code=400, detail="Could not download image from URL")
image = Image.open(response.raw).convert("RGB")
return model_manager.process_fashion_image(image)
except Exception as e:
logging.error(f"Error processing URL: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
@app.post("/segment-fashion-file")
async def segment_endpoint(file: UploadFile = File(...)):
try:
image_data = await file.read()
image = Image.open(io.BytesIO(image_data)).convert("RGB")
return model_manager.process_fashion_image(image)
except Exception as e:
logging.error(f"Error in endpoint: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error processing: {str(e)}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)