ashhadahsan commited on
Commit
4e736ad
·
1 Parent(s): 8f190f9

added the classification model

Browse files
Files changed (1) hide show
  1. app.py +108 -10
app.py CHANGED
@@ -4,18 +4,44 @@ from transformers import pipeline
4
  from stqdm import stqdm
5
  from simplet5 import SimpleT5
6
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
7
-
8
-
9
- @st.cache
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  def load_t5():
11
  model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
12
 
13
  tokenizer = AutoTokenizer.from_pretrained("t5-base")
 
14
  return model, tokenizer
15
 
16
 
17
- @st.cache
18
  def custom_model():
 
19
  return pipeline("summarization", model="my_awesome_sum/")
20
 
21
 
@@ -25,11 +51,20 @@ def convert_df(df):
25
  return df.to_csv(index=False).encode("utf-8")
26
 
27
 
28
- @st.cache
29
  def load_one_line_summarizer(model):
 
30
  return model.load_model("t5", "snrspeaks/t5-one-line-summary")
31
 
32
 
 
 
 
 
 
 
 
 
33
  st.set_page_config(layout="wide", page_title="Amazon Review Summarizer")
34
  st.title("Amazon Review Summarizer")
35
 
@@ -38,6 +73,7 @@ summarizer_option = st.selectbox(
38
  "Select Summarizer",
39
  ("Custom trained on the dataset", "t5-base", "t5-one-line-summary"),
40
  )
 
41
  hide_streamlit_style = """
42
  <style>
43
  #MainMenu {visibility: hidden;}
@@ -63,8 +99,7 @@ if st.button("Process"):
63
  text = df["text"].values.tolist()
64
  if summarizer_option == "Custom trained on the dataset":
65
  model = custom_model()
66
- print(summarizer_option)
67
-
68
  progress_text = "Summarization in progress. Please wait."
69
  summary = []
70
 
@@ -82,11 +117,31 @@ if st.button("Process"):
82
  output = pd.DataFrame(
83
  {"text": df["text"].values.tolist(), "summary": summary}
84
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  csv = convert_df(output)
86
  st.download_button(
87
  label="Download data as CSV",
88
  data=csv,
89
- file_name=f"{summarizer_option}_df.csv",
90
  mime="text/csv",
91
  )
92
  if summarizer_option == "t5-base":
@@ -115,11 +170,31 @@ if st.button("Process"):
115
  output = pd.DataFrame(
116
  {"text": df["text"].values.tolist(), "summary": summary}
117
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  csv = convert_df(output)
119
  st.download_button(
120
  label="Download data as CSV",
121
  data=csv,
122
- file_name=f"{summarizer_option}_df.csv",
123
  mime="text/csv",
124
  )
125
 
@@ -136,16 +211,39 @@ if st.button("Process"):
136
  output = pd.DataFrame(
137
  {"text": df["text"].values.tolist(), "summary": summary}
138
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  csv = convert_df(output)
140
  st.download_button(
141
  label="Download data as CSV",
142
  data=csv,
143
- file_name=f"{summarizer_option}_df.csv",
144
  mime="text/csv",
145
  )
 
146
  except KeyError:
147
  st.error(
148
  "Please Make sure that your data must have a column named text",
149
  icon="🚨",
150
  )
151
  st.info("Text column must have amazon reviews", icon="ℹ️")
 
 
 
4
  from stqdm import stqdm
5
  from simplet5 import SimpleT5
6
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
7
+ from transformers import BertTokenizer
8
+ from tensorflow.keras.models import load_model
9
+ from tensorflow.nn import softmax
10
+ import numpy as np
11
+ from datetime import datetime
12
+ import logging
13
+
14
+ date = datetime.now().strftime(r"%Y-%m-%d")
15
+ model_classes = {
16
+ 0: "Ads",
17
+ 1: "Apps",
18
+ 2: "Battery",
19
+ 3: "Charging",
20
+ 4: "Delivery",
21
+ 5: "Display",
22
+ 6: "FOS",
23
+ 7: "HW",
24
+ 8: "Order",
25
+ 9: "Refurb",
26
+ 10: "SD",
27
+ 11: "Setup",
28
+ 12: "Unknown",
29
+ 13: "WiFi",
30
+ }
31
+
32
+
33
+ @st.cache_resource
34
  def load_t5():
35
  model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
36
 
37
  tokenizer = AutoTokenizer.from_pretrained("t5-base")
38
+ st.success("Loaded T5 Model")
39
  return model, tokenizer
40
 
41
 
42
+ @st.cache_resource
43
  def custom_model():
44
+ st.success("Loaded custom model")
45
  return pipeline("summarization", model="my_awesome_sum/")
46
 
47
 
 
51
  return df.to_csv(index=False).encode("utf-8")
52
 
53
 
54
+ @st.cache_resource
55
  def load_one_line_summarizer(model):
56
+ st.success("Loaded one line summarizer")
57
  return model.load_model("t5", "snrspeaks/t5-one-line-summary")
58
 
59
 
60
+ @st.cache_resource
61
+ def classify_category():
62
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
63
+ new_model = load_model("model")
64
+ st.success("Loaded custom classification model")
65
+ return tokenizer, new_model
66
+
67
+
68
  st.set_page_config(layout="wide", page_title="Amazon Review Summarizer")
69
  st.title("Amazon Review Summarizer")
70
 
 
73
  "Select Summarizer",
74
  ("Custom trained on the dataset", "t5-base", "t5-one-line-summary"),
75
  )
76
+ classification = st.checkbox("Classify Category", value=True)
77
  hide_streamlit_style = """
78
  <style>
79
  #MainMenu {visibility: hidden;}
 
99
  text = df["text"].values.tolist()
100
  if summarizer_option == "Custom trained on the dataset":
101
  model = custom_model()
102
+
 
103
  progress_text = "Summarization in progress. Please wait."
104
  summary = []
105
 
 
117
  output = pd.DataFrame(
118
  {"text": df["text"].values.tolist(), "summary": summary}
119
  )
120
+ if classification:
121
+ classification_token, classification_model = classify_category()
122
+ tf_batch = classification_token(
123
+ text,
124
+ max_length=128,
125
+ padding=True,
126
+ truncation=True,
127
+ return_tensors="tf",
128
+ )
129
+ with st.spinner(text="identifying theme"):
130
+ tf_outputs = classification_model(tf_batch)
131
+ classes = []
132
+ with st.spinner(text="creating output file"):
133
+ for x in stqdm(range(len(text))):
134
+ tf_o = softmax(tf_outputs["logits"][x], axis=-1)
135
+ label = np.argmax(tf_o, axis=0)
136
+ keys = model_classes
137
+ classes.append(keys.get(label))
138
+ output["category"] = classes
139
+
140
  csv = convert_df(output)
141
  st.download_button(
142
  label="Download data as CSV",
143
  data=csv,
144
+ file_name=f"{summarizer_option}_{date}_df.csv",
145
  mime="text/csv",
146
  )
147
  if summarizer_option == "t5-base":
 
170
  output = pd.DataFrame(
171
  {"text": df["text"].values.tolist(), "summary": summary}
172
  )
173
+ if classification:
174
+ classification_token, classification_model = classify_category()
175
+ tf_batch = classification_token(
176
+ text,
177
+ max_length=128,
178
+ padding=True,
179
+ truncation=True,
180
+ return_tensors="tf",
181
+ )
182
+ with st.spinner(text="identifying theme"):
183
+ tf_outputs = classification_model(tf_batch)
184
+ classes = []
185
+ with st.spinner(text="creating output file"):
186
+ for x in stqdm(range(len(text))):
187
+ tf_o = tf.nn.softmax(tf_outputs["logits"][x], axis=-1)
188
+ label = np.argmax(tf_o, axis=0)
189
+ keys = model_classes
190
+ classes.append(keys.get(label))
191
+ output["category"] = classes
192
+
193
  csv = convert_df(output)
194
  st.download_button(
195
  label="Download data as CSV",
196
  data=csv,
197
+ file_name=f"{summarizer_option}_{date}_df.csv",
198
  mime="text/csv",
199
  )
200
 
 
211
  output = pd.DataFrame(
212
  {"text": df["text"].values.tolist(), "summary": summary}
213
  )
214
+ if classification:
215
+ classification_token, classification_model = classify_category()
216
+ tf_batch = classification_token(
217
+ text,
218
+ max_length=128,
219
+ padding=True,
220
+ truncation=True,
221
+ return_tensors="tf",
222
+ )
223
+ with st.spinner(text="identifying theme"):
224
+ tf_outputs = classification_model(tf_batch)
225
+ classes = []
226
+ with st.spinner(text="creating output file"):
227
+ for x in stqdm(range(len(text))):
228
+ tf_o = tf.nn.softmax(tf_outputs["logits"][x], axis=-1)
229
+ label = np.argmax(tf_o, axis=0)
230
+ keys = model_classes
231
+ classes.append(keys.get(label))
232
+ output["category"] = classes
233
+
234
  csv = convert_df(output)
235
  st.download_button(
236
  label="Download data as CSV",
237
  data=csv,
238
+ file_name=f"{summarizer_option}_{date}_df.csv",
239
  mime="text/csv",
240
  )
241
+
242
  except KeyError:
243
  st.error(
244
  "Please Make sure that your data must have a column named text",
245
  icon="🚨",
246
  )
247
  st.info("Text column must have amazon reviews", icon="ℹ️")
248
+ except BaseException as e:
249
+ logging.exception("An exception was occured")