arnabk1 commited on
Commit
dc44a45
·
verified ·
1 Parent(s): bba39c8
Files changed (3) hide show
  1. app.py +427 -0
  2. embeddings_50d_temp.npy +3 -0
  3. word_index_dict_50d_temp.pkl +3 -0
app.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ import numpy.linalg as la
4
+ import pickle
5
+ import os
6
+ import gdown
7
+ from sentence_transformers import SentenceTransformer
8
+ import matplotlib.pyplot as plt
9
+ import math
10
+
11
+
12
+ # Compute Cosine Similarity
13
+ def cosine_similarity(x, y):
14
+ """
15
+ Exponentiated cosine similarity
16
+ 1. Compute cosine similarity
17
+ 2. Exponentiate cosine similarity
18
+ 3. Return exponentiated cosine similarity
19
+ (20 pts)
20
+ """
21
+ ##################################
22
+ ### TODO: Add code here ##########
23
+ ##################################
24
+ # Compute dot product
25
+ dot_product = np.dot(x, y)
26
+
27
+ # Compute magnitudes of the vectors
28
+ magnitude_x = np.linalg.norm(x)
29
+ magnitude_y = np.linalg.norm(y)
30
+
31
+ # Compute cosine similarity
32
+ similarity = dot_product / (magnitude_x * magnitude_y)
33
+
34
+ # Exponentiate cosine similarity
35
+ exponentiated_similarity = np.exp(similarity)
36
+
37
+ return exponentiated_similarity
38
+
39
+
40
+ # Function to Load Glove Embeddings
41
+ def load_glove_embeddings(glove_path="Data/embeddings.pkl"):
42
+ with open(glove_path, "rb") as f:
43
+ embeddings_dict = pickle.load(f, encoding="latin1")
44
+
45
+ return embeddings_dict
46
+
47
+
48
+ def get_model_id_gdrive(model_type):
49
+ if model_type == "25d":
50
+ word_index_id = "13qMXs3-oB9C6kfSRMwbAtzda9xuAUtt8"
51
+ embeddings_id = "1-RXcfBvWyE-Av3ZHLcyJVsps0RYRRr_2"
52
+ elif model_type == "50d":
53
+ embeddings_id = "1DBaVpJsitQ1qxtUvV1Kz7ThDc3az16kZ"
54
+ word_index_id = "1rB4ksHyHZ9skes-fJHMa2Z8J1Qa7awQ9"
55
+ elif model_type == "100d":
56
+ word_index_id = "1-oWV0LqG3fmrozRZ7WB1jzeTJHRUI3mq"
57
+ embeddings_id = "1SRHfX130_6Znz7zbdfqboKosz-PfNvNp"
58
+
59
+ return word_index_id, embeddings_id
60
+
61
+
62
+ def download_glove_embeddings_gdrive(model_type):
63
+ # Get glove embeddings from google drive
64
+ word_index_id, embeddings_id = get_model_id_gdrive(model_type)
65
+
66
+ # Use gdown to get files from google drive
67
+ embeddings_temp = "embeddings_" + str(model_type) + "_temp.npy"
68
+ word_index_temp = "word_index_dict_" + str(model_type) + "_temp.pkl"
69
+
70
+ # Download word_index pickle file
71
+ print("Downloading word index dictionary....\n")
72
+ gdown.download(id=word_index_id, output=word_index_temp, quiet=False)
73
+
74
+ # Download embeddings numpy file
75
+ print("Donwloading embedings...\n\n")
76
+ gdown.download(id=embeddings_id, output=embeddings_temp, quiet=False)
77
+
78
+
79
+ # @st.cache_data()
80
+ def load_glove_embeddings_gdrive(model_type):
81
+ word_index_temp = "word_index_dict_" + str(model_type) + "_temp.pkl"
82
+ embeddings_temp = "embeddings_" + str(model_type) + "_temp.npy"
83
+
84
+ # Load word index dictionary
85
+ word_index_dict = pickle.load(open(word_index_temp, "rb"), encoding="latin")
86
+
87
+ # Load embeddings numpy
88
+ embeddings = np.load(embeddings_temp)
89
+
90
+ return word_index_dict, embeddings
91
+
92
+
93
+ @st.cache_resource()
94
+ def load_sentence_transformer_model(model_name):
95
+ sentenceTransformer = SentenceTransformer(model_name)
96
+ return sentenceTransformer
97
+
98
+
99
+ def get_sentence_transformer_embeddings(sentence, model_name="all-MiniLM-L6-v2"):
100
+ """
101
+ Get sentence transformer embeddings for a sentence
102
+ """
103
+ # 384 dimensional embedding
104
+ # Default model: https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2
105
+
106
+ sentenceTransformer = load_sentence_transformer_model(model_name)
107
+
108
+ try:
109
+ return sentenceTransformer.encode(sentence)
110
+ except:
111
+ if model_name == "all-MiniLM-L6-v2":
112
+ return np.zeros(384)
113
+ else:
114
+ return np.zeros(512)
115
+
116
+
117
+ def get_glove_embeddings(word, word_index_dict, embeddings, model_type):
118
+ """
119
+ Get glove embedding for a single word
120
+ """
121
+ if word.lower() in word_index_dict:
122
+ return embeddings[word_index_dict[word.lower()]]
123
+ else:
124
+ return np.zeros(int(model_type.split("d")[0]))
125
+
126
+
127
+ def averaged_glove_embeddings_gdrive(sentence, word_index_dict, embeddings, model_type=50):
128
+ """
129
+ Get averaged glove embeddings for a sentence
130
+ 1. Split sentence into words
131
+ 2. Get embeddings for each word
132
+ 3. Add embeddings for each word
133
+ 4. Divide by number of words
134
+ 5. Return averaged embeddings
135
+ (30 pts)
136
+ """
137
+ embedding = np.zeros(int(model_type.split("d")[0]))
138
+ ##################################
139
+ ##### TODO: Add code here ########
140
+ ##################################
141
+ words = sentence.split()
142
+ # total_embedding = np.zeros(len(word_index_dict[0]))
143
+
144
+ for word in words:
145
+ if word.lower() in word_index_dict.keys():
146
+ embedding += get_glove_embeddings(word.lower(), word_index_dict, embeddings, model_type)
147
+
148
+ if len(words) > 0:
149
+ averaged_embedding = embedding / len(words)
150
+ else:
151
+ averaged_embedding = embedding
152
+
153
+ return averaged_embedding
154
+
155
+ def get_category_embeddings(embeddings_metadata):
156
+ """
157
+ Get embeddings for each category
158
+ 1. Split categories into words
159
+ 2. Get embeddings for each word
160
+ """
161
+ model_name = embeddings_metadata["model_name"]
162
+ st.session_state["cat_embed_" + model_name] = {}
163
+ for category in st.session_state.categories.split(" "):
164
+ if model_name:
165
+ if not category in st.session_state["cat_embed_" + model_name]:
166
+ st.session_state["cat_embed_" + model_name][category] = get_sentence_transformer_embeddings(category, model_name=model_name)
167
+ else:
168
+ if not category in st.session_state["cat_embed_" + model_name]:
169
+ st.session_state["cat_embed_" + model_name][category] = get_sentence_transformer_embeddings(category)
170
+
171
+
172
+ def update_category_embeddings(embedings_metadata):
173
+ """
174
+ Update embeddings for each category
175
+ """
176
+ get_category_embeddings(embeddings_metadata)
177
+
178
+
179
+ def get_sorted_cosine_similarity(_, embeddings_metadata):
180
+ """
181
+ Get sorted cosine similarity between input sentence and categories
182
+ Steps:
183
+ 1. Get embeddings for input sentence
184
+ 2. Get embeddings for categories (if not found, update category embeddings)
185
+ 3. Compute cosine similarity between input sentence and categories
186
+ 4. Sort cosine similarity
187
+ 5. Return sorted cosine similarity
188
+ (50 pts)
189
+ """
190
+ categories = st.session_state.categories.split(" ")
191
+ cosine_sim = {}
192
+ if embeddings_metadata["embedding_model"] == "glove":
193
+ word_index_dict = embeddings_metadata["word_index_dict"]
194
+ embeddings = embeddings_metadata["embeddings"]
195
+ model_type = embeddings_metadata["model_type"]
196
+
197
+ input_embedding = averaged_glove_embeddings_gdrive(st.session_state.text_search,
198
+ word_index_dict,
199
+ embeddings, model_type)
200
+
201
+ ##########################################
202
+ ## TODO: Get embeddings for categories ###
203
+ ##########################################
204
+
205
+ if categories != None:
206
+ # Get and compute embeddings for each category
207
+ for index, category in enumerate(categories):
208
+ # category_embeddings.append(averaged_glove_embeddings_gdrive(category,word_index_dict,embeddings, model_type))
209
+ # if category not in cosine_sim:
210
+ cosine_sim[index] = cosine_similarity(input_embedding, embeddings[index])
211
+
212
+
213
+ else:
214
+ model_name = embeddings_metadata["model_name"]
215
+ if not "cat_embed_" + model_name in st.session_state:
216
+ get_category_embeddings(embeddings_metadata)
217
+
218
+ category_embeddings = st.session_state["cat_embed_" + model_name]
219
+
220
+ print("text_search = ", st.session_state.text_search)
221
+ if model_name:
222
+ input_embedding = get_sentence_transformer_embeddings(st.session_state.text_search, model_name=model_name)
223
+ else:
224
+ input_embedding = get_sentence_transformer_embeddings(st.session_state.text_search)
225
+ for index, category in enumerate(categories):
226
+ ##########################################
227
+ # TODO: Compute cosine similarity between input sentence and categories
228
+ # TODO: Update category embeddings if category not found
229
+ ##########################################
230
+ category_embedding = category_embeddings[category]
231
+
232
+ cosine_sim[index] = cosine_similarity(input_embedding, category_embedding)
233
+
234
+ # Sort cosine similarities in descending order
235
+ sorted_cosine_sim = sorted(cosine_sim.items(), key=lambda x: x[1], reverse=True)
236
+
237
+
238
+ return sorted_cosine_sim
239
+
240
+
241
+ def plot_piechart(sorted_cosine_scores_items):
242
+ sorted_cosine_scores = np.array([
243
+ sorted_cosine_scores_items[index][1]
244
+ for index in range(len(sorted_cosine_scores_items))
245
+ ]
246
+ )
247
+ categories = st.session_state.categories.split(" ")
248
+ categories_sorted = [
249
+ categories[sorted_cosine_scores_items[index][0]]
250
+ for index in range(len(sorted_cosine_scores_items))
251
+ ]
252
+ fig, ax = plt.subplots()
253
+ ax.pie(sorted_cosine_scores, labels=categories_sorted, autopct="%1.1f%%")
254
+ st.pyplot(fig) # Figure
255
+
256
+
257
+ def plot_piechart_helper(sorted_cosine_scores_items):
258
+ sorted_cosine_scores = np.array(
259
+ [
260
+ sorted_cosine_scores_items[index][1]
261
+ for index in range(len(sorted_cosine_scores_items))
262
+ ]
263
+ )
264
+ categories = st.session_state.categories.split(" ")
265
+ categories_sorted = [categories[sorted_cosine_scores_items[index][0]] for index in range(len(sorted_cosine_scores_items)) ]
266
+ fig, ax = plt.subplots(figsize=(3, 3))
267
+ my_explode = np.zeros(len(categories_sorted))
268
+ my_explode[0] = 0.2
269
+ if len(categories_sorted) == 3:
270
+ my_explode[1] = 0.1 # explode this by 0.2
271
+ elif len(categories_sorted) > 3:
272
+ my_explode[2] = 0.05
273
+ ax.pie(
274
+ sorted_cosine_scores,
275
+ labels=categories_sorted,
276
+ autopct="%1.1f%%",
277
+ explode=my_explode,
278
+ )
279
+
280
+ return fig
281
+
282
+
283
+ def plot_piecharts(sorted_cosine_scores_models):
284
+ scores_list = []
285
+ categories = st.session_state.categories.split(" ")
286
+ index = 0
287
+ for model in sorted_cosine_scores_models:
288
+ scores_list.append(sorted_cosine_scores_models[model])
289
+ # scores_list[index] = np.array([scores_list[index][ind2][1] for ind2 in range(len(scores_list[index]))])
290
+ index += 1
291
+
292
+ if len(sorted_cosine_scores_models) == 2:
293
+ fig, (ax1, ax2) = plt.subplots(2)
294
+
295
+ categories_sorted = [
296
+ categories[scores_list[0][index][0]] for index in range(len(scores_list[0]))
297
+ ]
298
+ sorted_scores = np.array(
299
+ [scores_list[0][index][1] for index in range(len(scores_list[0]))]
300
+ )
301
+ ax1.pie(sorted_scores, labels=categories_sorted, autopct="%1.1f%%")
302
+
303
+ categories_sorted = [
304
+ categories[scores_list[1][index][0]] for index in range(len(scores_list[1]))
305
+ ]
306
+ sorted_scores = np.array(
307
+ [scores_list[1][index][1] for index in range(len(scores_list[1]))]
308
+ )
309
+ ax2.pie(sorted_scores, labels=categories_sorted, autopct="%1.1f%%")
310
+
311
+ st.pyplot(fig)
312
+
313
+
314
+ def plot_alatirchart(sorted_cosine_scores_models):
315
+ models = list(sorted_cosine_scores_models.keys())
316
+ tabs = st.tabs(models)
317
+ figs = {}
318
+ for model in models:
319
+ figs[model] = plot_piechart_helper(sorted_cosine_scores_models[model])
320
+
321
+ for index in range(len(tabs)):
322
+ with tabs[index]:
323
+ st.pyplot(figs[models[index]])
324
+
325
+
326
+ ### Text Search ###
327
+ st.sidebar.title("GloVe Twitter")
328
+ st.sidebar.markdown(
329
+ """
330
+ GloVe is an unsupervised learning algorithm for obtaining vector representations for words. Pretrained on
331
+ 2 billion tweets with vocabulary size of 1.2 million. Download from [Stanford NLP](http://nlp.stanford.edu/data/glove.twitter.27B.zip).
332
+
333
+ Jeffrey Pennington, Richard Socher, and Christopher D. Manning. 2014. *GloVe: Global Vectors for Word Representation*.
334
+ """
335
+ )
336
+
337
+ model_type = st.sidebar.selectbox("Choose the model", ("25d", "50d"), index=1)
338
+
339
+
340
+ st.title("Search Based Retrieval Demo")
341
+ st.subheader(
342
+ "Pass in space separated categories you want this search demo to be about."
343
+ )
344
+ # st.selectbox(label="Pick the categories you want this search demo to be about...",
345
+ # options=("Flowers Colors Cars Weather Food", "Chocolate Milk", "Anger Joy Sad Frustration Worry Happiness", "Positive Negative"),
346
+ # key="categories"
347
+ # )
348
+ st.text_input(
349
+ label="Categories", key="categories", value="Flowers Colors Cars Weather Food"
350
+ )
351
+ print(st.session_state["categories"])
352
+ print(type(st.session_state["categories"]))
353
+ # print("Categories = ", categories)
354
+ # st.session_state.categories = categories
355
+
356
+ st.subheader("Pass in an input word or even a sentence")
357
+ text_search = st.text_input(
358
+ label="Input your sentence",
359
+ key="text_search",
360
+ value="Roses are red, trucks are blue, and Seattle is grey right now",
361
+ )
362
+ # st.session_state.text_search = text_search
363
+
364
+ # Download glove embeddings if it doesn't exist
365
+ embeddings_path = "embeddings_" + str(model_type) + "_temp.npy"
366
+ word_index_dict_path = "word_index_dict_" + str(model_type) + "_temp.pkl"
367
+ if not os.path.isfile(embeddings_path) or not os.path.isfile(word_index_dict_path):
368
+ print("Model type = ", model_type)
369
+ glove_path = "Data/glove_" + str(model_type) + ".pkl"
370
+ print("glove_path = ", glove_path)
371
+
372
+ # Download embeddings from google drive
373
+ with st.spinner("Downloading glove embeddings..."):
374
+ download_glove_embeddings_gdrive(model_type)
375
+
376
+
377
+ # Load glove embeddings
378
+ word_index_dict, embeddings = load_glove_embeddings_gdrive(model_type)
379
+
380
+
381
+ # Find closest word to an input word
382
+ if st.session_state.text_search:
383
+ # Glove embeddings
384
+ print("Glove Embedding")
385
+ embeddings_metadata = {
386
+ "embedding_model": "glove",
387
+ "word_index_dict": word_index_dict,
388
+ "embeddings": embeddings,
389
+ "model_type": model_type,
390
+ }
391
+ with st.spinner("Obtaining Cosine similarity for Glove..."):
392
+ sorted_cosine_sim_glove = get_sorted_cosine_similarity(
393
+ st.session_state.text_search, embeddings_metadata
394
+ )
395
+
396
+ # Sentence transformer embeddings
397
+ print("Sentence Transformer Embedding")
398
+ embeddings_metadata = {"embedding_model": "transformers", "model_name": ""}
399
+ with st.spinner("Obtaining Cosine similarity for 384d sentence transformer..."):
400
+ sorted_cosine_sim_transformer = get_sorted_cosine_similarity(
401
+ st.session_state.text_search, embeddings_metadata
402
+ )
403
+
404
+ # Results and Plot Pie Chart for Glove
405
+ print("Categories are: ", st.session_state.categories)
406
+ st.subheader(
407
+ "Closest word I have between: "
408
+ + st.session_state.categories
409
+ + " as per different Embeddings"
410
+ )
411
+
412
+ print(sorted_cosine_sim_glove)
413
+ print(sorted_cosine_sim_transformer)
414
+ # print(sorted_distilbert)
415
+ # Altair Chart for all models
416
+ plot_alatirchart(
417
+ {
418
+ "glove_" + str(model_type): sorted_cosine_sim_glove,
419
+ "sentence_transformer_384": sorted_cosine_sim_transformer,
420
+ }
421
+ )
422
+ # "distilbert_512": sorted_distilbert})
423
+
424
+ st.write("")
425
+ st.write(
426
+ "Demo developed by [Dr. Karthik Mohan](https://www.linkedin.com/in/karthik-mohan-72a4b323/)"
427
+ )
embeddings_50d_temp.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e74f88cde3ff2e36c815d13955c67983cf6f81829d2582cb6789c10786e5ef66
3
+ size 477405680
word_index_dict_50d_temp.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:674af352f703098ef122f6a8db7c5e08c5081829d49daea32e5aeac1fe582900
3
+ size 60284151