import torch from transformers import PegasusForConditionalGeneration, PegasusTokenizer import gradio as gr # Load the tokenizer and model once when the app starts model_name = 'tuner007/pegasus_paraphrase' torch_device = 'cuda' if torch.cuda.is_available() else 'cpu' # Initialize tokenizer and model tokenizer = PegasusTokenizer.from_pretrained(model_name) model = PegasusForConditionalGeneration.from_pretrained(model_name).to(torch_device) def get_response(input_text, num_return_sequences=1, num_beams=3): """ Generate paraphrased text for a given input_text using the Pegasus model. Args: input_text (str): The text to paraphrase. num_return_sequences (int): Number of paraphrased sequences to return. num_beams (int): Number of beams for beam search. Returns: list: A list containing paraphrased text strings. """ # Tokenize the input text batch = tokenizer( [input_text], truncation=True, padding='longest', max_length=60, return_tensors="pt" ).to(torch_device) # Generate paraphrased outputs translated = model.generate( **batch, max_length=60, num_beams=num_beams, num_return_sequences=num_return_sequences, temperature=0.7 ) # Decode the generated tokens tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True) return tgt_text def split_text_by_fullstop(text): """ Split the input text into sentences based on full stops. Args: text (str): The text to split. Returns: list: A list of sentences. """ sentences = [sentence.strip() for sentence in text.split('.') if sentence] return sentences def process_text_by_fullstop(text, num_return_sequences=1, num_beams=3): """ Process the input text by splitting it into sentences and paraphrasing each sentence. Args: text (str): The text to paraphrase. num_return_sequences (int): Number of paraphrased sequences per sentence. num_beams (int): Number of beams for beam search. Returns: str: The paraphrased text. """ sentences = split_text_by_fullstop(text) paraphrased_sentences = [] for sentence in sentences: # Ensure each sentence ends with a period sentence = sentence + '.' if not sentence.endswith('.') else sentence paraphrases = get_response(sentence, num_return_sequences, num_beams) paraphrased_sentences.extend(paraphrases) # Join all paraphrased sentences into a single string return ' '.join(paraphrased_sentences) def paraphrase(text, num_beams, num_return_sequences): """ Interface function to paraphrase input text based on user parameters. Args: text (str): The input text to paraphrase. num_beams (int): Number of beams for beam search. num_return_sequences (int): Number of paraphrased sequences to return. Returns: str: The paraphrased text. """ return process_text_by_fullstop(text, num_return_sequences, num_beams) # Define the Gradio interface iface = gr.Interface( fn=paraphrase, inputs=[ gr.components.Textbox( lines=10, placeholder="Enter text here...", label="Input Text" ), gr.components.Slider( minimum=1, maximum=10, step=1, value=3, label="Number of Beams" ), gr.components.Slider( minimum=1, maximum=5, step=1, value=1, label="Number of Return Sequences" ) ], outputs=gr.components.Textbox( lines=10, label="Paraphrased Text" ), title="Text Paraphrasing App", description="Enter your text and adjust the parameters to receive paraphrased versions using the Pegasus model.", allow_flagging="never" ) # Launch the app if __name__ == "__main__": iface.launch()