Dhanush S Gowda commited on
Commit
401ffed
·
verified ·
1 Parent(s): 53545b3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -26
app.py CHANGED
@@ -2,39 +2,74 @@ import streamlit as st
2
  from transformers import pipeline
3
  import os
4
 
5
- # Set the cache directory
6
- CACHE_DIR = os.getenv('HF_HOME', os.path.expanduser('~/.cache/huggingface/hub'))
7
 
8
- # Function to load a single model
9
  @st.cache_resource
10
- def load_model(model_name):
11
- if model_name == 'BART':
12
- return pipeline("summarization", model="facebook/bart-large-cnn", cache_dir=CACHE_DIR)
13
- elif model_name == 'T5':
14
- return pipeline("summarization", model="t5-large", cache_dir=CACHE_DIR)
15
- elif model_name == 'Pegasus':
16
- return pipeline("summarization", model="google/pegasus-cnn_dailymail", cache_dir=CACHE_DIR)
17
 
18
  # Streamlit app layout
19
- st.title("Text Summarization with Pre-trained Models (BART, T5, Pegasus)")
20
 
 
 
 
 
 
21
  text_input = st.text_area("Enter text to summarize:")
22
 
 
 
 
 
 
23
  if text_input:
24
- # Display word count of input text
25
  word_count = len(text_input.split())
26
- st.write(f"**Word Count:** {word_count}")
27
-
28
- model_choice = st.selectbox("Choose a model:", ['BART', 'T5', 'Pegasus'])
29
-
30
- if st.button("Generate Summary"):
31
- with st.spinner(f"Generating summary using {model_choice}..."):
32
- summarizer = load_model(model_choice)
33
- summary = summarizer(text_input, max_length=150, min_length=50, length_penalty=2.0, num_beams=4, early_stopping=True)[0]['summary_text']
34
-
35
- summary_word_count = len(summary.split())
36
- st.subheader(f"Summary using {model_choice}")
37
- st.write(summary.replace('<n>', ''))
38
- st.write(f"**Summary Word Count:** {summary_word_count}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  else:
40
- st.error("Please enter text to summarize.")
 
2
  from transformers import pipeline
3
  import os
4
 
5
+ # Set Hugging Face cache directory
6
+ os.environ['TRANSFORMERS_CACHE'] = os.getenv('HF_HOME', os.path.expanduser('~/.cache/huggingface/hub'))
7
 
8
+ # Function to load all three models
9
  @st.cache_resource
10
+ def load_models():
11
+ bart_summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
12
+ t5_summarizer = pipeline("summarization", model="t5-large")
13
+ pegasus_summarizer = pipeline("summarization", model="google/pegasus-cnn_dailymail")
14
+ return bart_summarizer, t5_summarizer, pegasus_summarizer
 
 
15
 
16
  # Streamlit app layout
17
+ st.title("Text Summarization with Pre-trained Models: BART, T5, Pegasus")
18
 
19
+ # Load models
20
+ with st.spinner("Loading models..."):
21
+ bart_model, t5_model, pegasus_model = load_models()
22
+
23
+ # Input text
24
  text_input = st.text_area("Enter text to summarize:")
25
 
26
+ # User input for min and max words
27
+ st.sidebar.header("Summary Length Settings")
28
+ min_words = st.sidebar.slider("Minimum words in summary:", 10, 100, 50, step=5)
29
+ max_words = st.sidebar.slider("Maximum words in summary:", min_words + 10, 300, 150, step=10)
30
+
31
  if text_input:
 
32
  word_count = len(text_input.split())
33
+ st.write(f"**Input Word Count:** {word_count}")
34
+
35
+ if st.button("Generate Summaries"):
36
+ with st.spinner("Generating summaries..."):
37
+ # Generate summaries with dynamic length constraints
38
+ bart_summary = bart_model(
39
+ text_input,
40
+ max_length=max_words,
41
+ min_length=min_words,
42
+ num_beams=4,
43
+ early_stopping=True
44
+ )[0]['summary_text']
45
+
46
+ t5_summary = t5_model(
47
+ text_input,
48
+ max_length=max_words,
49
+ min_length=min_words,
50
+ num_beams=4,
51
+ early_stopping=True
52
+ )[0]['summary_text']
53
+
54
+ pegasus_summary = pegasus_model(
55
+ text_input,
56
+ max_length=max_words,
57
+ min_length=min_words,
58
+ num_beams=4,
59
+ early_stopping=True
60
+ )[0]['summary_text']
61
+
62
+ # Display summaries
63
+ st.subheader("BART Summary")
64
+ st.write(bart_summary)
65
+ st.write(f"**Word Count:** {len(bart_summary.split())}")
66
+
67
+ st.subheader("T5 Summary")
68
+ st.write(t5_summary)
69
+ st.write(f"**Word Count:** {len(t5_summary.split())}")
70
+
71
+ st.subheader("Pegasus Summary")
72
+ st.write(pegasus_summary)
73
+ st.write(f"**Word Count:** {len(pegasus_summary.split())}")
74
  else:
75
+ st.warning("Please enter text to summarize.")