Eemansleepdeprived commited on
Commit
923abb2
·
verified ·
1 Parent(s): 651bb25

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +117 -0
handler.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ import torch
3
+ from transformers import PegasusForConditionalGeneration, PegasusTokenizer
4
+ import re
5
+
6
+ class EndpointHandler:
7
+ def __init__(self, path=""):
8
+ """
9
+ Initialize the endpoint handler with the model and tokenizer.
10
+
11
+ :param path: Path to the model weights
12
+ """
13
+ # Determine the device
14
+ self.torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
15
+
16
+ # Load tokenizer and model
17
+ self.tokenizer = PegasusTokenizer.from_pretrained(path)
18
+ self.model = PegasusForConditionalGeneration.from_pretrained(path).to(self.torch_device)
19
+
20
+ def split_into_paragraphs(self, text: str) -> List[str]:
21
+ """
22
+ Split text into paragraphs while preserving empty lines.
23
+
24
+ :param text: Input text
25
+ :return: List of paragraphs
26
+ """
27
+ paragraphs = text.split('\n\n')
28
+ return [p.strip() for p in paragraphs if p.strip()]
29
+
30
+ def split_into_sentences(self, paragraph: str) -> List[str]:
31
+ """
32
+ Split paragraph into sentences using regex.
33
+
34
+ :param paragraph: Input paragraph
35
+ :return: List of sentences
36
+ """
37
+ sentences = re.split(r'(?<=[.!?])\s+', paragraph)
38
+ return [s.strip() for s in sentences if s.strip()]
39
+
40
+ def get_response(self, input_text: str, num_return_sequences: int = 1) -> str:
41
+ """
42
+ Generate paraphrased text for a single input.
43
+
44
+ :param input_text: Input sentence to paraphrase
45
+ :param num_return_sequences: Number of alternative paraphrases to generate
46
+ :return: Paraphrased text
47
+ """
48
+ batch = self.tokenizer.prepare_seq2seq_batch(
49
+ [input_text],
50
+ truncation=True,
51
+ padding='longest',
52
+ max_length=80,
53
+ return_tensors="pt"
54
+ ).to(self.torch_device)
55
+
56
+ translated = self.model.generate(
57
+ **batch,
58
+ num_beams=10,
59
+ num_return_sequences=num_return_sequences,
60
+ temperature=1.0,
61
+ repetition_penalty=2.8,
62
+ length_penalty=1.2,
63
+ max_length=80,
64
+ min_length=5,
65
+ no_repeat_ngram_size=3
66
+ )
67
+
68
+ tgt_text = self.tokenizer.batch_decode(translated, skip_special_tokens=True)
69
+ return tgt_text[0]
70
+
71
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
72
+ """
73
+ Process the incoming request and generate paraphrased text.
74
+
75
+ :param data: Request payload containing input text
76
+ :return: Paraphrased text
77
+ """
78
+ # Extract input text from the payload
79
+ inputs = data.pop("inputs", data)
80
+
81
+ # If input is not a string, raise an error
82
+ if not isinstance(inputs, str):
83
+ raise ValueError("Input must be a string")
84
+
85
+ # Split text into paragraphs
86
+ paragraphs = self.split_into_paragraphs(inputs)
87
+ paraphrased_paragraphs = []
88
+
89
+ # Process each paragraph
90
+ for paragraph in paragraphs:
91
+ sentences = self.split_into_sentences(paragraph)
92
+ paraphrased_sentences = []
93
+
94
+ for sentence in sentences:
95
+ # Skip very short sentences
96
+ if len(sentence.split()) < 3:
97
+ paraphrased_sentences.append(sentence)
98
+ continue
99
+
100
+ try:
101
+ # Paraphrase the sentence
102
+ paraphrased = self.get_response(sentence)
103
+
104
+ # Avoid unwanted paraphrases
105
+ if not any(phrase in paraphrased.lower() for phrase in ['it\'s like', 'in other words']):
106
+ paraphrased_sentences.append(paraphrased)
107
+ else:
108
+ paraphrased_sentences.append(sentence)
109
+ except Exception as e:
110
+ print(f"Error processing sentence: {e}")
111
+ paraphrased_sentences.append(sentence)
112
+
113
+ # Join sentences back into a paragraph
114
+ paraphrased_paragraphs.append(' '.join(paraphrased_sentences))
115
+
116
+ # Join paragraphs back into text
117
+ return {"outputs": '\n\n'.join(paraphrased_paragraphs)}