Jay7478 commited on
Commit
b71da86
·
1 Parent(s): f3ca51f

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +144 -0
app.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import streamlit as st
2
+ # from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
+ # import nltk
4
+ # import math
5
+ # import torch
6
+
7
+ # model_name = "afnanmmir/t5-base-abstract-to-plain-language-1"
8
+ # # model_name = "afnanmmir/t5-base-axriv-to-abstract-3"
9
+ # max_input_length = 1024
10
+ # max_output_length = 256
11
+
12
+ # st.header("Generate summaries")
13
+
14
+ # st_model_load = st.text('Loading summary generator model...')
15
+
16
+ # # @st.cache(allow_output_mutation=True)
17
+ # @st.cache_data
18
+ # def load_model():
19
+ # print("Loading model...")
20
+ # tokenizer = AutoTokenizer.from_pretrained(model_name)
21
+ # model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
22
+ # nltk.download('punkt')
23
+ # print("Model loaded!")
24
+ # return tokenizer, model
25
+
26
+ # tokenizer, model = load_model()
27
+ # st.success('Model loaded!')
28
+ # st_model_load.text("")
29
+
30
+ # with st.sidebar:
31
+ # # st.header("Model parameters")
32
+ # # if 'num_titles' not in st.session_state:
33
+ # # st.session_state.num_titles = 5
34
+ # # def on_change_num_titles():
35
+ # # st.session_state.num_titles = num_titles
36
+ # # num_titles = st.slider("Number of titles to generate", min_value=1, max_value=10, value=1, step=1, on_change=on_change_num_titles)
37
+ # # if 'temperature' not in st.session_state:
38
+ # # st.session_state.temperature = 0.7
39
+ # # def on_change_temperatures():
40
+ # # st.session_state.temperature = temperature
41
+ # # temperature = st.slider("Temperature", min_value=0.1, max_value=1.5, value=0.6, step=0.05, on_change=on_change_temperatures)
42
+ # # st.markdown("_High temperature means that results are more random_")
43
+
44
+ # if 'text' not in st.session_state:
45
+ # st.session_state.text = ""
46
+ # st_text_area = st.text_area('Text to generate the summary for', value=st.session_state.text, height=500)
47
+
48
+ # def generate_summary():
49
+ # st.session_state.text = st_text_area
50
+
51
+ # # tokenize text
52
+ # inputs = ["summarize: " + st_text_area]
53
+ # # print(inputs)
54
+ # inputs = tokenizer(inputs, return_tensors="pt", max_length=max_input_length, truncation=True)
55
+ # # print("Tokenized inputs: ")
56
+ # # print(inputs)
57
+
58
+ # outputs = model.generate(**inputs, do_sample=True, max_length=max_output_length, early_stopping=True, num_beams=8, length_penalty=2.0, no_repeat_ngram_size=2, min_length=64)
59
+ # # print("outputs", outputs)
60
+ # decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
61
+ # # print("Decoded_outputs", decoded_outputs)
62
+ # predicted_summaries = nltk.sent_tokenize(decoded_outputs.strip())
63
+ # # print("Predicted summaries", predicted_summaries)
64
+
65
+ # # decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
66
+ # # predicted_summaries = [nltk.sent_tokenize(decoded_output.strip())[0] for decoded_output in decoded_outputs]
67
+
68
+ # st.session_state.summaries = predicted_summaries
69
+
70
+ # # generate title button
71
+ # st_generate_button = st.button('Generate summary', on_click=generate_summary)
72
+
73
+ # # title generation labels
74
+ # if 'summaries' not in st.session_state:
75
+ # st.session_state.summaries = []
76
+
77
+ # if len(st.session_state.summaries) > 0:
78
+ # # print("In summaries if")
79
+ # with st.container():
80
+ # st.subheader("Generated summaries")
81
+ # for summary in st.session_state.summaries:
82
+ # st.markdown("__" + summary + "__")
83
+
84
+
85
+ # -------------------------------
86
+
87
+ import streamlit as st
88
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
89
+ import nltk
90
+ import math
91
+ import torch
92
+
93
+ model_name = "afnanmmir/t5-base-abstract-to-plain-language-1"
94
+ max_input_length = 1024
95
+ max_output_length = 256
96
+ min_output_length = 64
97
+
98
+ st.header("Generate summaries for articles")
99
+
100
+ st_model_load = st.text('Loading summary generator model...')
101
+
102
+ @st.cache(allow_output_mutation=True)
103
+ def load_model():
104
+ print("Loading model...")
105
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
106
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
107
+ nltk.download('punkt')
108
+ print("Model loaded!")
109
+ return tokenizer, model
110
+
111
+ tokenizer, model = load_model()
112
+ st.success('Model loaded!')
113
+ st_model_load.text("")
114
+
115
+ if 'text' not in st.session_state:
116
+ st.session_state.text = ""
117
+ st_text_area = st.text_area('Text to generate the summary for', value=st.session_state.text, height=500)
118
+
119
+ def generate_summary():
120
+ st.session_state.text = st_text_area
121
+
122
+ # tokenize text
123
+ inputs = ["summarize: " + st_text_area]
124
+ inputs = tokenizer(inputs, return_tensors="pt", max_length=max_input_length, truncation=True)
125
+
126
+ # compute predictions
127
+ outputs = model.generate(**inputs, do_sample=True, max_length=max_output_length, early_stopping=True, num_beams=8, length_penalty=2.0, no_repeat_ngram_size=2, min_length=min_output_length)
128
+ decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
129
+ predicted_summaries = [nltk.sent_tokenize(decoded_output.strip())[0] for decoded_output in decoded_outputs]
130
+
131
+ st.session_state.summaries = predicted_summaries
132
+
133
+ # generate summary button
134
+ st_generate_button = st.button('Generate summary', on_click=generate_summary)
135
+
136
+ # summary generation labels
137
+ if 'summaries' not in st.session_state:
138
+ st.session_state.summaries = []
139
+
140
+ if len(st.session_state.summaries) > 0:
141
+ with st.container():
142
+ st.subheader("Generated summaries")
143
+ for summary in st.session_state.summaries:
144
+ st.markdown("__" + summary + "__")