|
from typing import Dict, List, Any |
|
import torch |
|
from transformers import PegasusForConditionalGeneration, PegasusTokenizer |
|
import re |
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
""" |
|
Initialize the endpoint handler with the model and tokenizer. |
|
|
|
:param path: Path to the model weights |
|
""" |
|
|
|
self.torch_device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
|
|
self.tokenizer = PegasusTokenizer.from_pretrained(path) |
|
self.model = PegasusForConditionalGeneration.from_pretrained(path).to(self.torch_device) |
|
|
|
def split_into_paragraphs(self, text: str) -> List[str]: |
|
""" |
|
Split text into paragraphs while preserving empty lines. |
|
|
|
:param text: Input text |
|
:return: List of paragraphs |
|
""" |
|
paragraphs = text.split('\n\n') |
|
return [p.strip() for p in paragraphs if p.strip()] |
|
|
|
def split_into_sentences(self, paragraph: str) -> List[str]: |
|
""" |
|
Split paragraph into sentences using regex. |
|
|
|
:param paragraph: Input paragraph |
|
:return: List of sentences |
|
""" |
|
sentences = re.split(r'(?<=[.!?])\s+', paragraph) |
|
return [s.strip() for s in sentences if s.strip()] |
|
|
|
def get_response(self, input_text: str, num_return_sequences: int = 1) -> str: |
|
""" |
|
Generate paraphrased text for a single input. |
|
|
|
:param input_text: Input sentence to paraphrase |
|
:param num_return_sequences: Number of alternative paraphrases to generate |
|
:return: Paraphrased text |
|
""" |
|
batch = self.tokenizer.prepare_seq2seq_batch( |
|
[input_text], |
|
truncation=True, |
|
padding='longest', |
|
max_length=80, |
|
return_tensors="pt" |
|
).to(self.torch_device) |
|
|
|
translated = self.model.generate( |
|
**batch, |
|
num_beams=10, |
|
num_return_sequences=num_return_sequences, |
|
temperature=1.0, |
|
repetition_penalty=2.8, |
|
length_penalty=1.2, |
|
max_length=80, |
|
min_length=5, |
|
no_repeat_ngram_size=3 |
|
) |
|
|
|
tgt_text = self.tokenizer.batch_decode(translated, skip_special_tokens=True) |
|
return tgt_text[0] |
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
""" |
|
Process the incoming request and generate paraphrased text. |
|
|
|
:param data: Request payload containing input text |
|
:return: Paraphrased text |
|
""" |
|
|
|
inputs = data.pop("inputs", data) |
|
|
|
|
|
if not isinstance(inputs, str): |
|
raise ValueError("Input must be a string") |
|
|
|
|
|
paragraphs = self.split_into_paragraphs(inputs) |
|
paraphrased_paragraphs = [] |
|
|
|
|
|
for paragraph in paragraphs: |
|
sentences = self.split_into_sentences(paragraph) |
|
paraphrased_sentences = [] |
|
|
|
for sentence in sentences: |
|
|
|
if len(sentence.split()) < 3: |
|
paraphrased_sentences.append(sentence) |
|
continue |
|
|
|
try: |
|
|
|
paraphrased = self.get_response(sentence) |
|
|
|
|
|
if not any(phrase in paraphrased.lower() for phrase in ['it\'s like', 'in other words']): |
|
paraphrased_sentences.append(paraphrased) |
|
else: |
|
paraphrased_sentences.append(sentence) |
|
except Exception as e: |
|
print(f"Error processing sentence: {e}") |
|
paraphrased_sentences.append(sentence) |
|
|
|
|
|
paraphrased_paragraphs.append(' '.join(paraphrased_sentences)) |
|
|
|
|
|
return {"outputs": '\n\n'.join(paraphrased_paragraphs)} |