import gradio as gr import torch from transformers import PegasusForConditionalGeneration, PegasusTokenizer import re import os def load_model(): """Load the model from local storage""" torch_device = 'cuda' if torch.cuda.is_available() else 'cpu' print(f"Using device: {torch_device}") # Load tokenizer and model from local directory tokenizer = PegasusTokenizer.from_pretrained('./models') model = PegasusForConditionalGeneration.from_pretrained('./models').to(torch_device) return tokenizer, model, torch_device def split_into_paragraphs(text): """Split text into paragraphs while preserving empty lines.""" paragraphs = text.split('\n\n') return [p.strip() for p in paragraphs if p.strip()] def split_into_sentences(paragraph): """Split paragraph into sentences using regex.""" sentences = re.split(r'(?<=[.!?])\s+', paragraph) return [s.strip() for s in sentences if s.strip()] def get_response(input_text, num_return_sequences, tokenizer, model, torch_device): batch = tokenizer.prepare_seq2seq_batch( [input_text], truncation=True, padding='longest', max_length=80, return_tensors="pt" ).to(torch_device) translated = model.generate( **batch, num_beams=10, num_return_sequences=num_return_sequences, temperature=1.0, repetition_penalty=2.8, length_penalty=1.2, max_length=80, min_length=5, no_repeat_ngram_size=3 ) tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True) return tgt_text[0] def get_response_from_text(context, tokenizer, model, torch_device): """Process entire text while preserving paragraph structure.""" paragraphs = split_into_paragraphs(context) paraphrased_paragraphs = [] for paragraph in paragraphs: sentences = split_into_sentences(paragraph) paraphrased_sentences = [] for sentence in sentences: if len(sentence.split()) < 3: paraphrased_sentences.append(sentence) continue try: paraphrased = get_response(sentence, 1, tokenizer, model, torch_device) if not any(phrase in paraphrased.lower() for phrase in ['it\'s like', 'in other words']): paraphrased_sentences.append(paraphrased) else: paraphrased_sentences.append(sentence) except Exception as e: print(f"Error processing sentence: {e}") paraphrased_sentences.append(sentence) paraphrased_paragraphs.append(' '.join(paraphrased_sentences)) return '\n\n'.join(paraphrased_paragraphs) def create_interface(): """Create and configure the Gradio interface""" # Load model and tokenizer tokenizer, model, torch_device = load_model() def greet(context): return get_response_from_text(context, tokenizer, model, torch_device) # Create interface with improved styling iface = gr.Interface( fn=greet, inputs=gr.Textbox( lines=15, label="Input Text", placeholder="Enter your text here...", elem_classes="input-text" ), outputs=gr.Textbox( lines=15, label="Paraphrased Text", elem_classes="output-text" ), title="Advanced Text Paraphraser", description="Enter text to generate a high-quality paraphrased version while maintaining paragraph structure.", theme="default", css=""" .input-text, .output-text { font-size: 16px !important; font-family: Arial, sans-serif !important; min-height: 300px !important; } """ ) return iface if __name__ == "__main__": # Create and launch the interface interface = create_interface() interface.launch()