cuongnguyen910's picture
Upload folder using huggingface_hub
5120311 verified
from .exec_backends.trt_loader import TrtModel, encode as encode_trt
from transformers import AutoTokenizer
import math
tokenizer_en = AutoTokenizer.from_pretrained("tensorRT/models/paraphrase-mpnet-base-v2")
model_en = TrtModel("tensorRT/models/paraphrase-mpnet-base-v2.engine")
tokenizer_multilingual = AutoTokenizer.from_pretrained("tensorRT/models/paraphrase-multilingual-MiniLM-L12-v2")
model_multilingual= TrtModel("tensorRT/models/paraphrase-multilingual-MiniLM-L12-v2.engine")
def encode(sentences, lang, batch_size = 8):
if batch_size >=8:
batch_size = 8
all_embs = []
NUM_BATCH = math.ceil(len(sentences) / batch_size)
for j in range(NUM_BATCH):
lst_sen = sentences[j*batch_size: j*batch_size + batch_size]
if lang == 'en':
# print(lst_sen)
embs = encode_trt(lst_sen, tokenizer=tokenizer_en, trt_model= model_en, use_token_type_ids=False)
else:
# print(lst_sen)
embs = encode_trt(lst_sen, tokenizer=tokenizer_multilingual, trt_model= model_multilingual, use_token_type_ids=False)
all_embs.extend(embs)
return all_embs