File size: 5,539 Bytes
7364060
b2702fe
 
077a679
7077928
 
b2702fe
 
 
7077928
b2702fe
 
3b62419
077a679
 
 
b2702fe
 
 
ca8948a
 
 
b2702fe
 
 
 
 
 
 
 
 
 
 
ca8948a
b2702fe
 
 
 
ca8948a
b2702fe
 
 
 
 
 
ca8948a
b2702fe
 
 
 
ca8948a
b2702fe
 
 
 
 
 
 
 
 
 
 
 
 
 
0e3833c
b2702fe
 
 
 
 
 
 
0e3833c
b2702fe
 
 
f6210c2
b2702fe
 
 
 
 
 
 
 
 
7364060
b2702fe
 
 
 
 
 
 
 
7364060
 
 
b2702fe
 
7364060
077a679
b2702fe
 
 
077a679
f8b9c38
 
 
077a679
 
 
 
 
 
f8b9c38
077a679
 
b2702fe
077a679
 
f8b9c38
 
 
0e3833c
 
 
 
 
 
f8b9c38
0e3833c
 
b2702fe
0e3833c
 
f8b9c38
 
b2702fe
 
 
 
 
 
 
 
 
3b62419
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
from fastapi import FastAPI, File, UploadFile, HTTPException
from transformers import SegformerImageProcessor
from huggingface_hub import hf_hub_download
from pydantic import BaseModel
from PIL import Image
import numpy as np
import io, base64, logging, requests, os
import onnxruntime as ort
from dotenv import load_dotenv

# Load environment variables
load_dotenv()

class ImageURL(BaseModel):
    url: str

class ModelManager:
    def __init__(self):
        self.logger = logging.getLogger(__name__)
        # self.token = os.getenv("HF_TOKEN")
        # if not self.token:
        #     raise ValueError("HF_TOKEN environment variable is required")
        self._initialize_models()

    def _initialize_models(self):
        try:
            # Initialize ONNX runtime sessions
            self.logger.info("Loading ONNX models...")
            
            # Download and load fashion model
            fashion_path = hf_hub_download(
                repo_id="alexgenovese/segformer-onnx",
                filename="segformer-b3-fashion.onnx",
                # token=self.token
            )
            self.fashion_model = ort.InferenceSession(fashion_path)
            self.fashion_processor = SegformerImageProcessor.from_pretrained(
                "sayeed99/segformer-b3-fashion",
                # token=self.token
            )

            # Download and load clothes model
            clothes_path = hf_hub_download(
                repo_id="alexgenovese/segformer-onnx",
                filename="segformer_b2_clothes.onnx",
                # token=self.token
            )
            self.clothes_model = ort.InferenceSession(clothes_path)
            self.clothes_processor = SegformerImageProcessor.from_pretrained(
                "mattmdjaga/segformer_b2_clothes",
                # token=self.token
            )

            self.logger.info("All models loaded successfully.")
        except Exception as e:
            self.logger.error(f"Error initializing models: {str(e)}")
            raise RuntimeError(f"Error initializing models: {str(e)}")

    def process_fashion_image(self, image: Image.Image):
        inputs = self.fashion_processor(images=image, return_tensors="np")
        onnx_inputs = {
            'input': inputs['pixel_values']
        }
        logits = self.fashion_model.run(None, onnx_inputs)[0]
        return self._post_process_outputs(logits, image.size)

    def process_clothes_image(self, image: Image.Image):
        inputs = self.clothes_processor(images=image, return_tensors="np")
        onnx_inputs = {
            'input': inputs['pixel_values']
        }
        logits = self.clothes_model.run(None, onnx_inputs)[0]
        return self._post_process_outputs(logits, image.size)

    def _post_process_outputs(self, logits, image_size):
        # Convert logits to proper shape for processing
        logits = np.array(logits)
        
        # Resize prediction to match original image size
        from skimage.transform import resize
        resized_logits = resize(
            logits[0],
            (image_size[1], image_size[0]),
            order=1,
            preserve_range=True,
            mode='reflect'
        )
        
        # Get prediction
        pred_seg = np.argmax(resized_logits, axis=0)
        mask_img = Image.fromarray((pred_seg * 255).astype(np.uint8))
        
        # Convert to base64
        buffered = io.BytesIO()
        mask_img.save(buffered, format="PNG")
        mask_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
        
        return {
            "mask": f"data:image/png;base64,{mask_base64}",
            "size": image_size,
            "predictions": pred_seg.tolist()
        }

# Initialize FastAPI and ModelManager
app = FastAPI()
model_manager = ModelManager()


@app.post("/segment-clothes-url")
async def segment_clothes_url_endpoint(image_data: ImageURL):
    try:
        response = requests.get(image_data.url, stream=True)
        if response.status_code != 200:
            raise HTTPException(status_code=400, detail="Could not download image from URL")
        
        image = Image.open(response.raw).convert("RGB")
        return model_manager.process_clothes_image(image)
        
    except Exception as e:
        logging.error(f"Error processing URL: {str(e)}")
        raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")

        
@app.post("/segment-fashion-url")
async def segment_url_endpoint(image_data: ImageURL):
    try:
        response = requests.get(image_data.url, stream=True)
        if response.status_code != 200:
            raise HTTPException(status_code=400, detail="Could not download image from URL")
        
        image = Image.open(response.raw).convert("RGB")
        return model_manager.process_fashion_image(image)
        
    except Exception as e:
        logging.error(f"Error processing URL: {str(e)}")
        raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")


@app.post("/segment-fashion-file")
async def segment_endpoint(file: UploadFile = File(...)):
    try:
        image_data = await file.read()
        image = Image.open(io.BytesIO(image_data)).convert("RGB")
        return model_manager.process_fashion_image(image)
    except Exception as e:
        logging.error(f"Error in endpoint: {str(e)}")
        raise HTTPException(status_code=500, detail=f"Error processing: {str(e)}")

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)