ms1449 commited on
Commit
29022fa
·
verified ·
1 Parent(s): 4b6b28c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -11
app.py CHANGED
@@ -2,7 +2,7 @@ import streamlit as st
2
  from transformers import pipeline, AutoTokenizer
3
  import torch
4
 
5
- # Define the model
6
  model = "facebook/bart-large-cnn"
7
 
8
  @st.cache_resource
@@ -10,18 +10,13 @@ def load_summarizer():
10
  return pipeline("summarization", model=model)
11
 
12
  def get_model_max_length(model_name):
13
- try:
14
  tokenizer = AutoTokenizer.from_pretrained(model_name)
15
  return tokenizer.model_max_length
16
- except:
17
- return 1024 # default value if unable to determine
18
 
19
  def generate_summary(text):
20
  summarizer = load_summarizer()
21
  max_input_length = min(summarizer.tokenizer.model_max_length, 1024)
22
- truncated_text = summarizer.tokenizer.decode(
23
- summarizer.tokenizer.encode(text, truncation=True, max_length=max_input_length)
24
- )
25
  summary = summarizer(truncated_text, max_new_tokens=150, min_new_tokens=40, do_sample=False)[0]['summary_text']
26
  return summary
27
 
@@ -38,8 +33,6 @@ if st.button("Generate Summary"):
38
  st.subheader("Summary:")
39
  st.write(summary)
40
  else:
41
- st.warning("Please enter some text to summarize.")
42
 
43
- st.write("---")
44
- st.write("Model: facebook/bart-large-cnn")
45
- st.write(f"Max input length: {get_model_max_length(model)}")
 
2
  from transformers import pipeline, AutoTokenizer
3
  import torch
4
 
5
+ # model
6
  model = "facebook/bart-large-cnn"
7
 
8
  @st.cache_resource
 
10
  return pipeline("summarization", model=model)
11
 
12
  def get_model_max_length(model_name):
 
13
  tokenizer = AutoTokenizer.from_pretrained(model_name)
14
  return tokenizer.model_max_length
 
 
15
 
16
  def generate_summary(text):
17
  summarizer = load_summarizer()
18
  max_input_length = min(summarizer.tokenizer.model_max_length, 1024)
19
+ truncated_text = summarizer.tokenizer.decode(summarizer.tokenizer.encode(text, truncation=True, max_length=max_input_length))
 
 
20
  summary = summarizer(truncated_text, max_new_tokens=150, min_new_tokens=40, do_sample=False)[0]['summary_text']
21
  return summary
22
 
 
33
  st.subheader("Summary:")
34
  st.write(summary)
35
  else:
36
+ st.warning("Please enter text to summarize.")
37
 
38
+ st.write("---")