File size: 2,519 Bytes
a53ef18 8cfa1b2 a53ef18 8cfa1b2 a53ef18 8cfa1b2 a53ef18 8cfa1b2 a53ef18 8cfa1b2 a53ef18 8cfa1b2 a53ef18 8cfa1b2 a53ef18 8cfa1b2 a53ef18 8cfa1b2 a53ef18 8cfa1b2 a53ef18 8cfa1b2 a53ef18 8cfa1b2 a53ef18 8cfa1b2 a53ef18 8cfa1b2 a53ef18 8cfa1b2 a53ef18 8cfa1b2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import onnxruntime as ort
import numpy as np
def convert_and_test_onnx(model_name, output_path="language_detection.onnx", test_text="This is a test sentence."):
"""
Converts a Hugging Face model to ONNX, modifies the tokenizer, and tests the ONNX model.
"""
try:
# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
# Modify the tokenizer's normalizer
if hasattr(tokenizer.backend_tokenizer.normalizer, "normalizations"):
tokenizer.backend_tokenizer.normalizer.normalizations = []
tokenizer.save_pretrained("./modified_tokenizer")
# Export the model to ONNX
dummy_input = tokenizer("This is a test sentence.", return_tensors="pt")
torch.onnx.export(
model,
(dummy_input["input_ids"], dummy_input["attention_mask"]),
output_path,
input_names=["input_ids", "attention_mask"],
output_names=["output"],
dynamic_axes={
"input_ids": {0: "batch", 1: "sequence"},
"attention_mask": {0: "batch", 1: "sequence"},
"output": {0: "batch"},
},
opset_version=14,
)
print(f"Model successfully converted and saved to {output_path}")
# Test the ONNX model
ort_session = ort.InferenceSession(output_path)
tokenizer_test = AutoTokenizer.from_pretrained("./modified_tokenizer")
# Explicitly set return_token_type_ids=False
inputs = tokenizer_test(test_text, return_tensors="np", return_token_type_ids=False)
ort_inputs = {k: v for k, v in inputs.items()}
ort_outputs = ort_session.run(None, ort_inputs)
logits = ort_outputs[0]
predicted_class_id = np.argmax(logits, axis=-1)
label_list = model.config.id2label
predicted_label = label_list[predicted_class_id[0]]
print(f"Test text: {test_text}")
print(f"Predicted label: {predicted_label}")
except Exception as e:
print(f"Error during conversion or testing: {e}")
if __name__ == "__main__":
model_name = "dewdev/language_detection"
test_text = "मैंने राजा को हिंदी में एक पत्र लिखा।"
convert_and_test_onnx(model_name, test_text=test_text)
|