Dhanush S Gowda commited on
Commit
53545b3
·
verified ·
1 Parent(s): e06b7dd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -34
app.py CHANGED
@@ -1,28 +1,19 @@
1
  import streamlit as st
2
  from transformers import pipeline
3
- from concurrent.futures import ThreadPoolExecutor
4
- import time
 
 
5
 
6
  # Function to load a single model
 
7
  def load_model(model_name):
8
  if model_name == 'BART':
9
- return pipeline("summarization", model="facebook/bart-large-cnn")
10
  elif model_name == 'T5':
11
- return pipeline("summarization", model="t5-large")
12
  elif model_name == 'Pegasus':
13
- return pipeline("summarization", model="google/pegasus-cnn_dailymail")
14
-
15
- # Function to load all models concurrently and cache them
16
- @st.cache_resource
17
- def load_all_models():
18
- model_names = ['BART', 'T5', 'Pegasus']
19
- models = {}
20
- with ThreadPoolExecutor() as executor:
21
- futures = {executor.submit(load_model, name): name for name in model_names}
22
- for future in futures:
23
- model_name = futures[future]
24
- models[model_name] = future.result()
25
- return models
26
 
27
  # Streamlit app layout
28
  st.title("Text Summarization with Pre-trained Models (BART, T5, Pegasus)")
@@ -34,23 +25,16 @@ if text_input:
34
  word_count = len(text_input.split())
35
  st.write(f"**Word Count:** {word_count}")
36
 
37
- if st.button("Generate Summaries"):
38
- with st.spinner("Loading models and generating summaries..."):
39
- start_time = time.time()
40
- models = load_all_models()
41
- summaries = {}
42
- for model_name, model in models.items():
43
- summary = model(text_input, max_length=150, min_length=50, length_penalty=2.0, num_beams=4, early_stopping=True)[0]['summary_text']
44
- summaries[model_name] = summary
45
- end_time = time.time()
46
 
47
- st.subheader("Summaries")
48
- for model_name, summary in summaries.items():
49
- summary_word_count = len(summary.split())
50
- st.write(f"**{model_name}**")
51
- st.write(summary.replace('<n>', ''))
52
- st.write(f"**Summary Word Count:** {summary_word_count}")
53
- st.write("---")
54
- st.write(f"**Total Time Taken:** {end_time - start_time:.2f} seconds")
 
55
  else:
56
  st.error("Please enter text to summarize.")
 
1
  import streamlit as st
2
  from transformers import pipeline
3
+ import os
4
+
5
+ # Set the cache directory
6
+ CACHE_DIR = os.getenv('HF_HOME', os.path.expanduser('~/.cache/huggingface/hub'))
7
 
8
  # Function to load a single model
9
+ @st.cache_resource
10
  def load_model(model_name):
11
  if model_name == 'BART':
12
+ return pipeline("summarization", model="facebook/bart-large-cnn", cache_dir=CACHE_DIR)
13
  elif model_name == 'T5':
14
+ return pipeline("summarization", model="t5-large", cache_dir=CACHE_DIR)
15
  elif model_name == 'Pegasus':
16
+ return pipeline("summarization", model="google/pegasus-cnn_dailymail", cache_dir=CACHE_DIR)
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  # Streamlit app layout
19
  st.title("Text Summarization with Pre-trained Models (BART, T5, Pegasus)")
 
25
  word_count = len(text_input.split())
26
  st.write(f"**Word Count:** {word_count}")
27
 
28
+ model_choice = st.selectbox("Choose a model:", ['BART', 'T5', 'Pegasus'])
 
 
 
 
 
 
 
 
29
 
30
+ if st.button("Generate Summary"):
31
+ with st.spinner(f"Generating summary using {model_choice}..."):
32
+ summarizer = load_model(model_choice)
33
+ summary = summarizer(text_input, max_length=150, min_length=50, length_penalty=2.0, num_beams=4, early_stopping=True)[0]['summary_text']
34
+
35
+ summary_word_count = len(summary.split())
36
+ st.subheader(f"Summary using {model_choice}")
37
+ st.write(summary.replace('<n>', ''))
38
+ st.write(f"**Summary Word Count:** {summary_word_count}")
39
  else:
40
  st.error("Please enter text to summarize.")