Spaces:
Running
Running
ashhadahsan
commited on
Commit
·
4e736ad
1
Parent(s):
8f190f9
added the classification model
Browse files
app.py
CHANGED
@@ -4,18 +4,44 @@ from transformers import pipeline
|
|
4 |
from stqdm import stqdm
|
5 |
from simplet5 import SimpleT5
|
6 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
7 |
-
|
8 |
-
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
def load_t5():
|
11 |
model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
|
12 |
|
13 |
tokenizer = AutoTokenizer.from_pretrained("t5-base")
|
|
|
14 |
return model, tokenizer
|
15 |
|
16 |
|
17 |
-
@st.
|
18 |
def custom_model():
|
|
|
19 |
return pipeline("summarization", model="my_awesome_sum/")
|
20 |
|
21 |
|
@@ -25,11 +51,20 @@ def convert_df(df):
|
|
25 |
return df.to_csv(index=False).encode("utf-8")
|
26 |
|
27 |
|
28 |
-
@st.
|
29 |
def load_one_line_summarizer(model):
|
|
|
30 |
return model.load_model("t5", "snrspeaks/t5-one-line-summary")
|
31 |
|
32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
st.set_page_config(layout="wide", page_title="Amazon Review Summarizer")
|
34 |
st.title("Amazon Review Summarizer")
|
35 |
|
@@ -38,6 +73,7 @@ summarizer_option = st.selectbox(
|
|
38 |
"Select Summarizer",
|
39 |
("Custom trained on the dataset", "t5-base", "t5-one-line-summary"),
|
40 |
)
|
|
|
41 |
hide_streamlit_style = """
|
42 |
<style>
|
43 |
#MainMenu {visibility: hidden;}
|
@@ -63,8 +99,7 @@ if st.button("Process"):
|
|
63 |
text = df["text"].values.tolist()
|
64 |
if summarizer_option == "Custom trained on the dataset":
|
65 |
model = custom_model()
|
66 |
-
|
67 |
-
|
68 |
progress_text = "Summarization in progress. Please wait."
|
69 |
summary = []
|
70 |
|
@@ -82,11 +117,31 @@ if st.button("Process"):
|
|
82 |
output = pd.DataFrame(
|
83 |
{"text": df["text"].values.tolist(), "summary": summary}
|
84 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
csv = convert_df(output)
|
86 |
st.download_button(
|
87 |
label="Download data as CSV",
|
88 |
data=csv,
|
89 |
-
file_name=f"{summarizer_option}_df.csv",
|
90 |
mime="text/csv",
|
91 |
)
|
92 |
if summarizer_option == "t5-base":
|
@@ -115,11 +170,31 @@ if st.button("Process"):
|
|
115 |
output = pd.DataFrame(
|
116 |
{"text": df["text"].values.tolist(), "summary": summary}
|
117 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
csv = convert_df(output)
|
119 |
st.download_button(
|
120 |
label="Download data as CSV",
|
121 |
data=csv,
|
122 |
-
file_name=f"{summarizer_option}_df.csv",
|
123 |
mime="text/csv",
|
124 |
)
|
125 |
|
@@ -136,16 +211,39 @@ if st.button("Process"):
|
|
136 |
output = pd.DataFrame(
|
137 |
{"text": df["text"].values.tolist(), "summary": summary}
|
138 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
csv = convert_df(output)
|
140 |
st.download_button(
|
141 |
label="Download data as CSV",
|
142 |
data=csv,
|
143 |
-
file_name=f"{summarizer_option}_df.csv",
|
144 |
mime="text/csv",
|
145 |
)
|
|
|
146 |
except KeyError:
|
147 |
st.error(
|
148 |
"Please Make sure that your data must have a column named text",
|
149 |
icon="🚨",
|
150 |
)
|
151 |
st.info("Text column must have amazon reviews", icon="ℹ️")
|
|
|
|
|
|
4 |
from stqdm import stqdm
|
5 |
from simplet5 import SimpleT5
|
6 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
7 |
+
from transformers import BertTokenizer
|
8 |
+
from tensorflow.keras.models import load_model
|
9 |
+
from tensorflow.nn import softmax
|
10 |
+
import numpy as np
|
11 |
+
from datetime import datetime
|
12 |
+
import logging
|
13 |
+
|
14 |
+
date = datetime.now().strftime(r"%Y-%m-%d")
|
15 |
+
model_classes = {
|
16 |
+
0: "Ads",
|
17 |
+
1: "Apps",
|
18 |
+
2: "Battery",
|
19 |
+
3: "Charging",
|
20 |
+
4: "Delivery",
|
21 |
+
5: "Display",
|
22 |
+
6: "FOS",
|
23 |
+
7: "HW",
|
24 |
+
8: "Order",
|
25 |
+
9: "Refurb",
|
26 |
+
10: "SD",
|
27 |
+
11: "Setup",
|
28 |
+
12: "Unknown",
|
29 |
+
13: "WiFi",
|
30 |
+
}
|
31 |
+
|
32 |
+
|
33 |
+
@st.cache_resource
|
34 |
def load_t5():
|
35 |
model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
|
36 |
|
37 |
tokenizer = AutoTokenizer.from_pretrained("t5-base")
|
38 |
+
st.success("Loaded T5 Model")
|
39 |
return model, tokenizer
|
40 |
|
41 |
|
42 |
+
@st.cache_resource
|
43 |
def custom_model():
|
44 |
+
st.success("Loaded custom model")
|
45 |
return pipeline("summarization", model="my_awesome_sum/")
|
46 |
|
47 |
|
|
|
51 |
return df.to_csv(index=False).encode("utf-8")
|
52 |
|
53 |
|
54 |
+
@st.cache_resource
|
55 |
def load_one_line_summarizer(model):
|
56 |
+
st.success("Loaded one line summarizer")
|
57 |
return model.load_model("t5", "snrspeaks/t5-one-line-summary")
|
58 |
|
59 |
|
60 |
+
@st.cache_resource
|
61 |
+
def classify_category():
|
62 |
+
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
63 |
+
new_model = load_model("model")
|
64 |
+
st.success("Loaded custom classification model")
|
65 |
+
return tokenizer, new_model
|
66 |
+
|
67 |
+
|
68 |
st.set_page_config(layout="wide", page_title="Amazon Review Summarizer")
|
69 |
st.title("Amazon Review Summarizer")
|
70 |
|
|
|
73 |
"Select Summarizer",
|
74 |
("Custom trained on the dataset", "t5-base", "t5-one-line-summary"),
|
75 |
)
|
76 |
+
classification = st.checkbox("Classify Category", value=True)
|
77 |
hide_streamlit_style = """
|
78 |
<style>
|
79 |
#MainMenu {visibility: hidden;}
|
|
|
99 |
text = df["text"].values.tolist()
|
100 |
if summarizer_option == "Custom trained on the dataset":
|
101 |
model = custom_model()
|
102 |
+
|
|
|
103 |
progress_text = "Summarization in progress. Please wait."
|
104 |
summary = []
|
105 |
|
|
|
117 |
output = pd.DataFrame(
|
118 |
{"text": df["text"].values.tolist(), "summary": summary}
|
119 |
)
|
120 |
+
if classification:
|
121 |
+
classification_token, classification_model = classify_category()
|
122 |
+
tf_batch = classification_token(
|
123 |
+
text,
|
124 |
+
max_length=128,
|
125 |
+
padding=True,
|
126 |
+
truncation=True,
|
127 |
+
return_tensors="tf",
|
128 |
+
)
|
129 |
+
with st.spinner(text="identifying theme"):
|
130 |
+
tf_outputs = classification_model(tf_batch)
|
131 |
+
classes = []
|
132 |
+
with st.spinner(text="creating output file"):
|
133 |
+
for x in stqdm(range(len(text))):
|
134 |
+
tf_o = softmax(tf_outputs["logits"][x], axis=-1)
|
135 |
+
label = np.argmax(tf_o, axis=0)
|
136 |
+
keys = model_classes
|
137 |
+
classes.append(keys.get(label))
|
138 |
+
output["category"] = classes
|
139 |
+
|
140 |
csv = convert_df(output)
|
141 |
st.download_button(
|
142 |
label="Download data as CSV",
|
143 |
data=csv,
|
144 |
+
file_name=f"{summarizer_option}_{date}_df.csv",
|
145 |
mime="text/csv",
|
146 |
)
|
147 |
if summarizer_option == "t5-base":
|
|
|
170 |
output = pd.DataFrame(
|
171 |
{"text": df["text"].values.tolist(), "summary": summary}
|
172 |
)
|
173 |
+
if classification:
|
174 |
+
classification_token, classification_model = classify_category()
|
175 |
+
tf_batch = classification_token(
|
176 |
+
text,
|
177 |
+
max_length=128,
|
178 |
+
padding=True,
|
179 |
+
truncation=True,
|
180 |
+
return_tensors="tf",
|
181 |
+
)
|
182 |
+
with st.spinner(text="identifying theme"):
|
183 |
+
tf_outputs = classification_model(tf_batch)
|
184 |
+
classes = []
|
185 |
+
with st.spinner(text="creating output file"):
|
186 |
+
for x in stqdm(range(len(text))):
|
187 |
+
tf_o = tf.nn.softmax(tf_outputs["logits"][x], axis=-1)
|
188 |
+
label = np.argmax(tf_o, axis=0)
|
189 |
+
keys = model_classes
|
190 |
+
classes.append(keys.get(label))
|
191 |
+
output["category"] = classes
|
192 |
+
|
193 |
csv = convert_df(output)
|
194 |
st.download_button(
|
195 |
label="Download data as CSV",
|
196 |
data=csv,
|
197 |
+
file_name=f"{summarizer_option}_{date}_df.csv",
|
198 |
mime="text/csv",
|
199 |
)
|
200 |
|
|
|
211 |
output = pd.DataFrame(
|
212 |
{"text": df["text"].values.tolist(), "summary": summary}
|
213 |
)
|
214 |
+
if classification:
|
215 |
+
classification_token, classification_model = classify_category()
|
216 |
+
tf_batch = classification_token(
|
217 |
+
text,
|
218 |
+
max_length=128,
|
219 |
+
padding=True,
|
220 |
+
truncation=True,
|
221 |
+
return_tensors="tf",
|
222 |
+
)
|
223 |
+
with st.spinner(text="identifying theme"):
|
224 |
+
tf_outputs = classification_model(tf_batch)
|
225 |
+
classes = []
|
226 |
+
with st.spinner(text="creating output file"):
|
227 |
+
for x in stqdm(range(len(text))):
|
228 |
+
tf_o = tf.nn.softmax(tf_outputs["logits"][x], axis=-1)
|
229 |
+
label = np.argmax(tf_o, axis=0)
|
230 |
+
keys = model_classes
|
231 |
+
classes.append(keys.get(label))
|
232 |
+
output["category"] = classes
|
233 |
+
|
234 |
csv = convert_df(output)
|
235 |
st.download_button(
|
236 |
label="Download data as CSV",
|
237 |
data=csv,
|
238 |
+
file_name=f"{summarizer_option}_{date}_df.csv",
|
239 |
mime="text/csv",
|
240 |
)
|
241 |
+
|
242 |
except KeyError:
|
243 |
st.error(
|
244 |
"Please Make sure that your data must have a column named text",
|
245 |
icon="🚨",
|
246 |
)
|
247 |
st.info("Text column must have amazon reviews", icon="ℹ️")
|
248 |
+
except BaseException as e:
|
249 |
+
logging.exception("An exception was occured")
|