Dhanush S Gowda commited on
Commit
6895495
·
verified ·
1 Parent(s): 401ffed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +134 -66
app.py CHANGED
@@ -1,75 +1,143 @@
1
  import streamlit as st
 
2
  from transformers import pipeline
3
  import os
 
4
 
5
- # Set Hugging Face cache directory
6
- os.environ['TRANSFORMERS_CACHE'] = os.getenv('HF_HOME', os.path.expanduser('~/.cache/huggingface/hub'))
7
-
8
- # Function to load all three models
9
  @st.cache_resource
10
- def load_models():
11
- bart_summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
12
- t5_summarizer = pipeline("summarization", model="t5-large")
13
- pegasus_summarizer = pipeline("summarization", model="google/pegasus-cnn_dailymail")
14
- return bart_summarizer, t5_summarizer, pegasus_summarizer
15
-
16
- # Streamlit app layout
17
- st.title("Text Summarization with Pre-trained Models: BART, T5, Pegasus")
18
-
19
- # Load models
20
- with st.spinner("Loading models..."):
21
- bart_model, t5_model, pegasus_model = load_models()
22
-
23
- # Input text
24
- text_input = st.text_area("Enter text to summarize:")
25
-
26
- # User input for min and max words
27
- st.sidebar.header("Summary Length Settings")
28
- min_words = st.sidebar.slider("Minimum words in summary:", 10, 100, 50, step=5)
29
- max_words = st.sidebar.slider("Maximum words in summary:", min_words + 10, 300, 150, step=10)
30
-
31
- if text_input:
32
- word_count = len(text_input.split())
33
- st.write(f"**Input Word Count:** {word_count}")
34
-
35
- if st.button("Generate Summaries"):
36
- with st.spinner("Generating summaries..."):
37
- # Generate summaries with dynamic length constraints
38
- bart_summary = bart_model(
39
- text_input,
40
- max_length=max_words,
41
- min_length=min_words,
42
- num_beams=4,
43
- early_stopping=True
44
- )[0]['summary_text']
45
-
46
- t5_summary = t5_model(
47
- text_input,
48
- max_length=max_words,
49
- min_length=min_words,
50
- num_beams=4,
51
- early_stopping=True
52
- )[0]['summary_text']
53
 
54
- pegasus_summary = pegasus_model(
55
- text_input,
56
- max_length=max_words,
57
- min_length=min_words,
58
- num_beams=4,
59
- early_stopping=True
60
- )[0]['summary_text']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
- # Display summaries
63
- st.subheader("BART Summary")
64
- st.write(bart_summary)
65
- st.write(f"**Word Count:** {len(bart_summary.split())}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
- st.subheader("T5 Summary")
68
- st.write(t5_summary)
69
- st.write(f"**Word Count:** {len(t5_summary.split())}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
- st.subheader("Pegasus Summary")
72
- st.write(pegasus_summary)
73
- st.write(f"**Word Count:** {len(pegasus_summary.split())}")
74
- else:
75
- st.warning("Please enter text to summarize.")
 
1
  import streamlit as st
2
+ import multiprocessing
3
  from transformers import pipeline
4
  import os
5
+ import torch
6
 
7
+ # Optimize model loading and caching
 
 
 
8
  @st.cache_resource
9
+ def load_model(model_name):
10
+ """Efficiently load a summarization model."""
11
+ try:
12
+ # Use GPU if available
13
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
14
+ return pipeline("summarization", model=model_name, device=device)
15
+ except Exception as e:
16
+ st.error(f"Error loading model {model_name}: {e}")
17
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
+ def generate_summary(model, text, length_percentage=0.3):
20
+ """
21
+ Generate summary with intelligent length control.
22
+
23
+ Args:
24
+ model: Hugging Face summarization pipeline
25
+ text: Input text to summarize
26
+ length_percentage: Percentage of original text to use for summary
27
+
28
+ Returns:
29
+ Generated summary
30
+ """
31
+ # Intelligent length calculation
32
+ word_count = len(text.split())
33
+ max_length = max(50, int(word_count * length_percentage))
34
+ min_length = max(30, int(word_count * 0.1))
35
+
36
+ try:
37
+ summary = model(
38
+ text,
39
+ max_length=max_length,
40
+ min_length=min_length,
41
+ num_beams=4,
42
+ early_stopping=True
43
+ )[0]['summary_text']
44
+ return summary
45
+ except Exception as e:
46
+ st.error(f"Summarization error: {e}")
47
+ return "Could not generate summary."
48
 
49
+ def parallel_summarize(text, length_percentage=0.3):
50
+ """
51
+ Generate summaries in parallel using multiprocessing.
52
+
53
+ Args:
54
+ text: Input text to summarize
55
+ length_percentage: Percentage of original text to use for summary
56
+
57
+ Returns:
58
+ Dictionary of summaries from different models
59
+ """
60
+ model_configs = [
61
+ ("facebook/bart-large-cnn", "BART"),
62
+ ("t5-large", "T5"),
63
+ ("google/pegasus-cnn_dailymail", "Pegasus")
64
+ ]
65
+
66
+ with multiprocessing.Pool(processes=min(len(model_configs), os.cpu_count())) as pool:
67
+ args = [(load_model(model_name), text, length_percentage)
68
+ for model_name, _ in model_configs]
69
+
70
+ results = pool.starmap(generate_summary, args)
71
+
72
+ return {name: summary for (_, name), summary in zip(model_configs, results)}
73
 
74
+ def main():
75
+ st.set_page_config(
76
+ page_title="Multi-Model Text Summarization",
77
+ page_icon="📝",
78
+ layout="wide"
79
+ )
80
+
81
+ # Title and Description
82
+ st.title("🤖 Advanced Text Summarization")
83
+ st.markdown("""
84
+ Generate concise summaries using multiple state-of-the-art models.
85
+ Intelligently adapts summary length based on input text.
86
+ """)
87
+
88
+ # Text Input
89
+ text_input = st.text_area(
90
+ "Paste your text here:",
91
+ height=250,
92
+ help="Enter the text you want to summarize"
93
+ )
94
+
95
+ # Length Control
96
+ length_control = st.slider(
97
+ "Summary Compression Rate",
98
+ min_value=0.1,
99
+ max_value=0.5,
100
+ value=0.3,
101
+ step=0.05,
102
+ help="Adjust how much of the original text to keep in the summary"
103
+ )
104
+
105
+ if st.button("Generate Summaries", type="primary"):
106
+ if not text_input:
107
+ st.warning("Please enter some text to summarize.")
108
+ return
109
+
110
+ progress_text = st.empty()
111
+ progress_bar = st.progress(0)
112
+
113
+ stages = ["Initializing Models", "Running BART", "Running T5", "Running Pegasus", "Completed"]
114
+
115
+ try:
116
+ for i, stage in enumerate(stages[:-1], 1):
117
+ progress_text.info(stage)
118
+ progress_bar.progress(i * 20)
119
+
120
+ if i == 2:
121
+ summaries = parallel_summarize(text_input, length_control)
122
+
123
+ progress_text.success("Summarization Complete!")
124
+ progress_bar.progress(100)
125
+
126
+ st.subheader("📝 Generated Summaries")
127
+ cols = st.columns(3)
128
+
129
+ for (col, (model, summary)) in zip(cols, summaries.items()):
130
+ with col:
131
+ st.markdown(f"### {model} Summary")
132
+ st.write(summary)
133
+ st.caption(f"Word Count: {len(summary.split())}")
134
+
135
+ except Exception as e:
136
+ st.error(f"An error occurred: {e}")
137
+
138
+ finally:
139
+ progress_text.empty()
140
+ progress_bar.empty()
141
 
142
+ if __name__ == "__main__":
143
+ main()