from typing import Dict, List, Any import torch from transformers import PegasusForConditionalGeneration, PegasusTokenizer import re class EndpointHandler: def __init__(self, path=""): """ Initialize the endpoint handler with the model and tokenizer. :param path: Path to the model weights """ # Determine the device self.torch_device = 'cuda' if torch.cuda.is_available() else 'cpu' # Load tokenizer and model self.tokenizer = PegasusTokenizer.from_pretrained(path) self.model = PegasusForConditionalGeneration.from_pretrained(path).to(self.torch_device) def split_into_paragraphs(self, text: str) -> List[str]: """ Split text into paragraphs while preserving empty lines. :param text: Input text :return: List of paragraphs """ paragraphs = text.split('\n\n') return [p.strip() for p in paragraphs if p.strip()] def split_into_sentences(self, paragraph: str) -> List[str]: """ Split paragraph into sentences using regex. :param paragraph: Input paragraph :return: List of sentences """ sentences = re.split(r'(?<=[.!?])\s+', paragraph) return [s.strip() for s in sentences if s.strip()] def get_response(self, input_text: str, num_return_sequences: int = 1) -> str: """ Generate paraphrased text for a single input. :param input_text: Input sentence to paraphrase :param num_return_sequences: Number of alternative paraphrases to generate :return: Paraphrased text """ batch = self.tokenizer.prepare_seq2seq_batch( [input_text], truncation=True, padding='longest', max_length=80, return_tensors="pt" ).to(self.torch_device) translated = self.model.generate( **batch, num_beams=10, num_return_sequences=num_return_sequences, temperature=1.0, repetition_penalty=2.8, length_penalty=1.2, max_length=80, min_length=5, no_repeat_ngram_size=3 ) tgt_text = self.tokenizer.batch_decode(translated, skip_special_tokens=True) return tgt_text[0] def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: """ Process the incoming request and generate paraphrased text. :param data: Request payload containing input text :return: Paraphrased text """ # Extract input text from the payload inputs = data.pop("inputs", data) # If input is not a string, raise an error if not isinstance(inputs, str): raise ValueError("Input must be a string") # Split text into paragraphs paragraphs = self.split_into_paragraphs(inputs) paraphrased_paragraphs = [] # Process each paragraph for paragraph in paragraphs: sentences = self.split_into_sentences(paragraph) paraphrased_sentences = [] for sentence in sentences: # Skip very short sentences if len(sentence.split()) < 3: paraphrased_sentences.append(sentence) continue try: # Paraphrase the sentence paraphrased = self.get_response(sentence) # Avoid unwanted paraphrases if not any(phrase in paraphrased.lower() for phrase in ['it\'s like', 'in other words']): paraphrased_sentences.append(paraphrased) else: paraphrased_sentences.append(sentence) except Exception as e: print(f"Error processing sentence: {e}") paraphrased_sentences.append(sentence) # Join sentences back into a paragraph paraphrased_paragraphs.append(' '.join(paraphrased_sentences)) # Join paragraphs back into text return {"outputs": '\n\n'.join(paraphrased_paragraphs)}