from fastapi import FastAPI, File, UploadFile, HTTPException from transformers import SegformerImageProcessor from huggingface_hub import hf_hub_download from pydantic import BaseModel from PIL import Image import numpy as np import io, base64, logging, requests, os import onnxruntime as ort from dotenv import load_dotenv # Load environment variables load_dotenv() class ImageURL(BaseModel): url: str class ModelManager: def __init__(self): self.logger = logging.getLogger(__name__) # self.token = os.getenv("HF_TOKEN") # if not self.token: # raise ValueError("HF_TOKEN environment variable is required") self._initialize_models() def _initialize_models(self): try: # Initialize ONNX runtime sessions self.logger.info("Loading ONNX models...") # Download and load fashion model fashion_path = hf_hub_download( repo_id="alexgenovese/segformer-onnx", filename="segformer-b3-fashion.onnx", # token=self.token ) self.fashion_model = ort.InferenceSession(fashion_path) self.fashion_processor = SegformerImageProcessor.from_pretrained( "sayeed99/segformer-b3-fashion", # token=self.token ) # Download and load clothes model clothes_path = hf_hub_download( repo_id="alexgenovese/segformer-onnx", filename="segformer_b2_clothes.onnx", # token=self.token ) self.clothes_model = ort.InferenceSession(clothes_path) self.clothes_processor = SegformerImageProcessor.from_pretrained( "mattmdjaga/segformer_b2_clothes", # token=self.token ) self.logger.info("All models loaded successfully.") except Exception as e: self.logger.error(f"Error initializing models: {str(e)}") raise RuntimeError(f"Error initializing models: {str(e)}") def process_fashion_image(self, image: Image.Image): inputs = self.fashion_processor(images=image, return_tensors="np") onnx_inputs = { 'input': inputs['pixel_values'] } logits = self.fashion_model.run(None, onnx_inputs)[0] return self._post_process_outputs(logits, image.size) def process_clothes_image(self, image: Image.Image): inputs = self.clothes_processor(images=image, return_tensors="np") onnx_inputs = { 'input': inputs['pixel_values'] } logits = self.clothes_model.run(None, onnx_inputs)[0] return self._post_process_outputs(logits, image.size) def _post_process_outputs(self, logits, image_size): # Convert logits to proper shape for processing logits = np.array(logits) # Resize prediction to match original image size from skimage.transform import resize resized_logits = resize( logits[0], (image_size[1], image_size[0]), order=1, preserve_range=True, mode='reflect' ) # Get prediction pred_seg = np.argmax(resized_logits, axis=0) mask_img = Image.fromarray((pred_seg * 255).astype(np.uint8)) # Convert to base64 buffered = io.BytesIO() mask_img.save(buffered, format="PNG") mask_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8") return { "mask": f"data:image/png;base64,{mask_base64}", "size": image_size, "predictions": pred_seg.tolist() } # Initialize FastAPI and ModelManager app = FastAPI() model_manager = ModelManager() @app.post("/segment-clothes-url") async def segment_clothes_url_endpoint(image_data: ImageURL): try: response = requests.get(image_data.url, stream=True) if response.status_code != 200: raise HTTPException(status_code=400, detail="Could not download image from URL") image = Image.open(response.raw).convert("RGB") return model_manager.process_clothes_image(image) except Exception as e: logging.error(f"Error processing URL: {str(e)}") raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}") @app.post("/segment-fashion-url") async def segment_url_endpoint(image_data: ImageURL): try: response = requests.get(image_data.url, stream=True) if response.status_code != 200: raise HTTPException(status_code=400, detail="Could not download image from URL") image = Image.open(response.raw).convert("RGB") return model_manager.process_fashion_image(image) except Exception as e: logging.error(f"Error processing URL: {str(e)}") raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}") @app.post("/segment-fashion-file") async def segment_endpoint(file: UploadFile = File(...)): try: image_data = await file.read() image = Image.open(io.BytesIO(image_data)).convert("RGB") return model_manager.process_fashion_image(image) except Exception as e: logging.error(f"Error in endpoint: {str(e)}") raise HTTPException(status_code=500, detail=f"Error processing: {str(e)}") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)