Spaces:
Running
Running
File size: 5,539 Bytes
7364060 b2702fe 077a679 7077928 b2702fe 7077928 b2702fe 3b62419 077a679 b2702fe ca8948a b2702fe ca8948a b2702fe ca8948a b2702fe ca8948a b2702fe ca8948a b2702fe 0e3833c b2702fe 0e3833c b2702fe f6210c2 b2702fe 7364060 b2702fe 7364060 b2702fe 7364060 077a679 b2702fe 077a679 f8b9c38 077a679 f8b9c38 077a679 b2702fe 077a679 f8b9c38 0e3833c f8b9c38 0e3833c b2702fe 0e3833c f8b9c38 b2702fe 3b62419 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
from fastapi import FastAPI, File, UploadFile, HTTPException
from transformers import SegformerImageProcessor
from huggingface_hub import hf_hub_download
from pydantic import BaseModel
from PIL import Image
import numpy as np
import io, base64, logging, requests, os
import onnxruntime as ort
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
class ImageURL(BaseModel):
url: str
class ModelManager:
def __init__(self):
self.logger = logging.getLogger(__name__)
# self.token = os.getenv("HF_TOKEN")
# if not self.token:
# raise ValueError("HF_TOKEN environment variable is required")
self._initialize_models()
def _initialize_models(self):
try:
# Initialize ONNX runtime sessions
self.logger.info("Loading ONNX models...")
# Download and load fashion model
fashion_path = hf_hub_download(
repo_id="alexgenovese/segformer-onnx",
filename="segformer-b3-fashion.onnx",
# token=self.token
)
self.fashion_model = ort.InferenceSession(fashion_path)
self.fashion_processor = SegformerImageProcessor.from_pretrained(
"sayeed99/segformer-b3-fashion",
# token=self.token
)
# Download and load clothes model
clothes_path = hf_hub_download(
repo_id="alexgenovese/segformer-onnx",
filename="segformer_b2_clothes.onnx",
# token=self.token
)
self.clothes_model = ort.InferenceSession(clothes_path)
self.clothes_processor = SegformerImageProcessor.from_pretrained(
"mattmdjaga/segformer_b2_clothes",
# token=self.token
)
self.logger.info("All models loaded successfully.")
except Exception as e:
self.logger.error(f"Error initializing models: {str(e)}")
raise RuntimeError(f"Error initializing models: {str(e)}")
def process_fashion_image(self, image: Image.Image):
inputs = self.fashion_processor(images=image, return_tensors="np")
onnx_inputs = {
'input': inputs['pixel_values']
}
logits = self.fashion_model.run(None, onnx_inputs)[0]
return self._post_process_outputs(logits, image.size)
def process_clothes_image(self, image: Image.Image):
inputs = self.clothes_processor(images=image, return_tensors="np")
onnx_inputs = {
'input': inputs['pixel_values']
}
logits = self.clothes_model.run(None, onnx_inputs)[0]
return self._post_process_outputs(logits, image.size)
def _post_process_outputs(self, logits, image_size):
# Convert logits to proper shape for processing
logits = np.array(logits)
# Resize prediction to match original image size
from skimage.transform import resize
resized_logits = resize(
logits[0],
(image_size[1], image_size[0]),
order=1,
preserve_range=True,
mode='reflect'
)
# Get prediction
pred_seg = np.argmax(resized_logits, axis=0)
mask_img = Image.fromarray((pred_seg * 255).astype(np.uint8))
# Convert to base64
buffered = io.BytesIO()
mask_img.save(buffered, format="PNG")
mask_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
return {
"mask": f"data:image/png;base64,{mask_base64}",
"size": image_size,
"predictions": pred_seg.tolist()
}
# Initialize FastAPI and ModelManager
app = FastAPI()
model_manager = ModelManager()
@app.post("/segment-clothes-url")
async def segment_clothes_url_endpoint(image_data: ImageURL):
try:
response = requests.get(image_data.url, stream=True)
if response.status_code != 200:
raise HTTPException(status_code=400, detail="Could not download image from URL")
image = Image.open(response.raw).convert("RGB")
return model_manager.process_clothes_image(image)
except Exception as e:
logging.error(f"Error processing URL: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
@app.post("/segment-fashion-url")
async def segment_url_endpoint(image_data: ImageURL):
try:
response = requests.get(image_data.url, stream=True)
if response.status_code != 200:
raise HTTPException(status_code=400, detail="Could not download image from URL")
image = Image.open(response.raw).convert("RGB")
return model_manager.process_fashion_image(image)
except Exception as e:
logging.error(f"Error processing URL: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
@app.post("/segment-fashion-file")
async def segment_endpoint(file: UploadFile = File(...)):
try:
image_data = await file.read()
image = Image.open(io.BytesIO(image_data)).convert("RGB")
return model_manager.process_fashion_image(image)
except Exception as e:
logging.error(f"Error in endpoint: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error processing: {str(e)}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860) |