import torch import gradio as gr from transformers import PegasusTokenizer, PegasusForConditionalGeneration # Define the PEGASUS model and tokenizer MODEL_NAME = 'VishnuPottabatthini/PEGASUS_Large' # Change this to the PEGASUS model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load the tokenizer and model tokenizer = PegasusTokenizer.from_pretrained(MODEL_NAME) model = PegasusForConditionalGeneration.from_pretrained(MODEL_NAME).to(device) # Define the summarization function def summarize(text, state): try: # Tokenize the input text inputs = tokenizer( text, return_tensors="pt", truncation=True, max_length=1024 # Adjust max length according to your model's capabilities ).to(device) # Generate the summary summary_ids = model.generate( inputs['input_ids'], attention_mask=inputs['attention_mask'], max_length=150, # Maximum length of the summary min_length=30, # Minimum length of the summary num_beams=4, # Beam search to improve the quality of generated text early_stopping=True ) # Decode the summary summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) return state + "\n" + summary, state + "\n" + summary except Exception as e: return str(e), state # Create the Gradio interface mf_summarize = gr.Interface( fn=summarize, inputs=[ gr.Textbox(placeholder="Enter text to summarize...", lines=10), gr.State(value="") ], outputs=[ gr.Textbox(lines=15, label="Summary"), gr.State() ], theme="huggingface", title="Article Summarization", live=True, description=( "Enter a long piece of text to generate a concise summary using a PEGASUS model. " "This demo uses a custom PEGASUS model from 🤗 Transformers." ) ) # Launch the Gradio interface mf_summarize.launch()