Spaces:
Running
Running
Alex
commited on
Commit
·
b2702fe
1
Parent(s):
0e3833c
updated to onnx
Browse files- .gitignore +5 -0
- README.md +26 -1
- app.py +102 -147
- hf_onnx_converter.py +202 -0
- requirements.txt +4 -1
- response.json +0 -0
.gitignore
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.env
|
2 |
+
.DS_Store
|
3 |
+
models/
|
4 |
+
model_cache
|
5 |
+
onnx_models
|
README.md
CHANGED
@@ -51,4 +51,29 @@ curl -X POST "https://alexgenovese-segmentation.hf.space/segment-url" \
|
|
51 |
-d '{
|
52 |
"url": "https://plus.unsplash.com/premium_photo-1673210886161-bfcc40f54d1f?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxzZWFyY2h8MXx8cGVyc29uJTIwc3RhbmRpbmd8ZW58MHx8MHx8&w=1000&q=80"
|
53 |
}' \
|
54 |
-
-o response.json
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
-d '{
|
52 |
"url": "https://plus.unsplash.com/premium_photo-1673210886161-bfcc40f54d1f?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxzZWFyY2h8MXx8cGVyc29uJTIwc3RhbmRpbmd8ZW58MHx8MHx8&w=1000&q=80"
|
53 |
}' \
|
54 |
+
-o response.json
|
55 |
+
|
56 |
+
|
57 |
+
# Segment-clothes-url
|
58 |
+
|
59 |
+
curl -X POST "https://alexgenovese-segmentation.hf.space/segment-clothes-url" \
|
60 |
+
-H "Content-Type: application/json" \
|
61 |
+
-d '{
|
62 |
+
"url": "https://plus.unsplash.com/premium_photo-1673210886161-bfcc40f54d1f?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxzZWFyY2h8MXx8cGVyc29uJTIwc3RhbmRpbmd8ZW58MHx8MHx8&w=1000&q=80"
|
63 |
+
}' \
|
64 |
+
-o response.json
|
65 |
+
|
66 |
+
# Convert to ONNX file
|
67 |
+
|
68 |
+
# For the fashion segmentation model:
|
69 |
+
python convert_to_onnx.py --model "sayeed99/segformer-b3-fashion" --output "models/fashion_segformer.onnx"
|
70 |
+
|
71 |
+
# For the clothes segmentation model:
|
72 |
+
python convert_to_onnx.py --model "mattmdjaga/segformer_b2_clothes" --output "models/clothes_segformer.onnx"
|
73 |
+
|
74 |
+
|
75 |
+
# Convert To Onnx file
|
76 |
+
|
77 |
+
python3 hf_onnx_converter.py \
|
78 |
+
--source "mattmdjaga/segformer_b2_clothes" \
|
79 |
+
--target "alexgenovese/segformer-onnx"
|
app.py
CHANGED
@@ -1,193 +1,148 @@
|
|
1 |
from fastapi import FastAPI, File, UploadFile, HTTPException
|
2 |
-
from transformers import
|
|
|
3 |
from pydantic import BaseModel
|
4 |
from PIL import Image
|
5 |
import numpy as np
|
6 |
-
import io, base64, logging, requests,
|
7 |
-
import
|
|
|
8 |
|
9 |
-
#
|
10 |
-
|
11 |
|
12 |
-
# Add this class for the request body
|
13 |
class ImageURL(BaseModel):
|
14 |
url: str
|
15 |
|
16 |
-
|
17 |
-
|
18 |
-
logger = logging.getLogger(__name__)
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
except Exception as e:
|
40 |
-
logger.error(f"Error loading clothes model: {str(e)}")
|
41 |
-
raise RuntimeError(f"Error loading clothes model: {str(e)}")
|
42 |
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
logger.info("Preparazione dell'immagine per l'inferenza...")
|
47 |
-
inputs = processor(images=image, return_tensors="pt").to("cpu")
|
48 |
-
|
49 |
-
# Inferenza
|
50 |
-
logger.info("Esecuzione dell'inferenza...")
|
51 |
-
with torch.no_grad():
|
52 |
-
outputs = model(**inputs)
|
53 |
-
logits = outputs.logits
|
54 |
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
buffered = io.BytesIO()
|
65 |
-
mask_img.save(buffered, format="PNG")
|
66 |
-
mask_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
67 |
-
|
68 |
-
# Annotazioni
|
69 |
-
annotations = {"mask": mask.tolist(), "label": logits }
|
70 |
-
|
71 |
-
return mask_base64, annotations
|
72 |
-
|
73 |
-
# Endpoint API
|
74 |
-
@app.post("/segment")
|
75 |
-
async def segment_endpoint(file: UploadFile = File(...)):
|
76 |
-
try:
|
77 |
-
logger.info("Ricezione del file...")
|
78 |
-
image_data = await file.read()
|
79 |
-
image = Image.open(io.BytesIO(image_data)).convert("RGB")
|
80 |
|
81 |
-
|
82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
|
84 |
return {
|
85 |
"mask": f"data:image/png;base64,{mask_base64}",
|
86 |
-
"
|
|
|
87 |
}
|
88 |
-
except Exception as e:
|
89 |
-
logger.error(f"Errore nell'endpoint: {str(e)}")
|
90 |
-
raise HTTPException(status_code=500, detail=f"Errore nell'elaborazione: {str(e)}")
|
91 |
-
|
92 |
|
|
|
|
|
|
|
93 |
|
94 |
-
# Add new endpoint
|
95 |
@app.post("/segment-url")
|
96 |
async def segment_url_endpoint(image_data: ImageURL):
|
97 |
try:
|
98 |
-
logger.info("Downloading image from URL...")
|
99 |
response = requests.get(image_data.url, stream=True)
|
100 |
if response.status_code != 200:
|
101 |
raise HTTPException(status_code=400, detail="Could not download image from URL")
|
102 |
|
103 |
-
# Open image from URL
|
104 |
image = Image.open(response.raw).convert("RGB")
|
105 |
-
|
106 |
-
# Process image with SegFormer
|
107 |
-
logger.info("Processing image...")
|
108 |
-
inputs = processor(images=image, return_tensors="pt")
|
109 |
-
outputs = model(**inputs)
|
110 |
-
logits = outputs.logits.cpu()
|
111 |
-
|
112 |
-
# Upsample logits to match original image size
|
113 |
-
upsampled_logits = nn.functional.interpolate(
|
114 |
-
logits,
|
115 |
-
size=image.size[::-1],
|
116 |
-
mode="bilinear",
|
117 |
-
align_corners=False,
|
118 |
-
)
|
119 |
-
|
120 |
-
# Get prediction
|
121 |
-
pred_seg = upsampled_logits.argmax(dim=1)[0]
|
122 |
-
|
123 |
-
# Convert to image
|
124 |
-
mask_img = Image.fromarray((pred_seg.numpy() * 255).astype(np.uint8))
|
125 |
-
|
126 |
-
# Convert to base64
|
127 |
-
buffered = io.BytesIO()
|
128 |
-
mask_img.save(buffered, format="PNG")
|
129 |
-
mask_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
130 |
-
|
131 |
-
return {
|
132 |
-
"mask": f"data:image/png;base64,{mask_base64}",
|
133 |
-
"size": image.size,
|
134 |
-
"labels" : pred_seg
|
135 |
-
}
|
136 |
|
137 |
except Exception as e:
|
138 |
-
|
139 |
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
|
140 |
|
141 |
-
|
142 |
-
|
143 |
-
# Add new endpoint
|
144 |
@app.post("/segment-clothes-url")
|
145 |
async def segment_clothes_url_endpoint(image_data: ImageURL):
|
146 |
try:
|
147 |
-
logger.info("Downloading image from URL...")
|
148 |
response = requests.get(image_data.url, stream=True)
|
149 |
if response.status_code != 200:
|
150 |
raise HTTPException(status_code=400, detail="Could not download image from URL")
|
151 |
|
152 |
-
# Open image from URL
|
153 |
image = Image.open(response.raw).convert("RGB")
|
154 |
-
|
155 |
-
# Process image with SegFormer
|
156 |
-
logger.info("Processing image...")
|
157 |
-
inputs = clothes_processor(images=image, return_tensors="pt")
|
158 |
-
outputs = clothes_model(**inputs)
|
159 |
-
logits = outputs.logits.cpu()
|
160 |
-
|
161 |
-
# Upsample logits to match original image size
|
162 |
-
upsampled_logits = nn.functional.interpolate(
|
163 |
-
logits,
|
164 |
-
size=image.size[::-1],
|
165 |
-
mode="bilinear",
|
166 |
-
align_corners=False,
|
167 |
-
)
|
168 |
-
|
169 |
-
# Get prediction
|
170 |
-
pred_seg = upsampled_logits.argmax(dim=1)[0]
|
171 |
-
|
172 |
-
# Convert to image
|
173 |
-
mask_img = Image.fromarray((pred_seg.numpy() * 255).astype(np.uint8))
|
174 |
-
|
175 |
-
# Convert to base64
|
176 |
-
buffered = io.BytesIO()
|
177 |
-
mask_img.save(buffered, format="PNG")
|
178 |
-
mask_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
179 |
-
|
180 |
-
return {
|
181 |
-
"mask": f"data:image/png;base64,{mask_base64}",
|
182 |
-
"size": image.size,
|
183 |
-
"predictions": pred_seg.numpy().tolist()
|
184 |
-
}
|
185 |
|
186 |
except Exception as e:
|
187 |
-
|
188 |
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
|
189 |
|
190 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
191 |
if __name__ == "__main__":
|
192 |
import uvicorn
|
193 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|
1 |
from fastapi import FastAPI, File, UploadFile, HTTPException
|
2 |
+
from transformers import SegformerImageProcessor
|
3 |
+
from huggingface_hub import hf_hub_download
|
4 |
from pydantic import BaseModel
|
5 |
from PIL import Image
|
6 |
import numpy as np
|
7 |
+
import io, base64, logging, requests, os
|
8 |
+
import onnxruntime as ort
|
9 |
+
from dotenv import load_dotenv
|
10 |
|
11 |
+
# Load environment variables
|
12 |
+
load_dotenv()
|
13 |
|
|
|
14 |
class ImageURL(BaseModel):
|
15 |
url: str
|
16 |
|
17 |
+
class ModelManager:
|
18 |
+
def __init__(self):
|
19 |
+
self.logger = logging.getLogger(__name__)
|
20 |
+
self.token = os.getenv("HF_TOKEN")
|
21 |
+
if not self.token:
|
22 |
+
raise ValueError("HF_TOKEN environment variable is required")
|
23 |
+
self._initialize_models()
|
24 |
+
|
25 |
+
def _initialize_models(self):
|
26 |
+
try:
|
27 |
+
# Initialize ONNX runtime sessions
|
28 |
+
self.logger.info("Loading ONNX models...")
|
29 |
+
|
30 |
+
# Download and load fashion model
|
31 |
+
fashion_path = hf_hub_download(
|
32 |
+
repo_id="alexgenovese/segformer-onnx",
|
33 |
+
filename="segformer-b3-fashion.onnx",
|
34 |
+
token=self.token
|
35 |
+
)
|
36 |
+
self.fashion_model = ort.InferenceSession(fashion_path)
|
37 |
+
self.fashion_processor = SegformerImageProcessor.from_pretrained(
|
38 |
+
"sayeed99/segformer-b3-fashion",
|
39 |
+
token=self.token
|
40 |
+
)
|
41 |
+
|
42 |
+
# Download and load clothes model
|
43 |
+
clothes_path = hf_hub_download(
|
44 |
+
repo_id="alexgenovese/segformer-onnx",
|
45 |
+
filename="segformer_b2_clothes.onnx",
|
46 |
+
token=self.token
|
47 |
+
)
|
48 |
+
self.clothes_model = ort.InferenceSession(clothes_path)
|
49 |
+
self.clothes_processor = SegformerImageProcessor.from_pretrained(
|
50 |
+
"mattmdjaga/segformer_b2_clothes",
|
51 |
+
token=self.token
|
52 |
+
)
|
53 |
+
|
54 |
+
self.logger.info("All models loaded successfully.")
|
55 |
+
except Exception as e:
|
56 |
+
self.logger.error(f"Error initializing models: {str(e)}")
|
57 |
+
raise RuntimeError(f"Error initializing models: {str(e)}")
|
58 |
+
|
59 |
+
def process_fashion_image(self, image: Image.Image):
|
60 |
+
inputs = self.fashion_processor(images=image, return_tensors="np")
|
61 |
+
onnx_inputs = {
|
62 |
+
'input': inputs['pixel_values']
|
63 |
+
}
|
64 |
+
logits = self.fashion_model.run(None, onnx_inputs)[0]
|
65 |
+
return self._post_process_outputs(logits, image.size)
|
66 |
|
67 |
+
def process_clothes_image(self, image: Image.Image):
|
68 |
+
inputs = self.clothes_processor(images=image, return_tensors="np")
|
69 |
+
onnx_inputs = {
|
70 |
+
'input': inputs['pixel_values']
|
71 |
+
}
|
72 |
+
logits = self.clothes_model.run(None, onnx_inputs)[0]
|
73 |
+
return self._post_process_outputs(logits, image.size)
|
|
|
|
|
|
|
74 |
|
75 |
+
def _post_process_outputs(self, logits, image_size):
|
76 |
+
# Convert logits to proper shape for processing
|
77 |
+
logits = np.array(logits)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
+
# Resize prediction to match original image size
|
80 |
+
from skimage.transform import resize
|
81 |
+
resized_logits = resize(
|
82 |
+
logits[0],
|
83 |
+
(image_size[1], image_size[0]),
|
84 |
+
order=1,
|
85 |
+
preserve_range=True,
|
86 |
+
mode='reflect'
|
87 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
|
89 |
+
# Get prediction
|
90 |
+
pred_seg = np.argmax(resized_logits, axis=0)
|
91 |
+
mask_img = Image.fromarray((pred_seg * 255).astype(np.uint8))
|
92 |
+
|
93 |
+
# Convert to base64
|
94 |
+
buffered = io.BytesIO()
|
95 |
+
mask_img.save(buffered, format="PNG")
|
96 |
+
mask_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
97 |
|
98 |
return {
|
99 |
"mask": f"data:image/png;base64,{mask_base64}",
|
100 |
+
"size": image_size,
|
101 |
+
"predictions": pred_seg.tolist()
|
102 |
}
|
|
|
|
|
|
|
|
|
103 |
|
104 |
+
# Initialize FastAPI and ModelManager
|
105 |
+
app = FastAPI()
|
106 |
+
model_manager = ModelManager()
|
107 |
|
|
|
108 |
@app.post("/segment-url")
|
109 |
async def segment_url_endpoint(image_data: ImageURL):
|
110 |
try:
|
|
|
111 |
response = requests.get(image_data.url, stream=True)
|
112 |
if response.status_code != 200:
|
113 |
raise HTTPException(status_code=400, detail="Could not download image from URL")
|
114 |
|
|
|
115 |
image = Image.open(response.raw).convert("RGB")
|
116 |
+
return model_manager.process_fashion_image(image)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
|
118 |
except Exception as e:
|
119 |
+
logging.error(f"Error processing URL: {str(e)}")
|
120 |
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
|
121 |
|
|
|
|
|
|
|
122 |
@app.post("/segment-clothes-url")
|
123 |
async def segment_clothes_url_endpoint(image_data: ImageURL):
|
124 |
try:
|
|
|
125 |
response = requests.get(image_data.url, stream=True)
|
126 |
if response.status_code != 200:
|
127 |
raise HTTPException(status_code=400, detail="Could not download image from URL")
|
128 |
|
|
|
129 |
image = Image.open(response.raw).convert("RGB")
|
130 |
+
return model_manager.process_clothes_image(image)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
|
132 |
except Exception as e:
|
133 |
+
logging.error(f"Error processing URL: {str(e)}")
|
134 |
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
|
135 |
|
136 |
+
@app.post("/segment")
|
137 |
+
async def segment_endpoint(file: UploadFile = File(...)):
|
138 |
+
try:
|
139 |
+
image_data = await file.read()
|
140 |
+
image = Image.open(io.BytesIO(image_data)).convert("RGB")
|
141 |
+
return model_manager.process_fashion_image(image)
|
142 |
+
except Exception as e:
|
143 |
+
logging.error(f"Error in endpoint: {str(e)}")
|
144 |
+
raise HTTPException(status_code=500, detail=f"Error processing: {str(e)}")
|
145 |
+
|
146 |
if __name__ == "__main__":
|
147 |
import uvicorn
|
148 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|
hf_onnx_converter.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoModelForSemanticSegmentation, SegformerImageProcessor
|
3 |
+
from huggingface_hub import HfApi, create_repo, upload_file, model_info
|
4 |
+
import os
|
5 |
+
from dotenv import load_dotenv
|
6 |
+
from pathlib import Path
|
7 |
+
import logging
|
8 |
+
import argparse
|
9 |
+
import tempfile
|
10 |
+
|
11 |
+
# Setup logging
|
12 |
+
logging.basicConfig(level=logging.INFO)
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
# Load environment variables
|
16 |
+
load_dotenv()
|
17 |
+
|
18 |
+
class ConfigurationError(Exception):
|
19 |
+
"""Raised when required environment variables are missing"""
|
20 |
+
pass
|
21 |
+
|
22 |
+
class HFOnnxConverter:
|
23 |
+
def __init__(self, token=None):
|
24 |
+
# Load configuration from environment
|
25 |
+
self.token = token or os.getenv("HF_TOKEN")
|
26 |
+
self.model_cache_dir = os.getenv("MODEL_CACHE_DIR")
|
27 |
+
self.onnx_output_dir = os.getenv("ONNX_OUTPUT_DIR")
|
28 |
+
|
29 |
+
# Validate configuration
|
30 |
+
if not self.token:
|
31 |
+
raise ConfigurationError("HF_TOKEN is required in environment variables")
|
32 |
+
|
33 |
+
# Create directories if they don't exist
|
34 |
+
for directory in [self.model_cache_dir, self.onnx_output_dir]:
|
35 |
+
if directory:
|
36 |
+
Path(directory).mkdir(parents=True, exist_ok=True)
|
37 |
+
|
38 |
+
self.api = HfApi()
|
39 |
+
|
40 |
+
# Login to Hugging Face
|
41 |
+
try:
|
42 |
+
self.api.whoami(token=self.token)
|
43 |
+
logger.info("Successfully authenticated with Hugging Face")
|
44 |
+
except Exception as e:
|
45 |
+
raise ConfigurationError(f"Failed to authenticate with Hugging Face: {str(e)}")
|
46 |
+
|
47 |
+
def setup_repository(self, repo_name: str) -> str:
|
48 |
+
"""Create or get repository on Hugging Face Hub"""
|
49 |
+
try:
|
50 |
+
create_repo(
|
51 |
+
repo_name,
|
52 |
+
token=self.token,
|
53 |
+
private=False,
|
54 |
+
exist_ok=True
|
55 |
+
)
|
56 |
+
logger.info(f"Repository {repo_name} is ready")
|
57 |
+
return repo_name
|
58 |
+
except Exception as e:
|
59 |
+
logger.error(f"Error setting up repository: {e}")
|
60 |
+
raise
|
61 |
+
|
62 |
+
def verify_model_exists(self, model_name: str) -> bool:
|
63 |
+
"""Verify if the model exists and is accessible"""
|
64 |
+
try:
|
65 |
+
model_info(model_name, token=self.token)
|
66 |
+
return True
|
67 |
+
except Exception as e:
|
68 |
+
logger.error(f"Model verification failed: {str(e)}")
|
69 |
+
return False
|
70 |
+
|
71 |
+
def convert_and_push(self, source_model: str, target_repo: str):
|
72 |
+
"""Convert model to ONNX and push to Hugging Face Hub"""
|
73 |
+
try:
|
74 |
+
# Verify model exists and is accessible
|
75 |
+
if not self.verify_model_exists(source_model):
|
76 |
+
raise ValueError(f"Model {source_model} is not accessible. Check if the model exists and you have proper permissions.")
|
77 |
+
|
78 |
+
# Use model cache directory if specified
|
79 |
+
model_kwargs = {
|
80 |
+
"token": self.token
|
81 |
+
}
|
82 |
+
if self.model_cache_dir:
|
83 |
+
model_kwargs["cache_dir"] = self.model_cache_dir
|
84 |
+
|
85 |
+
# Create working directory
|
86 |
+
working_dir = self.onnx_output_dir or tempfile.mkdtemp()
|
87 |
+
tmp_path = Path(working_dir) / f"{target_repo.split('/')[-1]}.onnx"
|
88 |
+
|
89 |
+
logger.info(f"Loading model {source_model}...")
|
90 |
+
model = AutoModelForSemanticSegmentation.from_pretrained(
|
91 |
+
source_model,
|
92 |
+
**model_kwargs
|
93 |
+
)
|
94 |
+
processor = SegformerImageProcessor.from_pretrained(
|
95 |
+
source_model,
|
96 |
+
**model_kwargs
|
97 |
+
)
|
98 |
+
|
99 |
+
# Set model to evaluation mode
|
100 |
+
model.eval()
|
101 |
+
|
102 |
+
# Create dummy input
|
103 |
+
dummy_input = processor(
|
104 |
+
images=torch.zeros(1, 3, 224, 224),
|
105 |
+
return_tensors="pt"
|
106 |
+
)
|
107 |
+
|
108 |
+
# Export to ONNX
|
109 |
+
logger.info(f"Converting to ONNX format... Output path: {tmp_path}")
|
110 |
+
torch.onnx.export(
|
111 |
+
model,
|
112 |
+
(dummy_input['pixel_values'],),
|
113 |
+
tmp_path,
|
114 |
+
input_names=['input'],
|
115 |
+
output_names=['output'],
|
116 |
+
dynamic_axes={
|
117 |
+
'input': {0: 'batch_size', 2: 'height', 3: 'width'},
|
118 |
+
'output': {0: 'batch_size'}
|
119 |
+
},
|
120 |
+
opset_version=12,
|
121 |
+
do_constant_folding=True
|
122 |
+
)
|
123 |
+
|
124 |
+
# Create model card with environment info
|
125 |
+
model_card = f"""---
|
126 |
+
base_model: {source_model}
|
127 |
+
tags:
|
128 |
+
- onnx
|
129 |
+
- semantic-segmentation
|
130 |
+
---
|
131 |
+
|
132 |
+
# ONNX Model converted from {source_model}
|
133 |
+
|
134 |
+
This is an ONNX version of the model {source_model}, converted automatically.
|
135 |
+
|
136 |
+
## Model Information
|
137 |
+
- Original Model: {source_model}
|
138 |
+
- ONNX Opset Version: 12
|
139 |
+
- Input Shape: Dynamic (batch_size, 3, height, width)
|
140 |
+
|
141 |
+
## Usage
|
142 |
+
|
143 |
+
```python
|
144 |
+
import onnxruntime as ort
|
145 |
+
import numpy as np
|
146 |
+
|
147 |
+
# Load ONNX model
|
148 |
+
session = ort.InferenceSession("model.onnx")
|
149 |
+
|
150 |
+
# Prepare input
|
151 |
+
input_data = np.zeros((1, 3, 224, 224), dtype=np.float32)
|
152 |
+
|
153 |
+
# Run inference
|
154 |
+
outputs = session.run(None, {{"input": input_data}})
|
155 |
+
```
|
156 |
+
"""
|
157 |
+
# Save model card
|
158 |
+
readme_path = Path(working_dir) / "README.md"
|
159 |
+
with open(readme_path, "w") as f:
|
160 |
+
f.write(model_card)
|
161 |
+
|
162 |
+
# Push files to hub
|
163 |
+
logger.info(f"Pushing files to {target_repo}...")
|
164 |
+
self.api.upload_file(
|
165 |
+
path_or_fileobj=str(tmp_path),
|
166 |
+
path_in_repo="model.onnx",
|
167 |
+
repo_id=target_repo,
|
168 |
+
token=self.token
|
169 |
+
)
|
170 |
+
self.api.upload_file(
|
171 |
+
path_or_fileobj=str(readme_path),
|
172 |
+
path_in_repo="README.md",
|
173 |
+
repo_id=target_repo,
|
174 |
+
token=self.token
|
175 |
+
)
|
176 |
+
|
177 |
+
logger.info(f"Successfully pushed ONNX model to {target_repo}")
|
178 |
+
return True
|
179 |
+
|
180 |
+
except Exception as e:
|
181 |
+
logger.error(f"Error during conversion and upload: {e}")
|
182 |
+
return False
|
183 |
+
|
184 |
+
def main():
|
185 |
+
parser = argparse.ArgumentParser(description='Convert and push model to ONNX format on Hugging Face Hub')
|
186 |
+
parser.add_argument('--source', type=str, required=True,
|
187 |
+
help='Source model name (e.g., "sayeed99/segformer-b3-fashion")')
|
188 |
+
parser.add_argument('--target', type=str, required=True,
|
189 |
+
help='Target repository name (e.g., "your-username/model-name-onnx")')
|
190 |
+
parser.add_argument('--token', type=str, help='Hugging Face token (optional)')
|
191 |
+
|
192 |
+
args = parser.parse_args()
|
193 |
+
|
194 |
+
converter = HFOnnxConverter(token=args.token)
|
195 |
+
converter.setup_repository(args.target)
|
196 |
+
success = converter.convert_and_push(args.source, args.target)
|
197 |
+
|
198 |
+
if not success:
|
199 |
+
exit(1)
|
200 |
+
|
201 |
+
if __name__ == "__main__":
|
202 |
+
main()
|
requirements.txt
CHANGED
@@ -4,4 +4,7 @@ torch
|
|
4 |
torchvision
|
5 |
transformers
|
6 |
pillow
|
7 |
-
numpy
|
|
|
|
|
|
|
|
4 |
torchvision
|
5 |
transformers
|
6 |
pillow
|
7 |
+
numpy
|
8 |
+
torch
|
9 |
+
dotenv
|
10 |
+
onnx
|
response.json
CHANGED
The diff for this file is too large to render.
See raw diff
|
|