Dhanush S Gowda commited on
Commit
a37af66
·
verified ·
1 Parent(s): 27d07c4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -14
app.py CHANGED
@@ -1,10 +1,10 @@
1
  import streamlit as st
2
  from transformers import pipeline
 
3
  import time
4
 
5
- # Load selected model's pipeline
6
- @st.cache_resource
7
- def load_pipeline(model_name):
8
  if model_name == 'BART':
9
  return pipeline("summarization", model="facebook/bart-large-cnn")
10
  elif model_name == 'T5':
@@ -12,6 +12,18 @@ def load_pipeline(model_name):
12
  elif model_name == 'Pegasus':
13
  return pipeline("summarization", model="google/pegasus-cnn_dailymail")
14
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  # Streamlit app layout
16
  st.title("Text Summarization with Pre-trained Models (BART, T5, Pegasus)")
17
 
@@ -22,19 +34,23 @@ if text_input:
22
  word_count = len(text_input.split())
23
  st.write(f"**Word Count:** {word_count}")
24
 
25
- model_choice = st.selectbox("Choose a model:", ['BART', 'T5', 'Pegasus'])
26
-
27
- if st.button("Generate Summary"):
28
- with st.spinner(f"Generating summary using {model_choice}..."):
29
  start_time = time.time()
30
- summarizer = load_pipeline(model_choice)
31
- summary = summarizer(text_input, max_length=150, min_length=50, length_penalty=2.0, num_beams=4, early_stopping=True)[0]['summary_text']
 
 
 
32
  end_time = time.time()
33
- summary_word_count = len(summary.split())
34
 
35
- st.subheader(f"Summary using {model_choice}")
36
- st.write(summary.replace('<n>', ''))
37
- st.write(f"**Summary Word Count:** {summary_word_count}")
38
- st.write(f"**Time Taken:** {end_time - start_time:.2f} seconds")
 
 
 
 
39
  else:
40
  st.error("Please enter text to summarize.")
 
1
  import streamlit as st
2
  from transformers import pipeline
3
+ from concurrent.futures import ThreadPoolExecutor
4
  import time
5
 
6
+ # Function to load a single model
7
+ def load_model(model_name):
 
8
  if model_name == 'BART':
9
  return pipeline("summarization", model="facebook/bart-large-cnn")
10
  elif model_name == 'T5':
 
12
  elif model_name == 'Pegasus':
13
  return pipeline("summarization", model="google/pegasus-cnn_dailymail")
14
 
15
+ # Function to load all models concurrently
16
+ @st.cache_resource
17
+ def load_all_models():
18
+ model_names = ['BART', 'T5', 'Pegasus']
19
+ models = {}
20
+ with ThreadPoolExecutor() as executor:
21
+ futures = {executor.submit(load_model, name): name for name in model_names}
22
+ for future in futures:
23
+ model_name = futures[future]
24
+ models[model_name] = future.result()
25
+ return models
26
+
27
  # Streamlit app layout
28
  st.title("Text Summarization with Pre-trained Models (BART, T5, Pegasus)")
29
 
 
34
  word_count = len(text_input.split())
35
  st.write(f"**Word Count:** {word_count}")
36
 
37
+ if st.button("Generate Summaries"):
38
+ with st.spinner("Loading models and generating summaries..."):
 
 
39
  start_time = time.time()
40
+ models = load_all_models()
41
+ summaries = {}
42
+ for model_name, model in models.items():
43
+ summary = model(text_input, max_length=150, min_length=50, length_penalty=2.0, num_beams=4, early_stopping=True)[0]['summary_text']
44
+ summaries[model_name] = summary
45
  end_time = time.time()
 
46
 
47
+ st.subheader("Summaries")
48
+ for model_name, summary in summaries.items():
49
+ summary_word_count = len(summary.split())
50
+ st.write(f"**{model_name}**")
51
+ st.write(summary.replace('<n>', ''))
52
+ st.write(f"**Summary Word Count:** {summary_word_count}")
53
+ st.write("---")
54
+ st.write(f"**Total Time Taken:** {end_time - start_time:.2f} seconds")
55
  else:
56
  st.error("Please enter text to summarize.")