Humaneyes / main.py
Eemansleepdeprived's picture
Upload 2 files
651bb25 verified
raw
history blame
3.96 kB
import gradio as gr
import torch
from transformers import PegasusForConditionalGeneration, PegasusTokenizer
import re
import os
def load_model():
"""Load the model from local storage"""
torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {torch_device}")
# Load tokenizer and model from local directory
tokenizer = PegasusTokenizer.from_pretrained('./models')
model = PegasusForConditionalGeneration.from_pretrained('./models').to(torch_device)
return tokenizer, model, torch_device
def split_into_paragraphs(text):
"""Split text into paragraphs while preserving empty lines."""
paragraphs = text.split('\n\n')
return [p.strip() for p in paragraphs if p.strip()]
def split_into_sentences(paragraph):
"""Split paragraph into sentences using regex."""
sentences = re.split(r'(?<=[.!?])\s+', paragraph)
return [s.strip() for s in sentences if s.strip()]
def get_response(input_text, num_return_sequences, tokenizer, model, torch_device):
batch = tokenizer.prepare_seq2seq_batch(
[input_text],
truncation=True,
padding='longest',
max_length=80,
return_tensors="pt"
).to(torch_device)
translated = model.generate(
**batch,
num_beams=10,
num_return_sequences=num_return_sequences,
temperature=1.0,
repetition_penalty=2.8,
length_penalty=1.2,
max_length=80,
min_length=5,
no_repeat_ngram_size=3
)
tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
return tgt_text[0]
def get_response_from_text(context, tokenizer, model, torch_device):
"""Process entire text while preserving paragraph structure."""
paragraphs = split_into_paragraphs(context)
paraphrased_paragraphs = []
for paragraph in paragraphs:
sentences = split_into_sentences(paragraph)
paraphrased_sentences = []
for sentence in sentences:
if len(sentence.split()) < 3:
paraphrased_sentences.append(sentence)
continue
try:
paraphrased = get_response(sentence, 1, tokenizer, model, torch_device)
if not any(phrase in paraphrased.lower() for phrase in ['it\'s like', 'in other words']):
paraphrased_sentences.append(paraphrased)
else:
paraphrased_sentences.append(sentence)
except Exception as e:
print(f"Error processing sentence: {e}")
paraphrased_sentences.append(sentence)
paraphrased_paragraphs.append(' '.join(paraphrased_sentences))
return '\n\n'.join(paraphrased_paragraphs)
def create_interface():
"""Create and configure the Gradio interface"""
# Load model and tokenizer
tokenizer, model, torch_device = load_model()
def greet(context):
return get_response_from_text(context, tokenizer, model, torch_device)
# Create interface with improved styling
iface = gr.Interface(
fn=greet,
inputs=gr.Textbox(
lines=15,
label="Input Text",
placeholder="Enter your text here...",
elem_classes="input-text"
),
outputs=gr.Textbox(
lines=15,
label="Paraphrased Text",
elem_classes="output-text"
),
title="Advanced Text Paraphraser",
description="Enter text to generate a high-quality paraphrased version while maintaining paragraph structure.",
theme="default",
css="""
.input-text, .output-text {
font-size: 16px !important;
font-family: Arial, sans-serif !important;
min-height: 300px !important;
}
"""
)
return iface
if __name__ == "__main__":
# Create and launch the interface
interface = create_interface()
interface.launch()