Jay7478 commited on
Commit
1228bd8
·
1 Parent(s): 69a9209

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -64
app.py CHANGED
@@ -1,15 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  import nltk
4
  import math
5
  import torch
6
 
7
- model_name = "fabiochiu/t5-base-medium-title-generation"
8
- max_input_length = 512
 
 
9
 
10
- st.header("Generate candidate titles for articles")
11
 
12
- st_model_load = st.text('Loading title generator model...')
13
 
14
  @st.cache(allow_output_mutation=True)
15
  def load_model():
@@ -24,78 +112,33 @@ tokenizer, model = load_model()
24
  st.success('Model loaded!')
25
  st_model_load.text("")
26
 
27
- with st.sidebar:
28
- st.header("Model parameters")
29
- if 'num_titles' not in st.session_state:
30
- st.session_state.num_titles = 5
31
- def on_change_num_titles():
32
- st.session_state.num_titles = num_titles
33
- num_titles = st.slider("Number of titles to generate", min_value=1, max_value=10, value=1, step=1, on_change=on_change_num_titles)
34
- if 'temperature' not in st.session_state:
35
- st.session_state.temperature = 0.7
36
- def on_change_temperatures():
37
- st.session_state.temperature = temperature
38
- temperature = st.slider("Temperature", min_value=0.1, max_value=1.5, value=0.6, step=0.05, on_change=on_change_temperatures)
39
- st.markdown("_High temperature means that results are more random_")
40
-
41
  if 'text' not in st.session_state:
42
  st.session_state.text = ""
43
- st_text_area = st.text_area('Text to generate the title for', value=st.session_state.text, height=500)
44
 
45
- def generate_title():
46
  st.session_state.text = st_text_area
47
 
48
  # tokenize text
49
  inputs = ["summarize: " + st_text_area]
50
- inputs = tokenizer(inputs, return_tensors="pt")
51
-
52
- # compute span boundaries
53
- num_tokens = len(inputs["input_ids"][0])
54
- print(f"Input has {num_tokens} tokens")
55
- max_input_length = 500
56
- num_spans = math.ceil(num_tokens / max_input_length)
57
- print(f"Input has {num_spans} spans")
58
- overlap = math.ceil((num_spans * max_input_length - num_tokens) / max(num_spans - 1, 1))
59
- spans_boundaries = []
60
- start = 0
61
- for i in range(num_spans):
62
- spans_boundaries.append([start + max_input_length * i, start + max_input_length * (i + 1)])
63
- start -= overlap
64
- print(f"Span boundaries are {spans_boundaries}")
65
- spans_boundaries_selected = []
66
- j = 0
67
- for _ in range(num_titles):
68
- spans_boundaries_selected.append(spans_boundaries[j])
69
- j += 1
70
- if j == len(spans_boundaries):
71
- j = 0
72
- print(f"Selected span boundaries are {spans_boundaries_selected}")
73
-
74
- # transform input with spans
75
- tensor_ids = [inputs["input_ids"][0][boundary[0]:boundary[1]] for boundary in spans_boundaries_selected]
76
- tensor_masks = [inputs["attention_mask"][0][boundary[0]:boundary[1]] for boundary in spans_boundaries_selected]
77
-
78
- inputs = {
79
- "input_ids": torch.stack(tensor_ids),
80
- "attention_mask": torch.stack(tensor_masks)
81
- }
82
 
83
  # compute predictions
84
- outputs = model.generate(**inputs, do_sample=True, temperature=temperature)
85
  decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
86
- predicted_titles = [nltk.sent_tokenize(decoded_output.strip())[0] for decoded_output in decoded_outputs]
87
 
88
- st.session_state.titles = predicted_titles
89
 
90
- # generate title button
91
- st_generate_button = st.button('Generate title', on_click=generate_title)
92
 
93
- # title generation labels
94
- if 'titles' not in st.session_state:
95
- st.session_state.titles = []
96
 
97
- if len(st.session_state.titles) > 0:
98
  with st.container():
99
- st.subheader("Generated titles")
100
- for title in st.session_state.titles:
101
- st.markdown("__" + title + "__")
 
1
+ # import streamlit as st
2
+ # from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
+ # import nltk
4
+ # import math
5
+ # import torch
6
+
7
+ # model_name = "afnanmmir/t5-base-abstract-to-plain-language-1"
8
+ # # model_name = "afnanmmir/t5-base-axriv-to-abstract-3"
9
+ # max_input_length = 1024
10
+ # max_output_length = 256
11
+
12
+ # st.header("Generate summaries")
13
+
14
+ # st_model_load = st.text('Loading summary generator model...')
15
+
16
+ # # @st.cache(allow_output_mutation=True)
17
+ # @st.cache_data
18
+ # def load_model():
19
+ # print("Loading model...")
20
+ # tokenizer = AutoTokenizer.from_pretrained(model_name)
21
+ # model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
22
+ # nltk.download('punkt')
23
+ # print("Model loaded!")
24
+ # return tokenizer, model
25
+
26
+ # tokenizer, model = load_model()
27
+ # st.success('Model loaded!')
28
+ # st_model_load.text("")
29
+
30
+ # with st.sidebar:
31
+ # # st.header("Model parameters")
32
+ # # if 'num_titles' not in st.session_state:
33
+ # # st.session_state.num_titles = 5
34
+ # # def on_change_num_titles():
35
+ # # st.session_state.num_titles = num_titles
36
+ # # num_titles = st.slider("Number of titles to generate", min_value=1, max_value=10, value=1, step=1, on_change=on_change_num_titles)
37
+ # # if 'temperature' not in st.session_state:
38
+ # # st.session_state.temperature = 0.7
39
+ # # def on_change_temperatures():
40
+ # # st.session_state.temperature = temperature
41
+ # # temperature = st.slider("Temperature", min_value=0.1, max_value=1.5, value=0.6, step=0.05, on_change=on_change_temperatures)
42
+ # # st.markdown("_High temperature means that results are more random_")
43
+
44
+ # if 'text' not in st.session_state:
45
+ # st.session_state.text = ""
46
+ # st_text_area = st.text_area('Text to generate the summary for', value=st.session_state.text, height=500)
47
+
48
+ # def generate_summary():
49
+ # st.session_state.text = st_text_area
50
+
51
+ # # tokenize text
52
+ # inputs = ["summarize: " + st_text_area]
53
+ # # print(inputs)
54
+ # inputs = tokenizer(inputs, return_tensors="pt", max_length=max_input_length, truncation=True)
55
+ # # print("Tokenized inputs: ")
56
+ # # print(inputs)
57
+
58
+ # outputs = model.generate(**inputs, do_sample=True, max_length=max_output_length, early_stopping=True, num_beams=8, length_penalty=2.0, no_repeat_ngram_size=2, min_length=64)
59
+ # # print("outputs", outputs)
60
+ # decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
61
+ # # print("Decoded_outputs", decoded_outputs)
62
+ # predicted_summaries = nltk.sent_tokenize(decoded_outputs.strip())
63
+ # # print("Predicted summaries", predicted_summaries)
64
+
65
+ # # decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
66
+ # # predicted_summaries = [nltk.sent_tokenize(decoded_output.strip())[0] for decoded_output in decoded_outputs]
67
+
68
+ # st.session_state.summaries = predicted_summaries
69
+
70
+ # # generate title button
71
+ # st_generate_button = st.button('Generate summary', on_click=generate_summary)
72
+
73
+ # # title generation labels
74
+ # if 'summaries' not in st.session_state:
75
+ # st.session_state.summaries = []
76
+
77
+ # if len(st.session_state.summaries) > 0:
78
+ # # print("In summaries if")
79
+ # with st.container():
80
+ # st.subheader("Generated summaries")
81
+ # for summary in st.session_state.summaries:
82
+ # st.markdown("__" + summary + "__")
83
+
84
+
85
+ # -------------------------------
86
+
87
  import streamlit as st
88
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
89
  import nltk
90
  import math
91
  import torch
92
 
93
+ model_name = "afnanmmir/t5-base-abstract-to-plain-language-1"
94
+ max_input_length = 1024
95
+ max_output_length = 256
96
+ min_output_length = 64
97
 
98
+ st.header("Generate summaries for articles")
99
 
100
+ st_model_load = st.text('Loading summary generator model...')
101
 
102
  @st.cache(allow_output_mutation=True)
103
  def load_model():
 
112
  st.success('Model loaded!')
113
  st_model_load.text("")
114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  if 'text' not in st.session_state:
116
  st.session_state.text = ""
117
+ st_text_area = st.text_area('Text to generate the summary for', value=st.session_state.text, height=500)
118
 
119
+ def generate_summary():
120
  st.session_state.text = st_text_area
121
 
122
  # tokenize text
123
  inputs = ["summarize: " + st_text_area]
124
+ inputs = tokenizer(inputs, return_tensors="pt", max_length=max_input_length, truncation=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
  # compute predictions
127
+ outputs = model.generate(**inputs, do_sample=True, max_length=max_output_length, early_stopping=True, num_beams=8, length_penalty=2.0, no_repeat_ngram_size=2, min_length=min_output_length)
128
  decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
129
+ predicted_summaries = [nltk.sent_tokenize(decoded_output.strip())[0] for decoded_output in decoded_outputs]
130
 
131
+ st.session_state.summaries = predicted_summaries
132
 
133
+ # generate summary button
134
+ st_generate_button = st.button('Generate summary', on_click=generate_summary)
135
 
136
+ # summary generation labels
137
+ if 'summaries' not in st.session_state:
138
+ st.session_state.summaries = []
139
 
140
+ if len(st.session_state.summaries) > 0:
141
  with st.container():
142
+ st.subheader("Generated summaries")
143
+ for summary in st.session_state.summaries:
144
+ st.markdown("__" + summary + "__")