File size: 4,229 Bytes
923abb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from typing import Dict, List, Any
import torch
from transformers import PegasusForConditionalGeneration, PegasusTokenizer
import re

class EndpointHandler:
    def __init__(self, path=""):
        """
        Initialize the endpoint handler with the model and tokenizer.
        
        :param path: Path to the model weights
        """
        # Determine the device
        self.torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
        
        # Load tokenizer and model
        self.tokenizer = PegasusTokenizer.from_pretrained(path)
        self.model = PegasusForConditionalGeneration.from_pretrained(path).to(self.torch_device)

    def split_into_paragraphs(self, text: str) -> List[str]:
        """
        Split text into paragraphs while preserving empty lines.
        
        :param text: Input text
        :return: List of paragraphs
        """
        paragraphs = text.split('\n\n')
        return [p.strip() for p in paragraphs if p.strip()]

    def split_into_sentences(self, paragraph: str) -> List[str]:
        """
        Split paragraph into sentences using regex.
        
        :param paragraph: Input paragraph
        :return: List of sentences
        """
        sentences = re.split(r'(?<=[.!?])\s+', paragraph)
        return [s.strip() for s in sentences if s.strip()]

    def get_response(self, input_text: str, num_return_sequences: int = 1) -> str:
        """
        Generate paraphrased text for a single input.
        
        :param input_text: Input sentence to paraphrase
        :param num_return_sequences: Number of alternative paraphrases to generate
        :return: Paraphrased text
        """
        batch = self.tokenizer.prepare_seq2seq_batch(
            [input_text],
            truncation=True,
            padding='longest',
            max_length=80,
            return_tensors="pt"
        ).to(self.torch_device)

        translated = self.model.generate(
            **batch,
            num_beams=10,
            num_return_sequences=num_return_sequences,
            temperature=1.0,
            repetition_penalty=2.8,
            length_penalty=1.2,
            max_length=80,
            min_length=5,
            no_repeat_ngram_size=3
        )

        tgt_text = self.tokenizer.batch_decode(translated, skip_special_tokens=True)
        return tgt_text[0]

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """
        Process the incoming request and generate paraphrased text.
        
        :param data: Request payload containing input text
        :return: Paraphrased text
        """
        # Extract input text from the payload
        inputs = data.pop("inputs", data)
        
        # If input is not a string, raise an error
        if not isinstance(inputs, str):
            raise ValueError("Input must be a string")

        # Split text into paragraphs
        paragraphs = self.split_into_paragraphs(inputs)
        paraphrased_paragraphs = []

        # Process each paragraph
        for paragraph in paragraphs:
            sentences = self.split_into_sentences(paragraph)
            paraphrased_sentences = []

            for sentence in sentences:
                # Skip very short sentences
                if len(sentence.split()) < 3:
                    paraphrased_sentences.append(sentence)
                    continue

                try:
                    # Paraphrase the sentence
                    paraphrased = self.get_response(sentence)
                    
                    # Avoid unwanted paraphrases
                    if not any(phrase in paraphrased.lower() for phrase in ['it\'s like', 'in other words']):
                        paraphrased_sentences.append(paraphrased)
                    else:
                        paraphrased_sentences.append(sentence)
                except Exception as e:
                    print(f"Error processing sentence: {e}")
                    paraphrased_sentences.append(sentence)

            # Join sentences back into a paragraph
            paraphrased_paragraphs.append(' '.join(paraphrased_sentences))

        # Join paragraphs back into text
        return {"outputs": '\n\n'.join(paraphrased_paragraphs)}