ms1449 commited on
Commit
9a5450f
·
verified ·
1 Parent(s): 45acf72

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -25
app.py CHANGED
@@ -1,38 +1,63 @@
1
  import streamlit as st
2
- from transformers import pipeline, AutoTokenizer
3
- import torch
4
-
5
- # model
6
- model = "facebook/bart-large-cnn"
7
 
8
  @st.cache_resource
9
- def load_summarizer():
10
- return pipeline("summarization", model=model)
11
-
12
- def get_model_max_length(model_name):
13
- tokenizer = AutoTokenizer.from_pretrained(model_name)
14
- return tokenizer.model_max_length
15
-
16
- def generate_summary(text):
17
- summarizer = load_summarizer()
18
- max_input_length = min(summarizer.tokenizer.model_max_length, 1024)
19
- truncated_text = summarizer.tokenizer.decode(summarizer.tokenizer.encode(text, truncation=True, max_length=max_input_length))
20
- summary = summarizer(truncated_text, max_new_tokens=150, min_new_tokens=40, do_sample=False)[0]['summary_text']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  return summary
22
 
23
- st.title("A simple text-summarization-tool")
 
24
 
25
- st.write("Using the BART-large-CNN model.")
 
26
 
27
- input_text = st.text_area("Enter the text:", height=200)
 
28
 
29
- if st.button("Generate Summary"):
30
  if input_text:
31
  with st.spinner("Generating summary..."):
32
- summary = generate_summary(input_text)
33
  st.subheader("Summary:")
34
  st.write(summary)
35
  else:
36
- st.warning("Please enter text to summarize.")
37
-
38
- st.write("---")
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ from transformers import BartTokenizer, BartForConditionalGeneration
 
 
 
 
3
 
4
  @st.cache_resource
5
+ def load_model():
6
+ model_path = "bart_small_samsum" # Update this if your model path is different
7
+ tokenizer = BartTokenizer.from_pretrained(model_path)
8
+ model = BartForConditionalGeneration.from_pretrained(model_path)
9
+ return tokenizer, model
10
+
11
+ # Set maximum lengths for input and target sequences
12
+ max_input_length = 128
13
+ max_target_length = 64
14
+
15
+ def summarize(input_text, tokenizer, model):
16
+ # Tokenize input text
17
+ inputs = tokenizer(input_text, return_tensors="pt", max_length=max_input_length, truncation=True)
18
+
19
+ # Generate summary
20
+ summary_ids = model.generate(
21
+ inputs["input_ids"],
22
+ max_length=max_target_length,
23
+ min_length=30,
24
+ length_penalty=2.0,
25
+ num_beams=4,
26
+ early_stopping=True
27
+ )
28
+
29
+ # Decode the generated summary
30
+ summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
31
+
32
  return summary
33
 
34
+ # Streamlit app
35
+ st.title("Summarization Tool Using Bart-small Finetuned on Small sized Samsum Dataset")
36
 
37
+ # Load model
38
+ tokenizer, model = load_model()
39
 
40
+ # Text input
41
+ input_text = st.text_area("Enter your dialogue here:", height=200)
42
 
43
+ if st.button("Summarize"):
44
  if input_text:
45
  with st.spinner("Generating summary..."):
46
+ summary = summarize(input_text, tokenizer, model)
47
  st.subheader("Summary:")
48
  st.write(summary)
49
  else:
50
+ st.warning("Please enter some text to summarize.")
51
+
52
+ # Add some information about the model
53
+ st.sidebar.header("About")
54
+ st.sidebar.info(
55
+ "This app uses a fine-tuned BART-Small model to summarize dialogues. "
56
+ "Enter your dialogue in the text area and click 'Summarize' to generate a summary."
57
+ )
58
+
59
+ # You can add more information or customization in the sidebar
60
+ st.sidebar.header("Model Details")
61
+ st.sidebar.text("Model: BART-small")
62
+ st.sidebar.text("Max Input Length: 128 tokens")
63
+ st.sidebar.text("Max Summary Length: 64 tokens")