Alex commited on
Commit
b2702fe
·
1 Parent(s): 0e3833c

updated to onnx

Browse files
Files changed (6) hide show
  1. .gitignore +5 -0
  2. README.md +26 -1
  3. app.py +102 -147
  4. hf_onnx_converter.py +202 -0
  5. requirements.txt +4 -1
  6. response.json +0 -0
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ .env
2
+ .DS_Store
3
+ models/
4
+ model_cache
5
+ onnx_models
README.md CHANGED
@@ -51,4 +51,29 @@ curl -X POST "https://alexgenovese-segmentation.hf.space/segment-url" \
51
  -d '{
52
  "url": "https://plus.unsplash.com/premium_photo-1673210886161-bfcc40f54d1f?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxzZWFyY2h8MXx8cGVyc29uJTIwc3RhbmRpbmd8ZW58MHx8MHx8&w=1000&q=80"
53
  }' \
54
- -o response.json
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  -d '{
52
  "url": "https://plus.unsplash.com/premium_photo-1673210886161-bfcc40f54d1f?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxzZWFyY2h8MXx8cGVyc29uJTIwc3RhbmRpbmd8ZW58MHx8MHx8&w=1000&q=80"
53
  }' \
54
+ -o response.json
55
+
56
+
57
+ # Segment-clothes-url
58
+
59
+ curl -X POST "https://alexgenovese-segmentation.hf.space/segment-clothes-url" \
60
+ -H "Content-Type: application/json" \
61
+ -d '{
62
+ "url": "https://plus.unsplash.com/premium_photo-1673210886161-bfcc40f54d1f?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxzZWFyY2h8MXx8cGVyc29uJTIwc3RhbmRpbmd8ZW58MHx8MHx8&w=1000&q=80"
63
+ }' \
64
+ -o response.json
65
+
66
+ # Convert to ONNX file
67
+
68
+ # For the fashion segmentation model:
69
+ python convert_to_onnx.py --model "sayeed99/segformer-b3-fashion" --output "models/fashion_segformer.onnx"
70
+
71
+ # For the clothes segmentation model:
72
+ python convert_to_onnx.py --model "mattmdjaga/segformer_b2_clothes" --output "models/clothes_segformer.onnx"
73
+
74
+
75
+ # Convert To Onnx file
76
+
77
+ python3 hf_onnx_converter.py \
78
+ --source "mattmdjaga/segformer_b2_clothes" \
79
+ --target "alexgenovese/segformer-onnx"
app.py CHANGED
@@ -1,193 +1,148 @@
1
  from fastapi import FastAPI, File, UploadFile, HTTPException
2
- from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor, AutoModelForSemanticSegmentation
 
3
  from pydantic import BaseModel
4
  from PIL import Image
5
  import numpy as np
6
- import io, base64, logging, requests, torch
7
- import torch.nn as nn
 
8
 
9
- # Inizializza l'app FastAPI
10
- app = FastAPI()
11
 
12
- # Add this class for the request body
13
  class ImageURL(BaseModel):
14
  url: str
15
 
16
- # Configura il logging
17
- logging.basicConfig(level=logging.INFO)
18
- logger = logging.getLogger(__name__)
19
-
20
- # Carica il modello e il processore SegFormer
21
- try:
22
- logger.info("Caricamento del modello SegFormer...")
23
- model = SegformerForSemanticSegmentation.from_pretrained("sayeed99/segformer-b3-fashion")
24
- processor = SegformerImageProcessor.from_pretrained("sayeed99/segformer-b3-fashion")
25
- model.to("cpu") # Usa CPU per il free tier
26
- logger.info("Modello caricato con successo.")
27
- except Exception as e:
28
- logger.error(f"Errore nel caricamento del modello: {str(e)}")
29
- raise RuntimeError(f"Errore nel caricamento del modello: {str(e)}")
30
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- # Add new model and processor initialization after existing ones
33
- try:
34
- logger.info("Loading clothes segmentation model...")
35
- clothes_model = AutoModelForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes")
36
- clothes_processor = SegformerImageProcessor.from_pretrained("mattmdjaga/segformer_b2_clothes")
37
- clothes_model.to("cpu")
38
- logger.info("Clothes model loaded successfully.")
39
- except Exception as e:
40
- logger.error(f"Error loading clothes model: {str(e)}")
41
- raise RuntimeError(f"Error loading clothes model: {str(e)}")
42
 
43
- # Funzione per segmentare l'immagine
44
- def segment_image(image: Image.Image):
45
- # Prepara l'input per SegFormer
46
- logger.info("Preparazione dell'immagine per l'inferenza...")
47
- inputs = processor(images=image, return_tensors="pt").to("cpu")
48
-
49
- # Inferenza
50
- logger.info("Esecuzione dell'inferenza...")
51
- with torch.no_grad():
52
- outputs = model(**inputs)
53
- logits = outputs.logits
54
 
55
- # Post-processa la maschera
56
- logger.info("Post-processing della maschera...")
57
- mask = torch.argmax(logits, dim=1)[0]
58
- mask = mask.cpu().numpy()
59
-
60
- # Converti la maschera in immagine
61
- mask_img = Image.fromarray((mask * 255 / mask.max()).astype(np.uint8))
62
-
63
- # Converti la maschera in base64 per la risposta
64
- buffered = io.BytesIO()
65
- mask_img.save(buffered, format="PNG")
66
- mask_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
67
-
68
- # Annotazioni
69
- annotations = {"mask": mask.tolist(), "label": logits }
70
-
71
- return mask_base64, annotations
72
-
73
- # Endpoint API
74
- @app.post("/segment")
75
- async def segment_endpoint(file: UploadFile = File(...)):
76
- try:
77
- logger.info("Ricezione del file...")
78
- image_data = await file.read()
79
- image = Image.open(io.BytesIO(image_data)).convert("RGB")
80
 
81
- logger.info("Segmentazione dell'immagine...")
82
- mask_base64, annotations = segment_image(image)
 
 
 
 
 
 
83
 
84
  return {
85
  "mask": f"data:image/png;base64,{mask_base64}",
86
- "annotations": annotations
 
87
  }
88
- except Exception as e:
89
- logger.error(f"Errore nell'endpoint: {str(e)}")
90
- raise HTTPException(status_code=500, detail=f"Errore nell'elaborazione: {str(e)}")
91
-
92
 
 
 
 
93
 
94
- # Add new endpoint
95
  @app.post("/segment-url")
96
  async def segment_url_endpoint(image_data: ImageURL):
97
  try:
98
- logger.info("Downloading image from URL...")
99
  response = requests.get(image_data.url, stream=True)
100
  if response.status_code != 200:
101
  raise HTTPException(status_code=400, detail="Could not download image from URL")
102
 
103
- # Open image from URL
104
  image = Image.open(response.raw).convert("RGB")
105
-
106
- # Process image with SegFormer
107
- logger.info("Processing image...")
108
- inputs = processor(images=image, return_tensors="pt")
109
- outputs = model(**inputs)
110
- logits = outputs.logits.cpu()
111
-
112
- # Upsample logits to match original image size
113
- upsampled_logits = nn.functional.interpolate(
114
- logits,
115
- size=image.size[::-1],
116
- mode="bilinear",
117
- align_corners=False,
118
- )
119
-
120
- # Get prediction
121
- pred_seg = upsampled_logits.argmax(dim=1)[0]
122
-
123
- # Convert to image
124
- mask_img = Image.fromarray((pred_seg.numpy() * 255).astype(np.uint8))
125
-
126
- # Convert to base64
127
- buffered = io.BytesIO()
128
- mask_img.save(buffered, format="PNG")
129
- mask_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
130
-
131
- return {
132
- "mask": f"data:image/png;base64,{mask_base64}",
133
- "size": image.size,
134
- "labels" : pred_seg
135
- }
136
 
137
  except Exception as e:
138
- logger.error(f"Error processing URL: {str(e)}")
139
  raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
140
 
141
-
142
-
143
- # Add new endpoint
144
  @app.post("/segment-clothes-url")
145
  async def segment_clothes_url_endpoint(image_data: ImageURL):
146
  try:
147
- logger.info("Downloading image from URL...")
148
  response = requests.get(image_data.url, stream=True)
149
  if response.status_code != 200:
150
  raise HTTPException(status_code=400, detail="Could not download image from URL")
151
 
152
- # Open image from URL
153
  image = Image.open(response.raw).convert("RGB")
154
-
155
- # Process image with SegFormer
156
- logger.info("Processing image...")
157
- inputs = clothes_processor(images=image, return_tensors="pt")
158
- outputs = clothes_model(**inputs)
159
- logits = outputs.logits.cpu()
160
-
161
- # Upsample logits to match original image size
162
- upsampled_logits = nn.functional.interpolate(
163
- logits,
164
- size=image.size[::-1],
165
- mode="bilinear",
166
- align_corners=False,
167
- )
168
-
169
- # Get prediction
170
- pred_seg = upsampled_logits.argmax(dim=1)[0]
171
-
172
- # Convert to image
173
- mask_img = Image.fromarray((pred_seg.numpy() * 255).astype(np.uint8))
174
-
175
- # Convert to base64
176
- buffered = io.BytesIO()
177
- mask_img.save(buffered, format="PNG")
178
- mask_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
179
-
180
- return {
181
- "mask": f"data:image/png;base64,{mask_base64}",
182
- "size": image.size,
183
- "predictions": pred_seg.numpy().tolist()
184
- }
185
 
186
  except Exception as e:
187
- logger.error(f"Error processing URL: {str(e)}")
188
  raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
189
 
190
- # Per compatibilità con Hugging Face Spaces
 
 
 
 
 
 
 
 
 
191
  if __name__ == "__main__":
192
  import uvicorn
193
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
  from fastapi import FastAPI, File, UploadFile, HTTPException
2
+ from transformers import SegformerImageProcessor
3
+ from huggingface_hub import hf_hub_download
4
  from pydantic import BaseModel
5
  from PIL import Image
6
  import numpy as np
7
+ import io, base64, logging, requests, os
8
+ import onnxruntime as ort
9
+ from dotenv import load_dotenv
10
 
11
+ # Load environment variables
12
+ load_dotenv()
13
 
 
14
  class ImageURL(BaseModel):
15
  url: str
16
 
17
+ class ModelManager:
18
+ def __init__(self):
19
+ self.logger = logging.getLogger(__name__)
20
+ self.token = os.getenv("HF_TOKEN")
21
+ if not self.token:
22
+ raise ValueError("HF_TOKEN environment variable is required")
23
+ self._initialize_models()
24
+
25
+ def _initialize_models(self):
26
+ try:
27
+ # Initialize ONNX runtime sessions
28
+ self.logger.info("Loading ONNX models...")
29
+
30
+ # Download and load fashion model
31
+ fashion_path = hf_hub_download(
32
+ repo_id="alexgenovese/segformer-onnx",
33
+ filename="segformer-b3-fashion.onnx",
34
+ token=self.token
35
+ )
36
+ self.fashion_model = ort.InferenceSession(fashion_path)
37
+ self.fashion_processor = SegformerImageProcessor.from_pretrained(
38
+ "sayeed99/segformer-b3-fashion",
39
+ token=self.token
40
+ )
41
+
42
+ # Download and load clothes model
43
+ clothes_path = hf_hub_download(
44
+ repo_id="alexgenovese/segformer-onnx",
45
+ filename="segformer_b2_clothes.onnx",
46
+ token=self.token
47
+ )
48
+ self.clothes_model = ort.InferenceSession(clothes_path)
49
+ self.clothes_processor = SegformerImageProcessor.from_pretrained(
50
+ "mattmdjaga/segformer_b2_clothes",
51
+ token=self.token
52
+ )
53
+
54
+ self.logger.info("All models loaded successfully.")
55
+ except Exception as e:
56
+ self.logger.error(f"Error initializing models: {str(e)}")
57
+ raise RuntimeError(f"Error initializing models: {str(e)}")
58
+
59
+ def process_fashion_image(self, image: Image.Image):
60
+ inputs = self.fashion_processor(images=image, return_tensors="np")
61
+ onnx_inputs = {
62
+ 'input': inputs['pixel_values']
63
+ }
64
+ logits = self.fashion_model.run(None, onnx_inputs)[0]
65
+ return self._post_process_outputs(logits, image.size)
66
 
67
+ def process_clothes_image(self, image: Image.Image):
68
+ inputs = self.clothes_processor(images=image, return_tensors="np")
69
+ onnx_inputs = {
70
+ 'input': inputs['pixel_values']
71
+ }
72
+ logits = self.clothes_model.run(None, onnx_inputs)[0]
73
+ return self._post_process_outputs(logits, image.size)
 
 
 
74
 
75
+ def _post_process_outputs(self, logits, image_size):
76
+ # Convert logits to proper shape for processing
77
+ logits = np.array(logits)
 
 
 
 
 
 
 
 
78
 
79
+ # Resize prediction to match original image size
80
+ from skimage.transform import resize
81
+ resized_logits = resize(
82
+ logits[0],
83
+ (image_size[1], image_size[0]),
84
+ order=1,
85
+ preserve_range=True,
86
+ mode='reflect'
87
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
+ # Get prediction
90
+ pred_seg = np.argmax(resized_logits, axis=0)
91
+ mask_img = Image.fromarray((pred_seg * 255).astype(np.uint8))
92
+
93
+ # Convert to base64
94
+ buffered = io.BytesIO()
95
+ mask_img.save(buffered, format="PNG")
96
+ mask_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
97
 
98
  return {
99
  "mask": f"data:image/png;base64,{mask_base64}",
100
+ "size": image_size,
101
+ "predictions": pred_seg.tolist()
102
  }
 
 
 
 
103
 
104
+ # Initialize FastAPI and ModelManager
105
+ app = FastAPI()
106
+ model_manager = ModelManager()
107
 
 
108
  @app.post("/segment-url")
109
  async def segment_url_endpoint(image_data: ImageURL):
110
  try:
 
111
  response = requests.get(image_data.url, stream=True)
112
  if response.status_code != 200:
113
  raise HTTPException(status_code=400, detail="Could not download image from URL")
114
 
 
115
  image = Image.open(response.raw).convert("RGB")
116
+ return model_manager.process_fashion_image(image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
  except Exception as e:
119
+ logging.error(f"Error processing URL: {str(e)}")
120
  raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
121
 
 
 
 
122
  @app.post("/segment-clothes-url")
123
  async def segment_clothes_url_endpoint(image_data: ImageURL):
124
  try:
 
125
  response = requests.get(image_data.url, stream=True)
126
  if response.status_code != 200:
127
  raise HTTPException(status_code=400, detail="Could not download image from URL")
128
 
 
129
  image = Image.open(response.raw).convert("RGB")
130
+ return model_manager.process_clothes_image(image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
  except Exception as e:
133
+ logging.error(f"Error processing URL: {str(e)}")
134
  raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
135
 
136
+ @app.post("/segment")
137
+ async def segment_endpoint(file: UploadFile = File(...)):
138
+ try:
139
+ image_data = await file.read()
140
+ image = Image.open(io.BytesIO(image_data)).convert("RGB")
141
+ return model_manager.process_fashion_image(image)
142
+ except Exception as e:
143
+ logging.error(f"Error in endpoint: {str(e)}")
144
+ raise HTTPException(status_code=500, detail=f"Error processing: {str(e)}")
145
+
146
  if __name__ == "__main__":
147
  import uvicorn
148
  uvicorn.run(app, host="0.0.0.0", port=7860)
hf_onnx_converter.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForSemanticSegmentation, SegformerImageProcessor
3
+ from huggingface_hub import HfApi, create_repo, upload_file, model_info
4
+ import os
5
+ from dotenv import load_dotenv
6
+ from pathlib import Path
7
+ import logging
8
+ import argparse
9
+ import tempfile
10
+
11
+ # Setup logging
12
+ logging.basicConfig(level=logging.INFO)
13
+ logger = logging.getLogger(__name__)
14
+
15
+ # Load environment variables
16
+ load_dotenv()
17
+
18
+ class ConfigurationError(Exception):
19
+ """Raised when required environment variables are missing"""
20
+ pass
21
+
22
+ class HFOnnxConverter:
23
+ def __init__(self, token=None):
24
+ # Load configuration from environment
25
+ self.token = token or os.getenv("HF_TOKEN")
26
+ self.model_cache_dir = os.getenv("MODEL_CACHE_DIR")
27
+ self.onnx_output_dir = os.getenv("ONNX_OUTPUT_DIR")
28
+
29
+ # Validate configuration
30
+ if not self.token:
31
+ raise ConfigurationError("HF_TOKEN is required in environment variables")
32
+
33
+ # Create directories if they don't exist
34
+ for directory in [self.model_cache_dir, self.onnx_output_dir]:
35
+ if directory:
36
+ Path(directory).mkdir(parents=True, exist_ok=True)
37
+
38
+ self.api = HfApi()
39
+
40
+ # Login to Hugging Face
41
+ try:
42
+ self.api.whoami(token=self.token)
43
+ logger.info("Successfully authenticated with Hugging Face")
44
+ except Exception as e:
45
+ raise ConfigurationError(f"Failed to authenticate with Hugging Face: {str(e)}")
46
+
47
+ def setup_repository(self, repo_name: str) -> str:
48
+ """Create or get repository on Hugging Face Hub"""
49
+ try:
50
+ create_repo(
51
+ repo_name,
52
+ token=self.token,
53
+ private=False,
54
+ exist_ok=True
55
+ )
56
+ logger.info(f"Repository {repo_name} is ready")
57
+ return repo_name
58
+ except Exception as e:
59
+ logger.error(f"Error setting up repository: {e}")
60
+ raise
61
+
62
+ def verify_model_exists(self, model_name: str) -> bool:
63
+ """Verify if the model exists and is accessible"""
64
+ try:
65
+ model_info(model_name, token=self.token)
66
+ return True
67
+ except Exception as e:
68
+ logger.error(f"Model verification failed: {str(e)}")
69
+ return False
70
+
71
+ def convert_and_push(self, source_model: str, target_repo: str):
72
+ """Convert model to ONNX and push to Hugging Face Hub"""
73
+ try:
74
+ # Verify model exists and is accessible
75
+ if not self.verify_model_exists(source_model):
76
+ raise ValueError(f"Model {source_model} is not accessible. Check if the model exists and you have proper permissions.")
77
+
78
+ # Use model cache directory if specified
79
+ model_kwargs = {
80
+ "token": self.token
81
+ }
82
+ if self.model_cache_dir:
83
+ model_kwargs["cache_dir"] = self.model_cache_dir
84
+
85
+ # Create working directory
86
+ working_dir = self.onnx_output_dir or tempfile.mkdtemp()
87
+ tmp_path = Path(working_dir) / f"{target_repo.split('/')[-1]}.onnx"
88
+
89
+ logger.info(f"Loading model {source_model}...")
90
+ model = AutoModelForSemanticSegmentation.from_pretrained(
91
+ source_model,
92
+ **model_kwargs
93
+ )
94
+ processor = SegformerImageProcessor.from_pretrained(
95
+ source_model,
96
+ **model_kwargs
97
+ )
98
+
99
+ # Set model to evaluation mode
100
+ model.eval()
101
+
102
+ # Create dummy input
103
+ dummy_input = processor(
104
+ images=torch.zeros(1, 3, 224, 224),
105
+ return_tensors="pt"
106
+ )
107
+
108
+ # Export to ONNX
109
+ logger.info(f"Converting to ONNX format... Output path: {tmp_path}")
110
+ torch.onnx.export(
111
+ model,
112
+ (dummy_input['pixel_values'],),
113
+ tmp_path,
114
+ input_names=['input'],
115
+ output_names=['output'],
116
+ dynamic_axes={
117
+ 'input': {0: 'batch_size', 2: 'height', 3: 'width'},
118
+ 'output': {0: 'batch_size'}
119
+ },
120
+ opset_version=12,
121
+ do_constant_folding=True
122
+ )
123
+
124
+ # Create model card with environment info
125
+ model_card = f"""---
126
+ base_model: {source_model}
127
+ tags:
128
+ - onnx
129
+ - semantic-segmentation
130
+ ---
131
+
132
+ # ONNX Model converted from {source_model}
133
+
134
+ This is an ONNX version of the model {source_model}, converted automatically.
135
+
136
+ ## Model Information
137
+ - Original Model: {source_model}
138
+ - ONNX Opset Version: 12
139
+ - Input Shape: Dynamic (batch_size, 3, height, width)
140
+
141
+ ## Usage
142
+
143
+ ```python
144
+ import onnxruntime as ort
145
+ import numpy as np
146
+
147
+ # Load ONNX model
148
+ session = ort.InferenceSession("model.onnx")
149
+
150
+ # Prepare input
151
+ input_data = np.zeros((1, 3, 224, 224), dtype=np.float32)
152
+
153
+ # Run inference
154
+ outputs = session.run(None, {{"input": input_data}})
155
+ ```
156
+ """
157
+ # Save model card
158
+ readme_path = Path(working_dir) / "README.md"
159
+ with open(readme_path, "w") as f:
160
+ f.write(model_card)
161
+
162
+ # Push files to hub
163
+ logger.info(f"Pushing files to {target_repo}...")
164
+ self.api.upload_file(
165
+ path_or_fileobj=str(tmp_path),
166
+ path_in_repo="model.onnx",
167
+ repo_id=target_repo,
168
+ token=self.token
169
+ )
170
+ self.api.upload_file(
171
+ path_or_fileobj=str(readme_path),
172
+ path_in_repo="README.md",
173
+ repo_id=target_repo,
174
+ token=self.token
175
+ )
176
+
177
+ logger.info(f"Successfully pushed ONNX model to {target_repo}")
178
+ return True
179
+
180
+ except Exception as e:
181
+ logger.error(f"Error during conversion and upload: {e}")
182
+ return False
183
+
184
+ def main():
185
+ parser = argparse.ArgumentParser(description='Convert and push model to ONNX format on Hugging Face Hub')
186
+ parser.add_argument('--source', type=str, required=True,
187
+ help='Source model name (e.g., "sayeed99/segformer-b3-fashion")')
188
+ parser.add_argument('--target', type=str, required=True,
189
+ help='Target repository name (e.g., "your-username/model-name-onnx")')
190
+ parser.add_argument('--token', type=str, help='Hugging Face token (optional)')
191
+
192
+ args = parser.parse_args()
193
+
194
+ converter = HFOnnxConverter(token=args.token)
195
+ converter.setup_repository(args.target)
196
+ success = converter.convert_and_push(args.source, args.target)
197
+
198
+ if not success:
199
+ exit(1)
200
+
201
+ if __name__ == "__main__":
202
+ main()
requirements.txt CHANGED
@@ -4,4 +4,7 @@ torch
4
  torchvision
5
  transformers
6
  pillow
7
- numpy
 
 
 
 
4
  torchvision
5
  transformers
6
  pillow
7
+ numpy
8
+ torch
9
+ dotenv
10
+ onnx
response.json CHANGED
The diff for this file is too large to render. See raw diff