from .exec_backends.trt_loader import TrtModel, encode as encode_trt | |
from transformers import AutoTokenizer | |
import math | |
tokenizer_en = AutoTokenizer.from_pretrained("tensorRT/models/paraphrase-mpnet-base-v2") | |
model_en = TrtModel("tensorRT/models/paraphrase-mpnet-base-v2.engine") | |
tokenizer_multilingual = AutoTokenizer.from_pretrained("tensorRT/models/paraphrase-multilingual-MiniLM-L12-v2") | |
model_multilingual= TrtModel("tensorRT/models/paraphrase-multilingual-MiniLM-L12-v2.engine") | |
def encode(sentences, lang, batch_size = 8): | |
if batch_size >=8: | |
batch_size = 8 | |
all_embs = [] | |
NUM_BATCH = math.ceil(len(sentences) / batch_size) | |
for j in range(NUM_BATCH): | |
lst_sen = sentences[j*batch_size: j*batch_size + batch_size] | |
if lang == 'en': | |
# print(lst_sen) | |
embs = encode_trt(lst_sen, tokenizer=tokenizer_en, trt_model= model_en, use_token_type_ids=False) | |
else: | |
# print(lst_sen) | |
embs = encode_trt(lst_sen, tokenizer=tokenizer_multilingual, trt_model= model_multilingual, use_token_type_ids=False) | |
all_embs.extend(embs) | |
return all_embs | |