Spaces:
Running
Running
from fastapi import FastAPI, File, UploadFile, HTTPException | |
from transformers import SegformerImageProcessor | |
from huggingface_hub import hf_hub_download | |
from pydantic import BaseModel | |
from PIL import Image | |
import numpy as np | |
import io, base64, logging, requests, os | |
import onnxruntime as ort | |
from dotenv import load_dotenv | |
# Load environment variables | |
load_dotenv() | |
class ImageURL(BaseModel): | |
url: str | |
class ModelManager: | |
def __init__(self): | |
self.logger = logging.getLogger(__name__) | |
# self.token = os.getenv("HF_TOKEN") | |
# if not self.token: | |
# raise ValueError("HF_TOKEN environment variable is required") | |
self._initialize_models() | |
def _initialize_models(self): | |
try: | |
# Initialize ONNX runtime sessions | |
self.logger.info("Loading ONNX models...") | |
# Download and load fashion model | |
fashion_path = hf_hub_download( | |
repo_id="alexgenovese/segformer-onnx", | |
filename="segformer-b3-fashion.onnx", | |
# token=self.token | |
) | |
self.fashion_model = ort.InferenceSession(fashion_path) | |
self.fashion_processor = SegformerImageProcessor.from_pretrained( | |
"sayeed99/segformer-b3-fashion", | |
# token=self.token | |
) | |
# Download and load clothes model | |
clothes_path = hf_hub_download( | |
repo_id="alexgenovese/segformer-onnx", | |
filename="segformer_b2_clothes.onnx", | |
# token=self.token | |
) | |
self.clothes_model = ort.InferenceSession(clothes_path) | |
self.clothes_processor = SegformerImageProcessor.from_pretrained( | |
"mattmdjaga/segformer_b2_clothes", | |
# token=self.token | |
) | |
self.logger.info("All models loaded successfully.") | |
except Exception as e: | |
self.logger.error(f"Error initializing models: {str(e)}") | |
raise RuntimeError(f"Error initializing models: {str(e)}") | |
def process_fashion_image(self, image: Image.Image): | |
inputs = self.fashion_processor(images=image, return_tensors="np") | |
onnx_inputs = { | |
'input': inputs['pixel_values'] | |
} | |
logits = self.fashion_model.run(None, onnx_inputs)[0] | |
return self._post_process_outputs(logits, image.size) | |
def process_clothes_image(self, image: Image.Image): | |
inputs = self.clothes_processor(images=image, return_tensors="np") | |
onnx_inputs = { | |
'input': inputs['pixel_values'] | |
} | |
logits = self.clothes_model.run(None, onnx_inputs)[0] | |
return self._post_process_outputs(logits, image.size) | |
def _post_process_outputs(self, logits, image_size): | |
# Convert logits to proper shape for processing | |
logits = np.array(logits) | |
# Resize prediction to match original image size | |
from skimage.transform import resize | |
resized_logits = resize( | |
logits[0], | |
(image_size[1], image_size[0]), | |
order=1, | |
preserve_range=True, | |
mode='reflect' | |
) | |
# Get prediction | |
pred_seg = np.argmax(resized_logits, axis=0) | |
mask_img = Image.fromarray((pred_seg * 255).astype(np.uint8)) | |
# Convert to base64 | |
buffered = io.BytesIO() | |
mask_img.save(buffered, format="PNG") | |
mask_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
return { | |
"mask": f"data:image/png;base64,{mask_base64}", | |
"size": image_size, | |
"predictions": pred_seg.tolist() | |
} | |
# Initialize FastAPI and ModelManager | |
app = FastAPI() | |
model_manager = ModelManager() | |
async def segment_clothes_url_endpoint(image_data: ImageURL): | |
try: | |
response = requests.get(image_data.url, stream=True) | |
if response.status_code != 200: | |
raise HTTPException(status_code=400, detail="Could not download image from URL") | |
image = Image.open(response.raw).convert("RGB") | |
return model_manager.process_clothes_image(image) | |
except Exception as e: | |
logging.error(f"Error processing URL: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}") | |
async def segment_url_endpoint(image_data: ImageURL): | |
try: | |
response = requests.get(image_data.url, stream=True) | |
if response.status_code != 200: | |
raise HTTPException(status_code=400, detail="Could not download image from URL") | |
image = Image.open(response.raw).convert("RGB") | |
return model_manager.process_fashion_image(image) | |
except Exception as e: | |
logging.error(f"Error processing URL: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}") | |
async def segment_endpoint(file: UploadFile = File(...)): | |
try: | |
image_data = await file.read() | |
image = Image.open(io.BytesIO(image_data)).convert("RGB") | |
return model_manager.process_fashion_image(image) | |
except Exception as e: | |
logging.error(f"Error in endpoint: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Error processing: {str(e)}") | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) |