""" inference_onnx.py This script leverages ONNX runtime to perform inference with a pre-trained model. """ import json import torch import sys import numpy as np import onnxruntime as rt from huggingface_hub import hf_hub_download from transformers import AutoTokenizer repo_path = "govtech/stsb-roberta-base-off-topic" config_path = hf_hub_download(repo_id=repo_path, filename="config.json") config_path = "config.json" with open(config_path, 'r') as f: config = json.load(f) def predict(sentence1, sentence2): # Configuration model_name = config['classifier']['embedding']['model_name'] max_length = config['classifier']['embedding']['max_length'] model_fp = config['classifier']['embedding']['model_fp'] device = torch.device("cuda") if torch.cuda.is_available() else "cpu" tokenizer = AutoTokenizer.from_pretrained(model_name) # Get inputs encoding = tokenizer( sentence1, sentence2, # Takes in a two sentences as a pair return_tensors="pt", truncation=True, padding="max_length", max_length=max_length, return_token_type_ids=False ) input_ids = encoding["input_ids"].to(device) attention_mask = encoding["attention_mask"].to(device) # Download the classifier from HuggingFace hub local_model_fp = model_fp local_model_fp = hf_hub_download(repo_id=repo_path, filename=model_fp) # Run inference session = rt.InferenceSession(local_model_fp) # Load the ONNX model onnx_inputs = { session.get_inputs()[0].name: input_ids.cpu().numpy(), session.get_inputs()[1].name: attention_mask.cpu().numpy() } outputs = session.run(None, onnx_inputs) probabilities = torch.softmax(torch.tensor(outputs[0]), dim=1) predicted_label = torch.argmax(probabilities, dim=1).item() return predicted_label, probabilities.cpu().numpy() if __name__ == "__main__": # Load data input_data = sys.argv[1] sentence_pairs = json.loads(input_data) # Validate input data format if not all(isinstance(pair[0], str) and isinstance(pair[1], str) for pair in sentence_pairs): raise ValueError("Each pair must contain two strings.") for idx, (sentence1, sentence2) in enumerate(sentence_pairs): # Generate prediction and scores predicted_label, probabilities = predict(sentence1, sentence2) # Print the results print(f"Pair {idx + 1}:") print(f" Sentence 1: {sentence1}") print(f" Sentence 2: {sentence2}") print(f" Predicted Label: {predicted_label}") print(f" Probabilities: {probabilities}") print('-' * 50)