Jay7478 commited on
Commit
69a9209
·
1 Parent(s): 4ad4e1d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -107
app.py CHANGED
@@ -1,104 +1,15 @@
1
- # import streamlit as st
2
- # from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
- # import nltk
4
- # import math
5
- # import torch
6
-
7
- # model_name = "afnanmmir/t5-base-abstract-to-plain-language-1"
8
- # # model_name = "afnanmmir/t5-base-axriv-to-abstract-3"
9
- # max_input_length = 1024
10
- # max_output_length = 256
11
-
12
- # st.header("Generate summaries")
13
-
14
- # st_model_load = st.text('Loading summary generator model...')
15
-
16
- # # @st.cache(allow_output_mutation=True)
17
- # @st.cache_data
18
- # def load_model():
19
- # print("Loading model...")
20
- # tokenizer = AutoTokenizer.from_pretrained(model_name)
21
- # model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
22
- # nltk.download('punkt')
23
- # print("Model loaded!")
24
- # return tokenizer, model
25
-
26
- # tokenizer, model = load_model()
27
- # st.success('Model loaded!')
28
- # st_model_load.text("")
29
-
30
- # with st.sidebar:
31
- # # st.header("Model parameters")
32
- # # if 'num_titles' not in st.session_state:
33
- # # st.session_state.num_titles = 5
34
- # # def on_change_num_titles():
35
- # # st.session_state.num_titles = num_titles
36
- # # num_titles = st.slider("Number of titles to generate", min_value=1, max_value=10, value=1, step=1, on_change=on_change_num_titles)
37
- # # if 'temperature' not in st.session_state:
38
- # # st.session_state.temperature = 0.7
39
- # # def on_change_temperatures():
40
- # # st.session_state.temperature = temperature
41
- # # temperature = st.slider("Temperature", min_value=0.1, max_value=1.5, value=0.6, step=0.05, on_change=on_change_temperatures)
42
- # # st.markdown("_High temperature means that results are more random_")
43
-
44
- # if 'text' not in st.session_state:
45
- # st.session_state.text = ""
46
- # st_text_area = st.text_area('Text to generate the summary for', value=st.session_state.text, height=500)
47
-
48
- # def generate_summary():
49
- # st.session_state.text = st_text_area
50
-
51
- # # tokenize text
52
- # inputs = ["summarize: " + st_text_area]
53
- # # print(inputs)
54
- # inputs = tokenizer(inputs, return_tensors="pt", max_length=max_input_length, truncation=True)
55
- # # print("Tokenized inputs: ")
56
- # # print(inputs)
57
-
58
- # outputs = model.generate(**inputs, do_sample=True, max_length=max_output_length, early_stopping=True, num_beams=8, length_penalty=2.0, no_repeat_ngram_size=2, min_length=64)
59
- # # print("outputs", outputs)
60
- # decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
61
- # # print("Decoded_outputs", decoded_outputs)
62
- # predicted_summaries = nltk.sent_tokenize(decoded_outputs.strip())
63
- # # print("Predicted summaries", predicted_summaries)
64
-
65
- # # decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
66
- # # predicted_summaries = [nltk.sent_tokenize(decoded_output.strip())[0] for decoded_output in decoded_outputs]
67
-
68
- # st.session_state.summaries = predicted_summaries
69
-
70
- # # generate title button
71
- # st_generate_button = st.button('Generate summary', on_click=generate_summary)
72
-
73
- # # title generation labels
74
- # if 'summaries' not in st.session_state:
75
- # st.session_state.summaries = []
76
-
77
- # if len(st.session_state.summaries) > 0:
78
- # # print("In summaries if")
79
- # with st.container():
80
- # st.subheader("Generated summaries")
81
- # for summary in st.session_state.summaries:
82
- # st.markdown("__" + summary + "__")
83
-
84
-
85
- # -------------------------------
86
-
87
  import streamlit as st
88
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
89
  import nltk
90
  import math
91
  import torch
92
 
93
- # model_name = "afnanmmir/t5-base-abstract-to-plain-language-1"
94
  model_name = "fabiochiu/t5-base-medium-title-generation"
95
- max_input_length = 1024
96
- max_output_length = 256
97
- min_output_length = 64
98
 
99
- st.header("Generate summaries for articles")
100
 
101
- st_model_load = st.text('Loading summary generator model...')
102
 
103
  @st.cache(allow_output_mutation=True)
104
  def load_model():
@@ -113,33 +24,78 @@ tokenizer, model = load_model()
113
  st.success('Model loaded!')
114
  st_model_load.text("")
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  if 'text' not in st.session_state:
117
  st.session_state.text = ""
118
- st_text_area = st.text_area('Text to generate the summary for', value=st.session_state.text, height=500)
119
 
120
- def generate_summary():
121
  st.session_state.text = st_text_area
122
 
123
  # tokenize text
124
  inputs = ["summarize: " + st_text_area]
125
- inputs = tokenizer(inputs, return_tensors="pt", max_length=max_input_length, truncation=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
  # compute predictions
128
- outputs = model.generate(**inputs, do_sample=True, max_length=max_output_length, early_stopping=True, num_beams=8, length_penalty=2.0, no_repeat_ngram_size=2, min_length=min_output_length)
129
  decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
130
- predicted_summaries = [nltk.sent_tokenize(decoded_output.strip())[0] for decoded_output in decoded_outputs]
131
 
132
- st.session_state.summaries = predicted_summaries
133
 
134
- # generate summary button
135
- st_generate_button = st.button('Generate summary', on_click=generate_summary)
136
 
137
- # summary generation labels
138
- if 'summaries' not in st.session_state:
139
- st.session_state.summaries = []
140
 
141
- if len(st.session_state.summaries) > 0:
142
  with st.container():
143
- st.subheader("Generated summaries")
144
- for summary in st.session_state.summaries:
145
- st.markdown("__" + summary + "__")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  import nltk
4
  import math
5
  import torch
6
 
 
7
  model_name = "fabiochiu/t5-base-medium-title-generation"
8
+ max_input_length = 512
 
 
9
 
10
+ st.header("Generate candidate titles for articles")
11
 
12
+ st_model_load = st.text('Loading title generator model...')
13
 
14
  @st.cache(allow_output_mutation=True)
15
  def load_model():
 
24
  st.success('Model loaded!')
25
  st_model_load.text("")
26
 
27
+ with st.sidebar:
28
+ st.header("Model parameters")
29
+ if 'num_titles' not in st.session_state:
30
+ st.session_state.num_titles = 5
31
+ def on_change_num_titles():
32
+ st.session_state.num_titles = num_titles
33
+ num_titles = st.slider("Number of titles to generate", min_value=1, max_value=10, value=1, step=1, on_change=on_change_num_titles)
34
+ if 'temperature' not in st.session_state:
35
+ st.session_state.temperature = 0.7
36
+ def on_change_temperatures():
37
+ st.session_state.temperature = temperature
38
+ temperature = st.slider("Temperature", min_value=0.1, max_value=1.5, value=0.6, step=0.05, on_change=on_change_temperatures)
39
+ st.markdown("_High temperature means that results are more random_")
40
+
41
  if 'text' not in st.session_state:
42
  st.session_state.text = ""
43
+ st_text_area = st.text_area('Text to generate the title for', value=st.session_state.text, height=500)
44
 
45
+ def generate_title():
46
  st.session_state.text = st_text_area
47
 
48
  # tokenize text
49
  inputs = ["summarize: " + st_text_area]
50
+ inputs = tokenizer(inputs, return_tensors="pt")
51
+
52
+ # compute span boundaries
53
+ num_tokens = len(inputs["input_ids"][0])
54
+ print(f"Input has {num_tokens} tokens")
55
+ max_input_length = 500
56
+ num_spans = math.ceil(num_tokens / max_input_length)
57
+ print(f"Input has {num_spans} spans")
58
+ overlap = math.ceil((num_spans * max_input_length - num_tokens) / max(num_spans - 1, 1))
59
+ spans_boundaries = []
60
+ start = 0
61
+ for i in range(num_spans):
62
+ spans_boundaries.append([start + max_input_length * i, start + max_input_length * (i + 1)])
63
+ start -= overlap
64
+ print(f"Span boundaries are {spans_boundaries}")
65
+ spans_boundaries_selected = []
66
+ j = 0
67
+ for _ in range(num_titles):
68
+ spans_boundaries_selected.append(spans_boundaries[j])
69
+ j += 1
70
+ if j == len(spans_boundaries):
71
+ j = 0
72
+ print(f"Selected span boundaries are {spans_boundaries_selected}")
73
+
74
+ # transform input with spans
75
+ tensor_ids = [inputs["input_ids"][0][boundary[0]:boundary[1]] for boundary in spans_boundaries_selected]
76
+ tensor_masks = [inputs["attention_mask"][0][boundary[0]:boundary[1]] for boundary in spans_boundaries_selected]
77
+
78
+ inputs = {
79
+ "input_ids": torch.stack(tensor_ids),
80
+ "attention_mask": torch.stack(tensor_masks)
81
+ }
82
 
83
  # compute predictions
84
+ outputs = model.generate(**inputs, do_sample=True, temperature=temperature)
85
  decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
86
+ predicted_titles = [nltk.sent_tokenize(decoded_output.strip())[0] for decoded_output in decoded_outputs]
87
 
88
+ st.session_state.titles = predicted_titles
89
 
90
+ # generate title button
91
+ st_generate_button = st.button('Generate title', on_click=generate_title)
92
 
93
+ # title generation labels
94
+ if 'titles' not in st.session_state:
95
+ st.session_state.titles = []
96
 
97
+ if len(st.session_state.titles) > 0:
98
  with st.container():
99
+ st.subheader("Generated titles")
100
+ for title in st.session_state.titles:
101
+ st.markdown("__" + title + "__")