language_detection / to_onnx.py
dewdev's picture
Upload 9 files
a53ef18 verified
raw
history blame
8.77 kB
import os
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from onnxruntime.quantization import quantize_dynamic, quantize_static, QuantType
from onnxruntime.quantization.calibrate import CalibrationDataReader
import onnx
import time
import numpy as np
def ensure_directory(path):
"""Create directory if it doesn't exist"""
abs_path = os.path.abspath(path)
if not os.path.exists(abs_path):
os.makedirs(abs_path)
print(f"Created directory: {abs_path}")
return abs_path
def verify_file_exists(file_path, timeout=5):
"""Verify that a file exists and is not empty"""
start_time = time.time()
while time.time() - start_time < timeout:
if os.path.exists(file_path) and os.path.getsize(file_path) > 0:
return True
time.sleep(0.1)
return False
def export_to_onnx(model, tokenizer, save_path):
"""Export model to ONNX format"""
try:
# Create a dummy input for the model
dummy_input = tokenizer("This is a sample input", return_tensors="pt")
# Export the model to ONNX
torch.onnx.export(
model,
(dummy_input["input_ids"], dummy_input["attention_mask"]),
save_path,
opset_version=14,
input_names=["input_ids", "attention_mask"],
output_names=["output"],
dynamic_axes={
"input_ids": {0: "batch_size"},
"attention_mask": {0: "batch_size"},
"output": {0: "batch_size"}
}
)
# Verify the file was created
if verify_file_exists(save_path):
print(f"Successfully exported ONNX model to {save_path}")
return True
else:
print(f"Failed to verify ONNX model at {save_path}")
return False
except Exception as e:
print(f"Error exporting to ONNX: {str(e)}")
return False
def create_calibration_dataset(tokenizer, max_length=512):
"""Generate calibration dataset for static quantization with padding"""
samples = [
"This is an English sentence.",
"Dies ist ein deutscher Satz.",
"C'est une phrase française.",
"Esta es una frase en español.",
"这是一个中文句子。",
"これは日本語の文章です。"
]
# Tokenize with padding and truncation
encoded_samples = []
for text in samples:
encoded = tokenizer(
text,
padding='max_length',
max_length=max_length,
truncation=True,
return_tensors="pt"
)
encoded_samples.append({
'input_ids': encoded['input_ids'],
'attention_mask': encoded['attention_mask']
})
return encoded_samples
class CalibrationLoader(CalibrationDataReader):
def __init__(self, calibration_data):
self.calibration_data = calibration_data
self.current_index = 0
def get_next(self):
if self.current_index >= len(self.calibration_data):
return None
current_data = self.calibration_data[self.current_index]
self.current_index += 1
# Ensure we're returning numpy arrays with the correct shape
return {
'input_ids': current_data['input_ids'].numpy(),
'attention_mask': current_data['attention_mask'].numpy()
}
def rewind(self):
self.current_index = 0
def export_to_onnx(model, tokenizer, save_path, max_length=512):
"""Export model to ONNX format with fixed dimensions"""
try:
# Create a dummy input with fixed dimensions
dummy_input = tokenizer(
"This is a sample input",
padding='max_length',
max_length=max_length,
truncation=True,
return_tensors="pt"
)
# Export the model to ONNX
torch.onnx.export(
model,
(dummy_input["input_ids"], dummy_input["attention_mask"]),
save_path,
opset_version=14,
input_names=["input_ids", "attention_mask"],
output_names=["output"],
dynamic_axes={
"input_ids": {0: "batch_size"},
"attention_mask": {0: "batch_size"}
}
)
if verify_file_exists(save_path):
print(f"Successfully exported ONNX model to {save_path}")
return True
else:
print(f"Failed to verify ONNX model at {save_path}")
return False
except Exception as e:
print(f"Error exporting to ONNX: {str(e)}")
return False
def quantize_model(base_onnx_path, onnx_dir, config_name, calibration_dataset=None):
"""
Quantize ONNX model using either dynamic or static quantization.
Args:
base_onnx_path (str): Path to the base ONNX model
onnx_dir (str): Directory to save quantized models
config_name (str): Type of quantization ('dynamic' or 'static')
calibration_dataset (list, optional): Dataset for static quantization calibration
"""
try:
quantized_model_path = os.path.join(onnx_dir, f"model_{config_name}_quantized.onnx")
if config_name == "dynamic":
print(f"\nPerforming dynamic quantization...")
quantize_dynamic(
model_input=base_onnx_path,
model_output=quantized_model_path,
weight_type=QuantType.QUInt8
)
elif config_name == "static" and calibration_dataset is not None:
print(f"\nPerforming static quantization...")
calibration_loader = CalibrationLoader(calibration_dataset)
quantize_static(
model_input=base_onnx_path,
model_output=quantized_model_path,
calibration_data_reader=calibration_loader,
quant_format=QuantType.QUInt8
)
else:
print(f"Invalid quantization configuration: {config_name}")
return False
# Verify the quantized model exists
if verify_file_exists(quantized_model_path):
print(f"Successfully created {config_name} quantized model at {quantized_model_path}")
# Print file sizes for comparison
base_size = os.path.getsize(base_onnx_path) / (1024 * 1024) # Convert to MB
quantized_size = os.path.getsize(quantized_model_path) / (1024 * 1024) # Convert to MB
print(f"Original model size: {base_size:.2f} MB")
print(f"Quantized model size: {quantized_size:.2f} MB")
print(f"Size reduction: {((base_size - quantized_size) / base_size * 100):.2f}%")
return True
else:
print(f"Failed to verify quantized model at {quantized_model_path}")
return False
except Exception as e:
print(f"Error during {config_name} quantization: {str(e)}")
return False
def main():
# Get absolute paths
current_dir = os.path.abspath(os.getcwd())
onnx_dir = ensure_directory(os.path.join(current_dir, "onnx"))
base_onnx_path = os.path.join(onnx_dir, "model.onnx")
print(f"Working directory: {current_dir}")
print(f"ONNX directory: {onnx_dir}")
print(f"Base ONNX model path: {base_onnx_path}")
# Step 1: Load model and tokenizer
print("\nLoading model and tokenizer...")
model_name = "alexneakameni/language_detection"
model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Get the model's default max_length
max_length = tokenizer.model_max_length
# Step 2: Export base ONNX model
if not export_to_onnx(model, tokenizer, base_onnx_path, max_length):
print("Failed to export base ONNX model. Exiting.")
return
# Verify the ONNX model
try:
print(f"Verifying ONNX model at: {base_onnx_path}")
onnx_model = onnx.load(base_onnx_path)
print("Successfully verified ONNX model")
except Exception as e:
print(f"Error verifying ONNX model: {str(e)}")
return
# Step 3: Create calibration dataset
calibration_dataset = create_calibration_dataset(tokenizer, max_length)
# Step 4: Create quantized versions
print("\nCreating quantized versions...")
# Dynamic quantization
quantize_model(
base_onnx_path=base_onnx_path,
onnx_dir=onnx_dir,
config_name="dynamic"
)
# Static quantization
quantize_model(
base_onnx_path=base_onnx_path,
onnx_dir=onnx_dir,
config_name="static",
calibration_dataset=calibration_dataset
)
if __name__ == "__main__":
main()