testing / src /translate /Translate.py
vijay399's picture
Duplicate from BlackKakapo/ParaphraseAPI
d120873
from nltk.tokenize import sent_tokenize
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import src.exception.Exception.Exception as ExceptionCustom
METHOD = "TRANSLATE"
tokenizerROMENG = AutoTokenizer.from_pretrained("BlackKakapo/opus-mt-ro-en")
modelROMENG = AutoModelForSeq2SeqLM.from_pretrained("BlackKakapo/opus-mt-ro-en")
tokenizerENGROM = AutoTokenizer.from_pretrained("BlackKakapo/opus-mt-en-ro")
modelENGROM = AutoModelForSeq2SeqLM.from_pretrained("BlackKakapo/opus-mt-en-ro")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
modelROMENG.to(device)
modelENGROM.to(device)
def paraphraseTranslateMethod(requestValue : str):
exception = ""
result_value = ""
exception = ExceptionCustom.checkForException(requestValue, METHOD)
if exception != "":
return "", exception
tokenized_sent_list = sent_tokenize(requestValue)
for SENTENCE in tokenized_sent_list:
input_ids1 = tokenizerROMENG(SENTENCE, return_tensors='pt').to(device)
output1 = modelROMENG.generate(
input_ids=input_ids1.input_ids,
do_sample=True,
max_length=256,
top_k=90,
top_p=0.97,
early_stopping=False
)
result1 = tokenizerROMENG.batch_decode(output1, skip_special_tokens=True)[0]
input_ids = tokenizerENGROM(result1, return_tensors='pt').to(device)
output = modelENGROM.generate(
input_ids=input_ids.input_ids,
do_sample=True,
max_length=256,
top_k=90,
top_p=0.97,
early_stopping=False
)
result = tokenizerENGROM.batch_decode(output, skip_special_tokens=True)[0]
result_value += result + " "
return result_value, ""