paraphrase / app.py
Jaane's picture
adding changes
8ec3ac3 verified
import torch
from transformers import PegasusForConditionalGeneration, PegasusTokenizer
import gradio as gr
# Load the tokenizer and model once when the app starts
model_name = 'tuner007/pegasus_paraphrase'
torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Initialize tokenizer and model
tokenizer = PegasusTokenizer.from_pretrained(model_name)
model = PegasusForConditionalGeneration.from_pretrained(model_name).to(torch_device)
def get_response(input_text, num_return_sequences=1, num_beams=3):
"""
Generate paraphrased text for a given input_text using the Pegasus model.
Args:
input_text (str): The text to paraphrase.
num_return_sequences (int): Number of paraphrased sequences to return.
num_beams (int): Number of beams for beam search.
Returns:
list: A list containing paraphrased text strings.
"""
# Tokenize the input text
batch = tokenizer(
[input_text],
truncation=True,
padding='longest',
max_length=60,
return_tensors="pt"
).to(torch_device)
# Generate paraphrased outputs
translated = model.generate(
**batch,
max_length=60,
num_beams=num_beams,
num_return_sequences=num_return_sequences,
temperature=0.7
)
# Decode the generated tokens
tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
return tgt_text
def split_text_by_fullstop(text):
"""
Split the input text into sentences based on full stops.
Args:
text (str): The text to split.
Returns:
list: A list of sentences.
"""
sentences = [sentence.strip() for sentence in text.split('.') if sentence]
return sentences
def process_text_by_fullstop(text, num_return_sequences=1, num_beams=3):
"""
Process the input text by splitting it into sentences and paraphrasing each sentence.
Args:
text (str): The text to paraphrase.
num_return_sequences (int): Number of paraphrased sequences per sentence.
num_beams (int): Number of beams for beam search.
Returns:
str: The paraphrased text.
"""
sentences = split_text_by_fullstop(text)
paraphrased_sentences = []
for sentence in sentences:
# Ensure each sentence ends with a period
sentence = sentence + '.' if not sentence.endswith('.') else sentence
paraphrases = get_response(sentence, num_return_sequences, num_beams)
paraphrased_sentences.extend(paraphrases)
# Join all paraphrased sentences into a single string
return ' '.join(paraphrased_sentences)
def paraphrase(text, num_beams, num_return_sequences):
"""
Interface function to paraphrase input text based on user parameters.
Args:
text (str): The input text to paraphrase.
num_beams (int): Number of beams for beam search.
num_return_sequences (int): Number of paraphrased sequences to return.
Returns:
str: The paraphrased text.
"""
return process_text_by_fullstop(text, num_return_sequences, num_beams)
# Define the Gradio interface
iface = gr.Interface(
fn=paraphrase,
inputs=[
gr.components.Textbox(
lines=10,
placeholder="Enter text here...",
label="Input Text"
),
gr.components.Slider(
minimum=1,
maximum=10,
step=1,
value=3,
label="Number of Beams"
),
gr.components.Slider(
minimum=1,
maximum=5,
step=1,
value=1,
label="Number of Return Sequences"
)
],
outputs=gr.components.Textbox(
lines=10,
label="Paraphrased Text"
),
title="Text Paraphrasing App",
description="Enter your text and adjust the parameters to receive paraphrased versions using the Pegasus model.",
allow_flagging="never"
)
# Launch the app
if __name__ == "__main__":
iface.launch()