sarahselim18 commited on
Commit
c37b1f1
·
1 Parent(s): f86ad12

file creation

Browse files
Files changed (1) hide show
  1. app.py +434 -0
app.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ import numpy.linalg as la
4
+ import pickle
5
+ import os
6
+ #import gdown
7
+ #from sentence_transformers import SentenceTransformer
8
+ import matplotlib.pyplot as plt
9
+ import math
10
+
11
+
12
+ # Compute Cosine Similarity
13
+ def cosine_similarity(x, y):
14
+ """
15
+ Exponentiated cosine similarity
16
+ 1. Compute cosine similarity
17
+ 2. Exponentiate cosine similarity
18
+ 3. Return exponentiated cosine similarity
19
+ (20 pts)
20
+ """
21
+ ##################################
22
+ ### TODO: Add code here ##########
23
+ cos_sim = np.dot(x,y)/(np.linalg.norm(x)*np.linalg.norm(y))
24
+ exp_cos = 0 ######## find formula for exponentiated cosine similarity
25
+
26
+ ##################################
27
+ return exp_cos
28
+
29
+
30
+ # Function to Load Glove Embeddings
31
+ def load_glove_embeddings(glove_path="Data/embeddings.pkl"):
32
+ with open(glove_path, "rb") as f:
33
+ embeddings_dict = pickle.load(f, encoding="latin1")
34
+
35
+ return embeddings_dict
36
+
37
+
38
+ def get_model_id_gdrive(model_type):
39
+ if model_type == "25d":
40
+ word_index_id = "13qMXs3-oB9C6kfSRMwbAtzda9xuAUtt8"
41
+ embeddings_id = "1-RXcfBvWyE-Av3ZHLcyJVsps0RYRRr_2"
42
+ elif model_type == "50d":
43
+ embeddings_id = "1DBaVpJsitQ1qxtUvV1Kz7ThDc3az16kZ"
44
+ word_index_id = "1rB4ksHyHZ9skes-fJHMa2Z8J1Qa7awQ9"
45
+ elif model_type == "100d":
46
+ word_index_id = "1-oWV0LqG3fmrozRZ7WB1jzeTJHRUI3mq"
47
+ embeddings_id = "1SRHfX130_6Znz7zbdfqboKosz-PfNvNp"
48
+
49
+ return word_index_id, embeddings_id
50
+
51
+
52
+ def download_glove_embeddings_gdrive(model_type):
53
+ # Get glove embeddings from google drive
54
+ word_index_id, embeddings_id = get_model_id_gdrive(model_type)
55
+
56
+ # Use gdown to get files from google drive
57
+ embeddings_temp = "embeddings_" + str(model_type) + "_temp.npy"
58
+ word_index_temp = "word_index_dict_" + str(model_type) + "_temp.pkl"
59
+
60
+ # Download word_index pickle file
61
+ print("Downloading word index dictionary....\n")
62
+ gdown.download(id=word_index_id, output=word_index_temp, quiet=False)
63
+
64
+ # Download embeddings numpy file
65
+ print("Donwloading embedings...\n\n")
66
+ gdown.download(id=embeddings_id, output=embeddings_temp, quiet=False)
67
+
68
+
69
+ @st.cache_data()
70
+ def load_glove_embeddings_gdrive(model_type):
71
+ word_index_temp = "word_index_dict_" + str(model_type) + "_temp.pkl"
72
+ embeddings_temp = "embeddings_" + str(model_type) + "_temp.npy"
73
+
74
+ # Load word index dictionary
75
+ word_index_dict = pickle.load(open(word_index_temp, "rb"), encoding="latin")
76
+
77
+ # Load embeddings numpy
78
+ embeddings = np.load(embeddings_temp)
79
+
80
+ return word_index_dict, embeddings
81
+
82
+
83
+ @st.cache_resource()
84
+ def load_sentence_transformer_model(model_name):
85
+ sentenceTransformer = SentenceTransformer(model_name)
86
+ return sentenceTransformer
87
+
88
+
89
+ def get_sentence_transformer_embeddings(sentence, model_name="all-MiniLM-L6-v2"):
90
+ """
91
+ Get sentence transformer embeddings for a sentence
92
+ """
93
+ # 384 dimensional embedding
94
+ # Default model: https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2
95
+
96
+ sentenceTransformer = load_sentence_transformer_model(model_name)
97
+
98
+ try:
99
+ return sentenceTransformer.encode(sentence)
100
+ except:
101
+ if model_name == "all-MiniLM-L6-v2":
102
+ return np.zeros(384)
103
+ else:
104
+ return np.zeros(512)
105
+
106
+
107
+ def get_glove_embeddings(word, word_index_dict, embeddings, model_type):
108
+ """
109
+ Get glove embedding for a single word
110
+ """
111
+ if word.lower() in word_index_dict:
112
+ return embeddings[word_index_dict[word.lower()]]
113
+ else:
114
+ return np.zeros(int(model_type.split("d")[0]))
115
+
116
+
117
+ def averaged_glove_embeddings_gdrive(sentence, word_index_dict, embeddings, model_type=50):
118
+ """
119
+ Get averaged glove embeddings for a sentence
120
+ 1. Split sentence into words
121
+ 2. Get embeddings for each word
122
+ 3. Add embeddings for each word
123
+ 4. Divide by number of words
124
+ 5. Return averaged embeddings
125
+ (30 pts)
126
+ """
127
+ embedding = np.zeros(int(model_type.split("d")[0]))
128
+ ##################################
129
+ ##### TODO: Add code here ########
130
+
131
+ #glove_word_set= load_glove_embeddings_gdrive(model_type)
132
+
133
+ for word in sentence:
134
+ #print(sentence)
135
+ words = [word.strip('.,?!').lower() for word in sentence.split()]
136
+ total = 0
137
+ for w in words:
138
+ if w in embeddings:
139
+ embed += embeddings[w]
140
+ total +=1
141
+ if total != 0:
142
+ embed = embed/total
143
+
144
+ return embed
145
+ ##################################
146
+
147
+
148
+ def get_category_embeddings(embeddings_metadata):
149
+ """
150
+ Get embeddings for each category
151
+ 1. Split categories into words
152
+ 2. Get embeddings for each word
153
+ """
154
+ model_name = embeddings_metadata["model_name"]
155
+ st.session_state["cat_embed_" + model_name] = {}
156
+ for category in st.session_state.categories.split(" "):
157
+ if model_name:
158
+ if not category in st.session_state["cat_embed_" + model_name]:
159
+ st.session_state["cat_embed_" + model_name][category] = get_sentence_transformer_embeddings(category, model_name=model_name)
160
+ else:
161
+ if not category in st.session_state["cat_embed_" + model_name]:
162
+ st.session_state["cat_embed_" + model_name][category] = get_sentence_transformer_embeddings(category)
163
+
164
+
165
+ def update_category_embeddings(embedings_metadata):
166
+ """
167
+ Update embeddings for each category
168
+ """
169
+ get_category_embeddings(embeddings_metadata)
170
+
171
+
172
+ def get_sorted_cosine_similarity(embeddings_metadata):
173
+ """
174
+ Get sorted cosine similarity between input sentence and categories
175
+ Steps:
176
+ 1. Get embeddings for input sentence
177
+ 2. Get embeddings for categories (if not found, update category embeddings)
178
+ 3. Compute cosine similarity between input sentence and categories
179
+ 4. Sort cosine similarity
180
+ 5. Return sorted cosine similarity
181
+ (50 pts)
182
+ """
183
+ categories = st.session_state.categories.split(" ")
184
+ cosine_sim = {}
185
+ if embeddings_metadata["embedding_model"] == "glove":
186
+ word_index_dict = embeddings_metadata["word_index_dict"]
187
+ embeddings = embeddings_metadata["embeddings"]
188
+ model_type = embeddings_metadata["model_type"]
189
+
190
+ input_embedding = averaged_glove_embeddings_gdrive(st.session_state.text_search,
191
+ word_index_dict,
192
+ embeddings, model_type)
193
+
194
+ ##########################################
195
+ ## TODO: Get embeddings for categories ###
196
+ cat_embed = []
197
+ for cat in categories:
198
+ cat_embed.append(get_category_embeddings(categories))
199
+
200
+
201
+ ##########################################
202
+
203
+ else:
204
+ model_name = embeddings_metadata["model_name"]
205
+ if not "cat_embed_" + model_name in st.session_state:
206
+ get_category_embeddings(embeddings_metadata)
207
+
208
+ category_embeddings = st.session_state["cat_embed_" + model_name]
209
+
210
+ print("text_search = ", st.session_state.text_search)
211
+ if model_name:
212
+ input_embedding = get_sentence_transformer_embeddings(st.session_state.text_search, model_name=model_name)
213
+ else:
214
+ input_embedding = get_sentence_transformer_embeddings(st.session_state.text_search)
215
+ #for index in range(len(categories)):
216
+ #pass
217
+ ##########################################
218
+ # TODO: Compute cosine similarity between input sentence and categories
219
+
220
+ cat_scores = []
221
+ cat_idx = 0
222
+ for cat_embed in category_embeddings:
223
+ # Calc cosine sim
224
+ cat_scores.append((cat_idx, np.dot(input,cat_embed)))
225
+ # Store doc_id and score as a tuple
226
+ cat_idx +=1
227
+
228
+
229
+ sorted_list = sorted(cat_scores, key=lambda x: x[1])
230
+
231
+ sorted_cats = [element[0] for element in sorted_list]
232
+
233
+ #flip sorting order
234
+ sorted_cats = sorted_cats[::-1]
235
+ # Add list to Map
236
+ result = sorted_cats[0]
237
+ selected_cat = categories[result]
238
+
239
+ # TODO: Update category embeddings if category not found
240
+ ##########################################
241
+
242
+ return
243
+
244
+
245
+ def plot_piechart(sorted_cosine_scores_items):
246
+ sorted_cosine_scores = np.array([
247
+ sorted_cosine_scores_items[index][1]
248
+ for index in range(len(sorted_cosine_scores_items))
249
+ ]
250
+ )
251
+ categories = st.session_state.categories.split(" ")
252
+ categories_sorted = [
253
+ categories[sorted_cosine_scores_items[index][0]]
254
+ for index in range(len(sorted_cosine_scores_items))
255
+ ]
256
+ fig, ax = plt.subplots()
257
+ ax.pie(sorted_cosine_scores, labels=categories_sorted, autopct="%1.1f%%")
258
+ st.pyplot(fig) # Figure
259
+
260
+
261
+ def plot_piechart_helper(sorted_cosine_scores_items):
262
+ sorted_cosine_scores = np.array(
263
+ [
264
+ sorted_cosine_scores_items[index][1]
265
+ for index in range(len(sorted_cosine_scores_items))
266
+ ]
267
+ )
268
+ categories = st.session_state.categories.split(" ")
269
+ categories_sorted = [
270
+ categories[sorted_cosine_scores_items[index][0]]
271
+ for index in range(len(sorted_cosine_scores_items))
272
+ ]
273
+ fig, ax = plt.subplots(figsize=(3, 3))
274
+ my_explode = np.zeros(len(categories_sorted))
275
+ my_explode[0] = 0.2
276
+ if len(categories_sorted) == 3:
277
+ my_explode[1] = 0.1 # explode this by 0.2
278
+ elif len(categories_sorted) > 3:
279
+ my_explode[2] = 0.05
280
+ ax.pie(
281
+ sorted_cosine_scores,
282
+ labels=categories_sorted,
283
+ autopct="%1.1f%%",
284
+ explode=my_explode,
285
+ )
286
+
287
+ return fig
288
+
289
+
290
+ def plot_piecharts(sorted_cosine_scores_models):
291
+ scores_list = []
292
+ categories = st.session_state.categories.split(" ")
293
+ index = 0
294
+ for model in sorted_cosine_scores_models:
295
+ scores_list.append(sorted_cosine_scores_models[model])
296
+ # scores_list[index] = np.array([scores_list[index][ind2][1] for ind2 in range(len(scores_list[index]))])
297
+ index += 1
298
+
299
+ if len(sorted_cosine_scores_models) == 2:
300
+ fig, (ax1, ax2) = plt.subplots(2)
301
+
302
+ categories_sorted = [
303
+ categories[scores_list[0][index][0]] for index in range(len(scores_list[0]))
304
+ ]
305
+ sorted_scores = np.array(
306
+ [scores_list[0][index][1] for index in range(len(scores_list[0]))]
307
+ )
308
+ ax1.pie(sorted_scores, labels=categories_sorted, autopct="%1.1f%%")
309
+
310
+ categories_sorted = [
311
+ categories[scores_list[1][index][0]] for index in range(len(scores_list[1]))
312
+ ]
313
+ sorted_scores = np.array(
314
+ [scores_list[1][index][1] for index in range(len(scores_list[1]))]
315
+ )
316
+ ax2.pie(sorted_scores, labels=categories_sorted, autopct="%1.1f%%")
317
+
318
+ st.pyplot(fig)
319
+
320
+
321
+ def plot_alatirchart(sorted_cosine_scores_models):
322
+ models = list(sorted_cosine_scores_models.keys())
323
+ tabs = st.tabs(models)
324
+ figs = {}
325
+ for model in models:
326
+ figs[model] = plot_piechart_helper(sorted_cosine_scores_models[model])
327
+
328
+ for index in range(len(tabs)):
329
+ with tabs[index]:
330
+ st.pyplot(figs[models[index]])
331
+
332
+
333
+ ### Text Search ###
334
+ st.sidebar.title("GloVe Twitter")
335
+ st.sidebar.markdown(
336
+ """
337
+ GloVe is an unsupervised learning algorithm for obtaining vector representations for words. Pretrained on
338
+ 2 billion tweets with vocabulary size of 1.2 million. Download from [Stanford NLP](http://nlp.stanford.edu/data/glove.twitter.27B.zip).
339
+
340
+ Jeffrey Pennington, Richard Socher, and Christopher D. Manning. 2014. *GloVe: Global Vectors for Word Representation*.
341
+ """
342
+ )
343
+
344
+ model_type = st.sidebar.selectbox("Choose the model", ("25d", "50d"), index=1)
345
+
346
+
347
+ st.title("Search Based Retrieval Demo")
348
+ st.subheader(
349
+ "Pass in space separated categories you want this search demo to be about."
350
+ )
351
+ # st.selectbox(label="Pick the categories you want this search demo to be about...",
352
+ # options=("Flowers Colors Cars Weather Food", "Chocolate Milk", "Anger Joy Sad Frustration Worry Happiness", "Positive Negative"),
353
+ # key="categories"
354
+ # )
355
+ st.text_input(
356
+ label="Categories", key="categories", value="Flowers Colors Cars Weather Food"
357
+ )
358
+ print(st.session_state["categories"])
359
+ print(type(st.session_state["categories"]))
360
+ # print("Categories = ", categories)
361
+ # st.session_state.categories = categories
362
+
363
+ st.subheader("Pass in an input word or even a sentence")
364
+ text_search = st.text_input(
365
+ label="Input your sentence",
366
+ key="text_search",
367
+ value="Roses are red, trucks are blue, and Seattle is grey right now",
368
+ )
369
+ # st.session_state.text_search = text_search
370
+
371
+ # Download glove embeddings if it doesn't exist
372
+ embeddings_path = "embeddings_" + str(model_type) + "_temp.npy"
373
+ word_index_dict_path = "word_index_dict_" + str(model_type) + "_temp.pkl"
374
+ if not os.path.isfile(embeddings_path) or not os.path.isfile(word_index_dict_path):
375
+ print("Model type = ", model_type)
376
+ glove_path = "Data/glove_" + str(model_type) + ".pkl"
377
+ print("glove_path = ", glove_path)
378
+
379
+ # Download embeddings from google drive
380
+ with st.spinner("Downloading glove embeddings..."):
381
+ download_glove_embeddings_gdrive(model_type)
382
+
383
+
384
+ # Load glove embeddings
385
+ word_index_dict, embeddings = load_glove_embeddings_gdrive(model_type)
386
+
387
+
388
+ # Find closest word to an input word
389
+ if st.session_state.text_search:
390
+ # Glove embeddings
391
+ print("Glove Embedding")
392
+ embeddings_metadata = {
393
+ "embedding_model": "glove",
394
+ "word_index_dict": word_index_dict,
395
+ "embeddings": embeddings,
396
+ "model_type": model_type,
397
+ }
398
+ with st.spinner("Obtaining Cosine similarity for Glove..."):
399
+ sorted_cosine_sim_glove = get_sorted_cosine_similarity(
400
+ st.session_state.text_search, embeddings_metadata
401
+ )
402
+
403
+ # Sentence transformer embeddings
404
+ print("Sentence Transformer Embedding")
405
+ embeddings_metadata = {"embedding_model": "transformers", "model_name": ""}
406
+ with st.spinner("Obtaining Cosine similarity for 384d sentence transformer..."):
407
+ sorted_cosine_sim_transformer = get_sorted_cosine_similarity(
408
+ st.session_state.text_search, embeddings_metadata
409
+ )
410
+
411
+ # Results and Plot Pie Chart for Glove
412
+ print("Categories are: ", st.session_state.categories)
413
+ st.subheader(
414
+ "Closest word I have between: "
415
+ + st.session_state.categories
416
+ + " as per different Embeddings"
417
+ )
418
+
419
+ print(sorted_cosine_sim_glove)
420
+ print(sorted_cosine_sim_transformer)
421
+ # print(sorted_distilbert)
422
+ # Altair Chart for all models
423
+ plot_alatirchart(
424
+ {
425
+ "glove_" + str(model_type): sorted_cosine_sim_glove,
426
+ "sentence_transformer_384": sorted_cosine_sim_transformer,
427
+ }
428
+ )
429
+ # "distilbert_512": sorted_distilbert})
430
+
431
+ st.write("")
432
+ st.write(
433
+ "Demo developed by [Dr. Karthik Mohan](https://www.linkedin.com/in/karthik-mohan-72a4b323/)"
434
+ )