from sentence_transformers import SentenceTransformer import torch from transformers import AutoTokenizer def convert_onnx(): # model = AutoModelForQuestionAnswering.from_pretrained("deepset/roberta-base-squad2") model = SentenceTransformer('model/distiluse-base-multilingual-cased-v2') torch.save(model, "model/distiluse-base-multilingual-cased-v2.pt") model = torch.load("model/distiluse-base-multilingual-cased-v2.pt") tokenizer = AutoTokenizer.from_pretrained('model/distiluse-base-multilingual-cased-v2') # tokenizer = AutoTokenizer.from_pretrained('deepset/roberta-base-squad2') # lst_input = [" Experts predict China will soon attack Taiwan"] lst_input = ["Pham Minh Chinh is Vietnam's Prime Minister"] x = tokenizer(lst_input, padding="max_length", truncation=True) print(x) save_path = 'tensorRT/models/distiluse-base-multilingual-cased-v2.onnx' torch.onnx.export(model, (torch.tensor(x['input_ids'], dtype=torch.long),torch.tensor(x['attention_mask'], dtype=torch.long)), save_path, export_params=True, opset_version=13, do_constant_folding=True, input_names = ['input_ids', 'attention_mask'], output_names = ['output'], dynamic_axes={'input_ids' : {0 : 'batch_size'}, 'attention_mask': {0 : 'batch_size'}, 'output' : {0 : 'batch_size'}} ) def convert_onnx_(model_name= "model/model-sup-simcse-vn", pt_model = 'model-sup-simcse-vn', max_length = 256, save_path = "tensorRT/models/tensorRT/models/model-sup-simcse-vn.onnx"): model = SentenceTransformer(model_name) torch.save(model, f"model/{pt_model}.pt") model = torch.load(f"model/{pt_model}.pt") tokenizer = AutoTokenizer.from_pretrained(model_name) # tokenizer = AutoTokenizer.from_pretrained('deepset/roberta-base-squad2') # lst_input = [" Experts predict China will soon attack Taiwan"] lst_input = ["Pham Minh Chinh is Vietnam's Prime Minister"] x = tokenizer(lst_input, padding="max_length", truncation=True, max_length=256) print(x) torch.onnx.export(model, (torch.tensor(x['input_ids'], dtype=torch.long),torch.tensor(x['attention_mask'], dtype=torch.long), torch.tensor(x['token_type_ids'], dtype=torch.long)), save_path, export_params=True, opset_version=13, do_constant_folding=True, input_names = ['input_ids', 'attention_mask','token_type_ids'], output_names = ['output'], dynamic_axes={'input_ids' : {0 : 'batch_size'}, 'attention_mask': {0 : 'batch_size'}, 'output' : {0 : 'batch_size'}} ) if __name__ == '__main__': convert_onnx_()