|
from sentence_transformers import SentenceTransformer
|
|
import torch
|
|
from transformers import AutoTokenizer
|
|
def convert_onnx():
|
|
|
|
model = SentenceTransformer('model/distiluse-base-multilingual-cased-v2')
|
|
torch.save(model, "model/distiluse-base-multilingual-cased-v2.pt")
|
|
|
|
model = torch.load("model/distiluse-base-multilingual-cased-v2.pt")
|
|
tokenizer = AutoTokenizer.from_pretrained('model/distiluse-base-multilingual-cased-v2')
|
|
|
|
|
|
lst_input = ["Pham Minh Chinh is Vietnam's Prime Minister"]
|
|
|
|
x = tokenizer(lst_input, padding="max_length", truncation=True)
|
|
|
|
print(x)
|
|
save_path = 'tensorRT/models/distiluse-base-multilingual-cased-v2.onnx'
|
|
|
|
torch.onnx.export(model, (torch.tensor(x['input_ids'], dtype=torch.long),torch.tensor(x['attention_mask'], dtype=torch.long)), save_path, export_params=True, opset_version=13, do_constant_folding=True,
|
|
input_names = ['input_ids', 'attention_mask'],
|
|
output_names = ['output'],
|
|
dynamic_axes={'input_ids' : {0 : 'batch_size'}, 'attention_mask': {0 : 'batch_size'},
|
|
'output' : {0 : 'batch_size'}}
|
|
)
|
|
|
|
|
|
def convert_onnx_(model_name= "model/model-sup-simcse-vn", pt_model = 'model-sup-simcse-vn', max_length = 256, save_path = "tensorRT/models/tensorRT/models/model-sup-simcse-vn.onnx"):
|
|
model = SentenceTransformer(model_name)
|
|
torch.save(model, f"model/{pt_model}.pt")
|
|
|
|
model = torch.load(f"model/{pt_model}.pt")
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
|
|
|
lst_input = ["Pham Minh Chinh is Vietnam's Prime Minister"]
|
|
|
|
x = tokenizer(lst_input, padding="max_length", truncation=True, max_length=256)
|
|
|
|
print(x)
|
|
|
|
torch.onnx.export(model, (torch.tensor(x['input_ids'], dtype=torch.long),torch.tensor(x['attention_mask'], dtype=torch.long), torch.tensor(x['token_type_ids'], dtype=torch.long)), save_path, export_params=True, opset_version=13, do_constant_folding=True,
|
|
input_names = ['input_ids', 'attention_mask','token_type_ids'],
|
|
output_names = ['output'],
|
|
dynamic_axes={'input_ids' : {0 : 'batch_size'}, 'attention_mask': {0 : 'batch_size'},
|
|
'output' : {0 : 'batch_size'}}
|
|
)
|
|
if __name__ == '__main__':
|
|
convert_onnx_() |