segmentation / hf_onnx_converter.py
Alex
updated to onnx
b2702fe
import torch
from transformers import AutoModelForSemanticSegmentation, SegformerImageProcessor
from huggingface_hub import HfApi, create_repo, upload_file, model_info
import os
from dotenv import load_dotenv
from pathlib import Path
import logging
import argparse
import tempfile
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Load environment variables
load_dotenv()
class ConfigurationError(Exception):
"""Raised when required environment variables are missing"""
pass
class HFOnnxConverter:
def __init__(self, token=None):
# Load configuration from environment
self.token = token or os.getenv("HF_TOKEN")
self.model_cache_dir = os.getenv("MODEL_CACHE_DIR")
self.onnx_output_dir = os.getenv("ONNX_OUTPUT_DIR")
# Validate configuration
if not self.token:
raise ConfigurationError("HF_TOKEN is required in environment variables")
# Create directories if they don't exist
for directory in [self.model_cache_dir, self.onnx_output_dir]:
if directory:
Path(directory).mkdir(parents=True, exist_ok=True)
self.api = HfApi()
# Login to Hugging Face
try:
self.api.whoami(token=self.token)
logger.info("Successfully authenticated with Hugging Face")
except Exception as e:
raise ConfigurationError(f"Failed to authenticate with Hugging Face: {str(e)}")
def setup_repository(self, repo_name: str) -> str:
"""Create or get repository on Hugging Face Hub"""
try:
create_repo(
repo_name,
token=self.token,
private=False,
exist_ok=True
)
logger.info(f"Repository {repo_name} is ready")
return repo_name
except Exception as e:
logger.error(f"Error setting up repository: {e}")
raise
def verify_model_exists(self, model_name: str) -> bool:
"""Verify if the model exists and is accessible"""
try:
model_info(model_name, token=self.token)
return True
except Exception as e:
logger.error(f"Model verification failed: {str(e)}")
return False
def convert_and_push(self, source_model: str, target_repo: str):
"""Convert model to ONNX and push to Hugging Face Hub"""
try:
# Verify model exists and is accessible
if not self.verify_model_exists(source_model):
raise ValueError(f"Model {source_model} is not accessible. Check if the model exists and you have proper permissions.")
# Use model cache directory if specified
model_kwargs = {
"token": self.token
}
if self.model_cache_dir:
model_kwargs["cache_dir"] = self.model_cache_dir
# Create working directory
working_dir = self.onnx_output_dir or tempfile.mkdtemp()
tmp_path = Path(working_dir) / f"{target_repo.split('/')[-1]}.onnx"
logger.info(f"Loading model {source_model}...")
model = AutoModelForSemanticSegmentation.from_pretrained(
source_model,
**model_kwargs
)
processor = SegformerImageProcessor.from_pretrained(
source_model,
**model_kwargs
)
# Set model to evaluation mode
model.eval()
# Create dummy input
dummy_input = processor(
images=torch.zeros(1, 3, 224, 224),
return_tensors="pt"
)
# Export to ONNX
logger.info(f"Converting to ONNX format... Output path: {tmp_path}")
torch.onnx.export(
model,
(dummy_input['pixel_values'],),
tmp_path,
input_names=['input'],
output_names=['output'],
dynamic_axes={
'input': {0: 'batch_size', 2: 'height', 3: 'width'},
'output': {0: 'batch_size'}
},
opset_version=12,
do_constant_folding=True
)
# Create model card with environment info
model_card = f"""---
base_model: {source_model}
tags:
- onnx
- semantic-segmentation
---
# ONNX Model converted from {source_model}
This is an ONNX version of the model {source_model}, converted automatically.
## Model Information
- Original Model: {source_model}
- ONNX Opset Version: 12
- Input Shape: Dynamic (batch_size, 3, height, width)
## Usage
```python
import onnxruntime as ort
import numpy as np
# Load ONNX model
session = ort.InferenceSession("model.onnx")
# Prepare input
input_data = np.zeros((1, 3, 224, 224), dtype=np.float32)
# Run inference
outputs = session.run(None, {{"input": input_data}})
```
"""
# Save model card
readme_path = Path(working_dir) / "README.md"
with open(readme_path, "w") as f:
f.write(model_card)
# Push files to hub
logger.info(f"Pushing files to {target_repo}...")
self.api.upload_file(
path_or_fileobj=str(tmp_path),
path_in_repo="model.onnx",
repo_id=target_repo,
token=self.token
)
self.api.upload_file(
path_or_fileobj=str(readme_path),
path_in_repo="README.md",
repo_id=target_repo,
token=self.token
)
logger.info(f"Successfully pushed ONNX model to {target_repo}")
return True
except Exception as e:
logger.error(f"Error during conversion and upload: {e}")
return False
def main():
parser = argparse.ArgumentParser(description='Convert and push model to ONNX format on Hugging Face Hub')
parser.add_argument('--source', type=str, required=True,
help='Source model name (e.g., "sayeed99/segformer-b3-fashion")')
parser.add_argument('--target', type=str, required=True,
help='Target repository name (e.g., "your-username/model-name-onnx")')
parser.add_argument('--token', type=str, help='Hugging Face token (optional)')
args = parser.parse_args()
converter = HFOnnxConverter(token=args.token)
converter.setup_repository(args.target)
success = converter.convert_and_push(args.source, args.target)
if not success:
exit(1)
if __name__ == "__main__":
main()