Humaneyes / handler.py
Eemansleepdeprived's picture
Create handler.py
923abb2 verified
from typing import Dict, List, Any
import torch
from transformers import PegasusForConditionalGeneration, PegasusTokenizer
import re
class EndpointHandler:
def __init__(self, path=""):
"""
Initialize the endpoint handler with the model and tokenizer.
:param path: Path to the model weights
"""
# Determine the device
self.torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Load tokenizer and model
self.tokenizer = PegasusTokenizer.from_pretrained(path)
self.model = PegasusForConditionalGeneration.from_pretrained(path).to(self.torch_device)
def split_into_paragraphs(self, text: str) -> List[str]:
"""
Split text into paragraphs while preserving empty lines.
:param text: Input text
:return: List of paragraphs
"""
paragraphs = text.split('\n\n')
return [p.strip() for p in paragraphs if p.strip()]
def split_into_sentences(self, paragraph: str) -> List[str]:
"""
Split paragraph into sentences using regex.
:param paragraph: Input paragraph
:return: List of sentences
"""
sentences = re.split(r'(?<=[.!?])\s+', paragraph)
return [s.strip() for s in sentences if s.strip()]
def get_response(self, input_text: str, num_return_sequences: int = 1) -> str:
"""
Generate paraphrased text for a single input.
:param input_text: Input sentence to paraphrase
:param num_return_sequences: Number of alternative paraphrases to generate
:return: Paraphrased text
"""
batch = self.tokenizer.prepare_seq2seq_batch(
[input_text],
truncation=True,
padding='longest',
max_length=80,
return_tensors="pt"
).to(self.torch_device)
translated = self.model.generate(
**batch,
num_beams=10,
num_return_sequences=num_return_sequences,
temperature=1.0,
repetition_penalty=2.8,
length_penalty=1.2,
max_length=80,
min_length=5,
no_repeat_ngram_size=3
)
tgt_text = self.tokenizer.batch_decode(translated, skip_special_tokens=True)
return tgt_text[0]
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Process the incoming request and generate paraphrased text.
:param data: Request payload containing input text
:return: Paraphrased text
"""
# Extract input text from the payload
inputs = data.pop("inputs", data)
# If input is not a string, raise an error
if not isinstance(inputs, str):
raise ValueError("Input must be a string")
# Split text into paragraphs
paragraphs = self.split_into_paragraphs(inputs)
paraphrased_paragraphs = []
# Process each paragraph
for paragraph in paragraphs:
sentences = self.split_into_sentences(paragraph)
paraphrased_sentences = []
for sentence in sentences:
# Skip very short sentences
if len(sentence.split()) < 3:
paraphrased_sentences.append(sentence)
continue
try:
# Paraphrase the sentence
paraphrased = self.get_response(sentence)
# Avoid unwanted paraphrases
if not any(phrase in paraphrased.lower() for phrase in ['it\'s like', 'in other words']):
paraphrased_sentences.append(paraphrased)
else:
paraphrased_sentences.append(sentence)
except Exception as e:
print(f"Error processing sentence: {e}")
paraphrased_sentences.append(sentence)
# Join sentences back into a paragraph
paraphrased_paragraphs.append(' '.join(paraphrased_sentences))
# Join paragraphs back into text
return {"outputs": '\n\n'.join(paraphrased_paragraphs)}