File size: 1,481 Bytes
02dbb6b
f2a57a2
d120873
 
 
 
 
38b6bd6
 
 
d120873
 
cfbbe17
 
d120873
cfbbe17
 
 
d120873
cfbbe17
d120873
cfbbe17
 
d120873
a3f4748
cfbbe17
d120873
 
 
 
 
f2a57a2
 
 
d120873
 
cfbbe17
 
 
 
 
 
d120873
cfbbe17
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from nltk.tokenize import sent_tokenize
from transformers import PegasusForConditionalGeneration, PegasusTokenizer
import torch
import src.exception.Exception.Exception as ExceptionCustom

METHOD = "PARAPHRASE"

tokenizer = PegasusTokenizer.from_pretrained('tuner007/pegasus_paraphrase')
model = PegasusForConditionalGeneration.from_pretrained('tuner007/pegasus_paraphrase')


def paraphraseParaphraseMethod(requestValue : str):
    exception = ""
    result_value = ""

    exception = ExceptionCustom.checkForException(requestValue, METHOD)
    if exception != "":
        return "", exception

    tokenized_sent_list = sent_tokenize(requestValue)

    for SENTENCE in tokenized_sent_list:
        text = "paraphrase: " + SENTENCE

        encoding = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
        input_ids, attention_masks = encoding["input_ids"], encoding["attention_mask"]

        beam_outputs = model.generate(
            input_ids=input_ids, 
            attention_mask=attention_masks,
            max_length=512,
            num_beams=5,
            length_penalty=0.8,
            early_stopping=True
        )

        for beam_output in beam_outputs:
            text_para = tokenizer.decode(beam_output, skip_special_tokens=True, clean_up_tokenization_spaces=True)

            if SENTENCE.lower().strip() != text_para.lower().strip():
                result_value += text_para + " "
                break

    return result_value, ""