""" inference_onnx.py This script leverages ONNX runtime to perform inference with a pre-trained model. """ import json import torch import sys import numpy as np import onnxruntime as rt from huggingface_hub import hf_hub_download from transformers import AutoTokenizer repo_path = "govtech/jina-embeddings-v2-small-en-off-topic" config_path = hf_hub_download(repo_id=repo_path, filename="config.json") config_path = "config.json" with open(config_path, 'r') as f: config = json.load(f) def predict(sentence1, sentence2): """ Predicts the label for a pair of sentences using a fine-tuned ONNX model. This function tokenizes the input sentences, prepares them as inputs for an ONNX model, and performs inference to predict the label and probabilities for the given sentence pair. Args: - sentence1 (str): The first input sentence. - sentence2 (str): The second input sentence. Returns: tuple: - predicted_label (int): The predicted label (e.g., 0 or 1). - probabilities (numpy.ndarray): The probabilities for each class. """ # Load model configuration model_name = config['classifier']['embedding']['model_name'] max_length = config['classifier']['embedding']['max_length'] model_fp = config['classifier']['embedding']['model_fp'] # Set device and load tokenizer device = torch.device("cuda") if torch.cuda.is_available() else "cpu" tokenizer = AutoTokenizer.from_pretrained(model_name) # Get inputs inputs1 = tokenizer(sentence1, return_tensors="pt", truncation=True, padding="max_length", max_length=max_length) inputs2 = tokenizer(sentence2, return_tensors="pt", truncation=True, padding="max_length", max_length=max_length) input_ids1 = inputs1['input_ids'].to(device) attention_mask1 = inputs1['attention_mask'].to(device) input_ids2 = inputs2['input_ids'].to(device) attention_mask2 = inputs2['attention_mask'].to(device) # Download the classifier from HuggingFace hub local_model_fp = model_fp local_model_fp = hf_hub_download(repo_id=repo_path, filename=model_fp) # Run inference session = rt.InferenceSession(local_model_fp) # Load the ONNX model onnx_inputs = { session.get_inputs()[0].name: input_ids1.cpu().numpy(), session.get_inputs()[1].name: attention_mask1.cpu().numpy(), session.get_inputs()[2].name: input_ids2.cpu().numpy(), session.get_inputs()[3].name: attention_mask2.cpu().numpy(), } outputs = session.run(None, onnx_inputs) probabilities = torch.softmax(torch.tensor(outputs[0]), dim=1) predicted_label = torch.argmax(probabilities, dim=1).item() return predicted_label, probabilities.cpu().numpy() if __name__ == "__main__": # Load data input_data = sys.argv[1] sentence_pairs = json.loads(input_data) # Validate input data format if not all(isinstance(pair[0], str) and isinstance(pair[1], str) for pair in sentence_pairs): raise ValueError("Each pair must contain two strings.") for idx, (sentence1, sentence2) in enumerate(sentence_pairs): # Generate prediction and scores predicted_label, probabilities = predict(sentence1, sentence2) # Print the results print(f"Pair {idx + 1}:") print(f" Sentence 1: {sentence1}") print(f" Sentence 2: {sentence2}") print(f" Predicted Label: {predicted_label}") print(f" Probabilities: {probabilities}") print('-' * 50)