|
import os |
|
import torch |
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer |
|
from onnxruntime.quantization import quantize_dynamic, quantize_static, QuantType |
|
from onnxruntime.quantization.calibrate import CalibrationDataReader |
|
import onnx |
|
import time |
|
import numpy as np |
|
|
|
def ensure_directory(path): |
|
"""Create directory if it doesn't exist""" |
|
abs_path = os.path.abspath(path) |
|
if not os.path.exists(abs_path): |
|
os.makedirs(abs_path) |
|
print(f"Created directory: {abs_path}") |
|
return abs_path |
|
|
|
def verify_file_exists(file_path, timeout=5): |
|
"""Verify that a file exists and is not empty""" |
|
start_time = time.time() |
|
while time.time() - start_time < timeout: |
|
if os.path.exists(file_path) and os.path.getsize(file_path) > 0: |
|
return True |
|
time.sleep(0.1) |
|
return False |
|
|
|
def export_to_onnx(model, tokenizer, save_path): |
|
"""Export model to ONNX format""" |
|
try: |
|
|
|
dummy_input = tokenizer("This is a sample input", return_tensors="pt") |
|
|
|
|
|
torch.onnx.export( |
|
model, |
|
(dummy_input["input_ids"], dummy_input["attention_mask"]), |
|
save_path, |
|
opset_version=14, |
|
input_names=["input_ids", "attention_mask"], |
|
output_names=["output"], |
|
dynamic_axes={ |
|
"input_ids": {0: "batch_size"}, |
|
"attention_mask": {0: "batch_size"}, |
|
"output": {0: "batch_size"} |
|
} |
|
) |
|
|
|
|
|
if verify_file_exists(save_path): |
|
print(f"Successfully exported ONNX model to {save_path}") |
|
return True |
|
else: |
|
print(f"Failed to verify ONNX model at {save_path}") |
|
return False |
|
except Exception as e: |
|
print(f"Error exporting to ONNX: {str(e)}") |
|
return False |
|
|
|
def create_calibration_dataset(tokenizer, max_length=512): |
|
"""Generate calibration dataset for static quantization with padding""" |
|
samples = [ |
|
"This is an English sentence.", |
|
"Dies ist ein deutscher Satz.", |
|
"C'est une phrase française.", |
|
"Esta es una frase en español.", |
|
"这是一个中文句子。", |
|
"これは日本語の文章です。" |
|
] |
|
|
|
|
|
encoded_samples = [] |
|
for text in samples: |
|
encoded = tokenizer( |
|
text, |
|
padding='max_length', |
|
max_length=max_length, |
|
truncation=True, |
|
return_tensors="pt" |
|
) |
|
encoded_samples.append({ |
|
'input_ids': encoded['input_ids'], |
|
'attention_mask': encoded['attention_mask'] |
|
}) |
|
|
|
return encoded_samples |
|
|
|
class CalibrationLoader(CalibrationDataReader): |
|
def __init__(self, calibration_data): |
|
self.calibration_data = calibration_data |
|
self.current_index = 0 |
|
|
|
def get_next(self): |
|
if self.current_index >= len(self.calibration_data): |
|
return None |
|
|
|
current_data = self.calibration_data[self.current_index] |
|
self.current_index += 1 |
|
|
|
|
|
return { |
|
'input_ids': current_data['input_ids'].numpy(), |
|
'attention_mask': current_data['attention_mask'].numpy() |
|
} |
|
|
|
def rewind(self): |
|
self.current_index = 0 |
|
|
|
def export_to_onnx(model, tokenizer, save_path, max_length=512): |
|
"""Export model to ONNX format with fixed dimensions""" |
|
try: |
|
|
|
dummy_input = tokenizer( |
|
"This is a sample input", |
|
padding='max_length', |
|
max_length=max_length, |
|
truncation=True, |
|
return_tensors="pt" |
|
) |
|
|
|
|
|
torch.onnx.export( |
|
model, |
|
(dummy_input["input_ids"], dummy_input["attention_mask"]), |
|
save_path, |
|
opset_version=14, |
|
input_names=["input_ids", "attention_mask"], |
|
output_names=["output"], |
|
dynamic_axes={ |
|
"input_ids": {0: "batch_size"}, |
|
"attention_mask": {0: "batch_size"} |
|
} |
|
) |
|
|
|
if verify_file_exists(save_path): |
|
print(f"Successfully exported ONNX model to {save_path}") |
|
return True |
|
else: |
|
print(f"Failed to verify ONNX model at {save_path}") |
|
return False |
|
except Exception as e: |
|
print(f"Error exporting to ONNX: {str(e)}") |
|
return False |
|
|
|
def quantize_model(base_onnx_path, onnx_dir, config_name, calibration_dataset=None): |
|
""" |
|
Quantize ONNX model using either dynamic or static quantization. |
|
|
|
Args: |
|
base_onnx_path (str): Path to the base ONNX model |
|
onnx_dir (str): Directory to save quantized models |
|
config_name (str): Type of quantization ('dynamic' or 'static') |
|
calibration_dataset (list, optional): Dataset for static quantization calibration |
|
""" |
|
try: |
|
quantized_model_path = os.path.join(onnx_dir, f"model_{config_name}_quantized.onnx") |
|
|
|
if config_name == "dynamic": |
|
print(f"\nPerforming dynamic quantization...") |
|
quantize_dynamic( |
|
model_input=base_onnx_path, |
|
model_output=quantized_model_path, |
|
weight_type=QuantType.QUInt8 |
|
) |
|
|
|
elif config_name == "static" and calibration_dataset is not None: |
|
print(f"\nPerforming static quantization...") |
|
calibration_loader = CalibrationLoader(calibration_dataset) |
|
quantize_static( |
|
model_input=base_onnx_path, |
|
model_output=quantized_model_path, |
|
calibration_data_reader=calibration_loader, |
|
quant_format=QuantType.QUInt8 |
|
) |
|
|
|
else: |
|
print(f"Invalid quantization configuration: {config_name}") |
|
return False |
|
|
|
|
|
if verify_file_exists(quantized_model_path): |
|
print(f"Successfully created {config_name} quantized model at {quantized_model_path}") |
|
|
|
|
|
base_size = os.path.getsize(base_onnx_path) / (1024 * 1024) |
|
quantized_size = os.path.getsize(quantized_model_path) / (1024 * 1024) |
|
|
|
print(f"Original model size: {base_size:.2f} MB") |
|
print(f"Quantized model size: {quantized_size:.2f} MB") |
|
print(f"Size reduction: {((base_size - quantized_size) / base_size * 100):.2f}%") |
|
|
|
return True |
|
else: |
|
print(f"Failed to verify quantized model at {quantized_model_path}") |
|
return False |
|
|
|
except Exception as e: |
|
print(f"Error during {config_name} quantization: {str(e)}") |
|
return False |
|
|
|
|
|
def main(): |
|
|
|
current_dir = os.path.abspath(os.getcwd()) |
|
onnx_dir = ensure_directory(os.path.join(current_dir, "onnx")) |
|
base_onnx_path = os.path.join(onnx_dir, "model.onnx") |
|
|
|
print(f"Working directory: {current_dir}") |
|
print(f"ONNX directory: {onnx_dir}") |
|
print(f"Base ONNX model path: {base_onnx_path}") |
|
|
|
|
|
print("\nLoading model and tokenizer...") |
|
model_name = "alexneakameni/language_detection" |
|
model = AutoModelForSequenceClassification.from_pretrained(model_name) |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
|
|
max_length = tokenizer.model_max_length |
|
|
|
|
|
if not export_to_onnx(model, tokenizer, base_onnx_path, max_length): |
|
print("Failed to export base ONNX model. Exiting.") |
|
return |
|
|
|
|
|
try: |
|
print(f"Verifying ONNX model at: {base_onnx_path}") |
|
onnx_model = onnx.load(base_onnx_path) |
|
print("Successfully verified ONNX model") |
|
except Exception as e: |
|
print(f"Error verifying ONNX model: {str(e)}") |
|
return |
|
|
|
|
|
calibration_dataset = create_calibration_dataset(tokenizer, max_length) |
|
|
|
|
|
print("\nCreating quantized versions...") |
|
|
|
|
|
quantize_model( |
|
base_onnx_path=base_onnx_path, |
|
onnx_dir=onnx_dir, |
|
config_name="dynamic" |
|
) |
|
|
|
|
|
quantize_model( |
|
base_onnx_path=base_onnx_path, |
|
onnx_dir=onnx_dir, |
|
config_name="static", |
|
calibration_dataset=calibration_dataset |
|
) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|