language_detection / to_onnx.py
dewdev's picture
Upload to_onnx.py
8cfa1b2 verified
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import onnxruntime as ort
import numpy as np
def convert_and_test_onnx(model_name, output_path="language_detection.onnx", test_text="This is a test sentence."):
"""
Converts a Hugging Face model to ONNX, modifies the tokenizer, and tests the ONNX model.
"""
try:
# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
# Modify the tokenizer's normalizer
if hasattr(tokenizer.backend_tokenizer.normalizer, "normalizations"):
tokenizer.backend_tokenizer.normalizer.normalizations = []
tokenizer.save_pretrained("./modified_tokenizer")
# Export the model to ONNX
dummy_input = tokenizer("This is a test sentence.", return_tensors="pt")
torch.onnx.export(
model,
(dummy_input["input_ids"], dummy_input["attention_mask"]),
output_path,
input_names=["input_ids", "attention_mask"],
output_names=["output"],
dynamic_axes={
"input_ids": {0: "batch", 1: "sequence"},
"attention_mask": {0: "batch", 1: "sequence"},
"output": {0: "batch"},
},
opset_version=14,
)
print(f"Model successfully converted and saved to {output_path}")
# Test the ONNX model
ort_session = ort.InferenceSession(output_path)
tokenizer_test = AutoTokenizer.from_pretrained("./modified_tokenizer")
# Explicitly set return_token_type_ids=False
inputs = tokenizer_test(test_text, return_tensors="np", return_token_type_ids=False)
ort_inputs = {k: v for k, v in inputs.items()}
ort_outputs = ort_session.run(None, ort_inputs)
logits = ort_outputs[0]
predicted_class_id = np.argmax(logits, axis=-1)
label_list = model.config.id2label
predicted_label = label_list[predicted_class_id[0]]
print(f"Test text: {test_text}")
print(f"Predicted label: {predicted_label}")
except Exception as e:
print(f"Error during conversion or testing: {e}")
if __name__ == "__main__":
model_name = "dewdev/language_detection"
test_text = "मैंने राजा को हिंदी में एक पत्र लिखा।"
convert_and_test_onnx(model_name, test_text=test_text)