Anupam251272 commited on
Commit
2ff086b
·
verified ·
1 Parent(s): 0265055

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +138 -0
app.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import nltk
2
+ nltk.download('punkt')
3
+
4
+ # Third cell - Main implementation
5
+ import torch
6
+ from transformers import PegasusForConditionalGeneration, PegasusTokenizer
7
+ from newspaper import Article
8
+ import gradio as gr
9
+ import warnings
10
+ warnings.filterwarnings('ignore')
11
+
12
+ # Check if GPU is available
13
+ device = "cuda" if torch.cuda.is_available() else "cpu"
14
+ print(f"Using device: {device}")
15
+
16
+ # Initialize model and tokenizer
17
+ model_name = "google/pegasus-large"
18
+ try:
19
+ tokenizer = PegasusTokenizer.from_pretrained(model_name)
20
+ model = PegasusForConditionalGeneration.from_pretrained(model_name)
21
+ model = model.to(device)
22
+ print("Model loaded successfully!")
23
+ except Exception as e:
24
+ print(f"Error loading model: {e}")
25
+
26
+ def fetch_article_text(url):
27
+ """Fetch and extract text from a given URL"""
28
+ try:
29
+ article = Article(url)
30
+ article.download()
31
+ article.parse()
32
+ return article.text
33
+ except Exception as e:
34
+ return f"Error fetching article: {e}"
35
+
36
+ def summarize_text(text, max_length=150, min_length=40):
37
+ """Generate summary using the Pegasus model"""
38
+ try:
39
+ # Tokenize with padding and truncation
40
+ inputs = tokenizer(
41
+ text,
42
+ max_length=1024,
43
+ truncation=True,
44
+ padding="max_length",
45
+ return_tensors="pt"
46
+ ).to(device)
47
+
48
+ # Generate summary
49
+ summary_ids = model.generate(
50
+ inputs["input_ids"],
51
+ max_length=max_length,
52
+ min_length=min_length,
53
+ length_penalty=2.0,
54
+ num_beams=4,
55
+ early_stopping=True
56
+ )
57
+
58
+ # Decode and return summary
59
+ summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
60
+ return summary
61
+
62
+ except Exception as e:
63
+ return f"Error generating summary: {e}"
64
+
65
+ def process_input(input_text, input_type, max_length=150, min_length=40):
66
+ """Process either URL or direct text input"""
67
+ try:
68
+ if input_type == "URL":
69
+ text = fetch_article_text(input_text)
70
+ if "Error" in text:
71
+ return text
72
+ else:
73
+ text = input_text
74
+
75
+ if not text or len(text.strip()) < 100:
76
+ return "Error: Input text is too short or empty."
77
+
78
+ return summarize_text(text, max_length, min_length)
79
+
80
+ except Exception as e:
81
+ return f"Error processing input: {e}"
82
+
83
+ # Create Gradio interface
84
+ def create_interface():
85
+ with gr.Blocks(title="Research Article Summarizer") as interface:
86
+ gr.Markdown("# Research Article Summarizer")
87
+ gr.Markdown("Enter either a URL or paste the article text directly.")
88
+
89
+ with gr.Row():
90
+ input_type = gr.Radio(
91
+ choices=["URL", "Text"],
92
+ value="URL",
93
+ label="Input Type"
94
+ )
95
+
96
+ with gr.Row():
97
+ input_text = gr.Textbox(
98
+ lines=5,
99
+ placeholder="Enter URL or paste article text here...",
100
+ label="Input"
101
+ )
102
+
103
+ with gr.Row():
104
+ max_length = gr.Slider(
105
+ minimum=50,
106
+ maximum=500,
107
+ value=150,
108
+ step=10,
109
+ label="Maximum Summary Length"
110
+ )
111
+ min_length = gr.Slider(
112
+ minimum=20,
113
+ maximum=200,
114
+ value=40,
115
+ step=10,
116
+ label="Minimum Summary Length"
117
+ )
118
+
119
+ with gr.Row():
120
+ submit_btn = gr.Button("Generate Summary")
121
+
122
+ with gr.Row():
123
+ output = gr.Textbox(
124
+ lines=5,
125
+ label="Generated Summary"
126
+ )
127
+
128
+ submit_btn.click(
129
+ fn=process_input,
130
+ inputs=[input_text, input_type, max_length, min_length],
131
+ outputs=output
132
+ )
133
+
134
+ return interface
135
+
136
+ # Launch the interface
137
+ demo = create_interface()
138
+ demo.launch(debug=True, share=True)