|
import gradio as gr |
|
import torch |
|
from transformers import PegasusForConditionalGeneration, PegasusTokenizer |
|
import re |
|
import os |
|
|
|
def load_model(): |
|
"""Load the model from local storage""" |
|
torch_device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
print(f"Using device: {torch_device}") |
|
|
|
|
|
tokenizer = PegasusTokenizer.from_pretrained('./models') |
|
model = PegasusForConditionalGeneration.from_pretrained('./models').to(torch_device) |
|
return tokenizer, model, torch_device |
|
|
|
def split_into_paragraphs(text): |
|
"""Split text into paragraphs while preserving empty lines.""" |
|
paragraphs = text.split('\n\n') |
|
return [p.strip() for p in paragraphs if p.strip()] |
|
|
|
def split_into_sentences(paragraph): |
|
"""Split paragraph into sentences using regex.""" |
|
sentences = re.split(r'(?<=[.!?])\s+', paragraph) |
|
return [s.strip() for s in sentences if s.strip()] |
|
|
|
def get_response(input_text, num_return_sequences, tokenizer, model, torch_device): |
|
batch = tokenizer.prepare_seq2seq_batch( |
|
[input_text], |
|
truncation=True, |
|
padding='longest', |
|
max_length=80, |
|
return_tensors="pt" |
|
).to(torch_device) |
|
|
|
translated = model.generate( |
|
**batch, |
|
num_beams=10, |
|
num_return_sequences=num_return_sequences, |
|
temperature=1.0, |
|
repetition_penalty=2.8, |
|
length_penalty=1.2, |
|
max_length=80, |
|
min_length=5, |
|
no_repeat_ngram_size=3 |
|
) |
|
|
|
tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True) |
|
return tgt_text[0] |
|
|
|
def get_response_from_text(context, tokenizer, model, torch_device): |
|
"""Process entire text while preserving paragraph structure.""" |
|
paragraphs = split_into_paragraphs(context) |
|
paraphrased_paragraphs = [] |
|
|
|
for paragraph in paragraphs: |
|
sentences = split_into_sentences(paragraph) |
|
paraphrased_sentences = [] |
|
|
|
for sentence in sentences: |
|
if len(sentence.split()) < 3: |
|
paraphrased_sentences.append(sentence) |
|
continue |
|
|
|
try: |
|
paraphrased = get_response(sentence, 1, tokenizer, model, torch_device) |
|
if not any(phrase in paraphrased.lower() for phrase in ['it\'s like', 'in other words']): |
|
paraphrased_sentences.append(paraphrased) |
|
else: |
|
paraphrased_sentences.append(sentence) |
|
except Exception as e: |
|
print(f"Error processing sentence: {e}") |
|
paraphrased_sentences.append(sentence) |
|
|
|
paraphrased_paragraphs.append(' '.join(paraphrased_sentences)) |
|
|
|
return '\n\n'.join(paraphrased_paragraphs) |
|
|
|
def create_interface(): |
|
"""Create and configure the Gradio interface""" |
|
|
|
tokenizer, model, torch_device = load_model() |
|
|
|
def greet(context): |
|
return get_response_from_text(context, tokenizer, model, torch_device) |
|
|
|
|
|
iface = gr.Interface( |
|
fn=greet, |
|
inputs=gr.Textbox( |
|
lines=15, |
|
label="Input Text", |
|
placeholder="Enter your text here...", |
|
elem_classes="input-text" |
|
), |
|
outputs=gr.Textbox( |
|
lines=15, |
|
label="Paraphrased Text", |
|
elem_classes="output-text" |
|
), |
|
title="Advanced Text Paraphraser", |
|
description="Enter text to generate a high-quality paraphrased version while maintaining paragraph structure.", |
|
theme="default", |
|
css=""" |
|
.input-text, .output-text { |
|
font-size: 16px !important; |
|
font-family: Arial, sans-serif !important; |
|
min-height: 300px !important; |
|
} |
|
""" |
|
) |
|
return iface |
|
|
|
if __name__ == "__main__": |
|
|
|
interface = create_interface() |
|
interface.launch() |