import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification import onnxruntime as ort import numpy as np def convert_and_test_onnx(model_name, output_path="language_detection.onnx", test_text="This is a test sentence."): """ Converts a Hugging Face model to ONNX, modifies the tokenizer, and tests the ONNX model. """ try: # Load the tokenizer and model tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSequenceClassification.from_pretrained(model_name) # Modify the tokenizer's normalizer if hasattr(tokenizer.backend_tokenizer.normalizer, "normalizations"): tokenizer.backend_tokenizer.normalizer.normalizations = [] tokenizer.save_pretrained("./modified_tokenizer") # Export the model to ONNX dummy_input = tokenizer("This is a test sentence.", return_tensors="pt") torch.onnx.export( model, (dummy_input["input_ids"], dummy_input["attention_mask"]), output_path, input_names=["input_ids", "attention_mask"], output_names=["output"], dynamic_axes={ "input_ids": {0: "batch", 1: "sequence"}, "attention_mask": {0: "batch", 1: "sequence"}, "output": {0: "batch"}, }, opset_version=14, ) print(f"Model successfully converted and saved to {output_path}") # Test the ONNX model ort_session = ort.InferenceSession(output_path) tokenizer_test = AutoTokenizer.from_pretrained("./modified_tokenizer") # Explicitly set return_token_type_ids=False inputs = tokenizer_test(test_text, return_tensors="np", return_token_type_ids=False) ort_inputs = {k: v for k, v in inputs.items()} ort_outputs = ort_session.run(None, ort_inputs) logits = ort_outputs[0] predicted_class_id = np.argmax(logits, axis=-1) label_list = model.config.id2label predicted_label = label_list[predicted_class_id[0]] print(f"Test text: {test_text}") print(f"Predicted label: {predicted_label}") except Exception as e: print(f"Error during conversion or testing: {e}") if __name__ == "__main__": model_name = "dewdev/language_detection" test_text = "मैंने राजा को हिंदी में एक पत्र लिखा।" convert_and_test_onnx(model_name, test_text=test_text)