testing / src /paraphrase /Paraphrase.py
vijay399's picture
Update src/paraphrase/Paraphrase.py
a3f4748
from nltk.tokenize import sent_tokenize
from transformers import PegasusForConditionalGeneration, PegasusTokenizer
import torch
import src.exception.Exception.Exception as ExceptionCustom
METHOD = "PARAPHRASE"
tokenizer = PegasusTokenizer.from_pretrained('tuner007/pegasus_paraphrase')
model = PegasusForConditionalGeneration.from_pretrained('tuner007/pegasus_paraphrase')
def paraphraseParaphraseMethod(requestValue : str):
exception = ""
result_value = ""
exception = ExceptionCustom.checkForException(requestValue, METHOD)
if exception != "":
return "", exception
tokenized_sent_list = sent_tokenize(requestValue)
for SENTENCE in tokenized_sent_list:
text = "paraphrase: " + SENTENCE
encoding = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
input_ids, attention_masks = encoding["input_ids"], encoding["attention_mask"]
beam_outputs = model.generate(
input_ids=input_ids,
attention_mask=attention_masks,
max_length=512,
num_beams=5,
length_penalty=0.8,
early_stopping=True
)
for beam_output in beam_outputs:
text_para = tokenizer.decode(beam_output, skip_special_tokens=True, clean_up_tokenization_spaces=True)
if SENTENCE.lower().strip() != text_para.lower().strip():
result_value += text_para + " "
break
return result_value, ""