ericlkc commited on
Commit
dc9d3f8
·
verified ·
1 Parent(s): a2bbddb

mini-proj1 part4

Browse files
embeddings_50d_temp.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e74f88cde3ff2e36c815d13955c67983cf6f81829d2582cb6789c10786e5ef66
3
+ size 477405680
group-p4.py ADDED
@@ -0,0 +1,473 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ # %load miniproject1_part4-2-1.py
4
+ import streamlit as st
5
+ import numpy as np
6
+ import numpy.linalg as la
7
+ import pickle
8
+ import os
9
+ import gdown
10
+ from sentence_transformers import SentenceTransformer
11
+ import matplotlib.pyplot as plt
12
+ import math
13
+
14
+
15
+ # Compute Cosine Similarity
16
+ def cosine_similarity(x, y):
17
+ """
18
+ Exponentiated cosine similarity
19
+ 1. Compute cosine similarity
20
+ 2. Exponentiate cosine similarity
21
+ 3. Return exponentiated cosine similarity
22
+ (20 pts)
23
+ """
24
+ # Compute cosine similarity
25
+ dot_product = np.dot(x, y)
26
+ norm_x = np.linalg.norm(x)
27
+ norm_y = np.linalg.norm(y)
28
+ cosine_sim = dot_product / (norm_x * norm_y)
29
+
30
+ # Exponentiate cosine similarity
31
+ exp_cosine_sim = np.exp(cosine_sim)
32
+
33
+ # Return exponentiated cosine similarity
34
+ return exp_cosine_sim
35
+
36
+
37
+ # Function to Load Glove Embeddings
38
+ def load_glove_embeddings(glove_path="Data/embeddings.pkl"):
39
+ with open(glove_path, "rb") as f:
40
+ embeddings_dict = pickle.load(f, encoding="latin1")
41
+
42
+ return embeddings_dict
43
+
44
+
45
+ def get_model_id_gdrive(model_type):
46
+ if model_type == "25d":
47
+ word_index_id = "13qMXs3-oB9C6kfSRMwbAtzda9xuAUtt8"
48
+ embeddings_id = "1-RXcfBvWyE-Av3ZHLcyJVsps0RYRRr_2"
49
+ elif model_type == "50d":
50
+ embeddings_id = "1DBaVpJsitQ1qxtUvV1Kz7ThDc3az16kZ"
51
+ word_index_id = "1rB4ksHyHZ9skes-fJHMa2Z8J1Qa7awQ9"
52
+ elif model_type == "100d":
53
+ word_index_id = "1-oWV0LqG3fmrozRZ7WB1jzeTJHRUI3mq"
54
+ embeddings_id = "1SRHfX130_6Znz7zbdfqboKosz-PfNvNp"
55
+
56
+ return word_index_id, embeddings_id
57
+
58
+
59
+ def download_glove_embeddings_gdrive(model_type):
60
+ # Get glove embeddings from google drive
61
+ word_index_id, embeddings_id = get_model_id_gdrive(model_type)
62
+
63
+ # Use gdown to get files from google drive
64
+
65
+ # 修改的
66
+ embeddings_temp = "embeddings_50d_temp.npy"
67
+ # embeddings_temp = "embeddings_" + str(model_type) + "_temp.npy"
68
+
69
+ # 修改的
70
+ word_index_temp = "word_index_dict_50d_temp.pkl"
71
+ # word_index_temp = "word_index_dict_" + str(model_type) + "_temp.pkl"
72
+
73
+ # Download word_index pickle file
74
+ print("Downloading word index dictionary....\n")
75
+ # gdown.download(id=word_index_id, output=word_index_temp, quiet=False)
76
+
77
+ # Download embeddings numpy file
78
+ print("Donwloading embedings...\n\n")
79
+ # gdown.download(id=embeddings_id, output=embeddings_temp, quiet=False)
80
+
81
+
82
+ # @st.cache_data()
83
+ def load_glove_embeddings_gdrive(model_type):
84
+ word_index_temp = "word_index_dict_" + str(model_type) + "_temp.pkl"
85
+ embeddings_temp = "embeddings_" + str(model_type) + "_temp.npy"
86
+
87
+ # Load word index dictionary
88
+ word_index_dict = pickle.load(open(word_index_temp, "rb"), encoding="latin")
89
+
90
+ # Load embeddings numpy
91
+ embeddings = np.load(embeddings_temp)
92
+
93
+ return word_index_dict, embeddings
94
+
95
+
96
+ @st.cache_resource()
97
+ def load_sentence_transformer_model(model_name):
98
+ sentenceTransformer = SentenceTransformer(model_name)
99
+ return sentenceTransformer
100
+
101
+
102
+ def get_sentence_transformer_embeddings(sentence, model_name="all-MiniLM-L6-v2"):
103
+ """
104
+ Get sentence transformer embeddings for a sentence
105
+ """
106
+ # 384 dimensional embedding
107
+ # Default model: https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2
108
+
109
+ sentenceTransformer = load_sentence_transformer_model(model_name)
110
+
111
+ try:
112
+ return sentenceTransformer.encode(sentence)
113
+ except:
114
+ if model_name == "all-MiniLM-L6-v2":
115
+ return np.zeros(384)
116
+ else:
117
+ return np.zeros(512)
118
+
119
+
120
+ def get_glove_embeddings(word, word_index_dict, embeddings, model_type):
121
+ """
122
+ Get glove embedding for a single word
123
+ """
124
+ if word.lower() in word_index_dict:
125
+ return embeddings[word_index_dict[word.lower()]]
126
+ else:
127
+ return np.zeros(int(model_type.split("d")[0]))
128
+
129
+
130
+ def averaged_glove_embeddings_gdrive(sentence, word_index_dict, embeddings, model_type=50):
131
+ """
132
+ Get averaged glove embeddings for a sentence
133
+ 1. Split sentence into words
134
+ 2. Get embeddings for each word
135
+ 3. Add embeddings for each word
136
+ 4. Divide by number of words
137
+ 5. Return averaged embeddings
138
+ (30 pts)
139
+ """
140
+ words = sentence.split() # Step 1: Split sentence into words
141
+ embedding_sum = np.zeros(int(model_type.split("d")[0]))
142
+ valid_word_count = 0
143
+
144
+ for word in words: # Step 2: Get embeddings for each word
145
+ word_embedding = get_glove_embeddings(word, word_index_dict, embeddings, model_type)
146
+ if np.any(word_embedding): # Only consider valid embeddings
147
+ embedding_sum += word_embedding
148
+ valid_word_count += 1
149
+
150
+ if valid_word_count > 0: # Step 4: Divide by number of words
151
+ averaged_embedding = embedding_sum / valid_word_count
152
+ else:
153
+ averaged_embedding = np.zeros(int(model_type.split("d")[0]))
154
+
155
+ return averaged_embedding # Step 5: Return averaged embeddings
156
+
157
+ def get_category_embeddings(embeddings_metadata):
158
+ """
159
+ Get embeddings for each category
160
+ 1. Split categories into words
161
+ 2. Get embeddings for each word
162
+ """
163
+ model_name = embeddings_metadata["model_name"]
164
+ st.session_state["cat_embed_" + model_name] = {}
165
+ for category in st.session_state.categories.split(" "):
166
+ if model_name:
167
+ if not category in st.session_state["cat_embed_" + model_name]:
168
+ st.session_state["cat_embed_" + model_name][category] = get_sentence_transformer_embeddings(category, model_name=model_name)
169
+ else:
170
+ if not category in st.session_state["cat_embed_" + model_name]:
171
+ st.session_state["cat_embed_" + model_name][category] = get_sentence_transformer_embeddings(category)
172
+
173
+
174
+ def update_category_embeddings(embedings_metadata):
175
+ """
176
+ Update embeddings for each category
177
+ """
178
+ get_category_embeddings(embeddings_metadata)
179
+
180
+
181
+ def get_sorted_cosine_similarity(embeddings_metadata):
182
+ """
183
+ Get sorted cosine similarity between input sentence and categories
184
+ Steps:
185
+ 1. Get embeddings for input sentence
186
+ 2. Get embeddings for categories (if not found, update category embeddings)
187
+ 3. Compute cosine similarity between input sentence and categories
188
+ 4. Sort cosine similarity
189
+ 5. Return sorted cosine similarity
190
+ (50 pts)
191
+ """
192
+ categories = st.session_state.categories.split(" ")
193
+ cosine_sim = {}
194
+ if embeddings_metadata["embedding_model"] == "glove":
195
+ word_index_dict = embeddings_metadata["word_index_dict"]
196
+ embeddings = embeddings_metadata["embeddings"]
197
+ model_type = embeddings_metadata["model_type"]
198
+
199
+ input_embedding = averaged_glove_embeddings_gdrive(st.session_state.text_search,
200
+ word_index_dict,
201
+ embeddings, model_type)
202
+
203
+ for category in categories:
204
+ # Get embedding for category
205
+ category_embedding = averaged_glove_embeddings_gdrive(category, word_index_dict, embeddings, model_type)
206
+ # Compute cosine similarity
207
+ cos_sim = cosine_similarity(input_embedding, category_embedding)
208
+ cosine_sim[category] = cos_sim
209
+
210
+ else:
211
+ model_name = embeddings_metadata["model_name"]
212
+ if not "cat_embed_" + model_name in st.session_state:
213
+ get_category_embeddings(embeddings_metadata)
214
+
215
+ category_embeddings = st.session_state["cat_embed_" + model_name]
216
+
217
+ print("text_search = ", st.session_state.text_search)
218
+ if model_name:
219
+ input_embedding = get_sentence_transformer_embeddings(st.session_state.text_search, model_name=model_name)
220
+ else:
221
+ input_embedding = get_sentence_transformer_embeddings(st.session_state.text_search)
222
+ for category in categories:
223
+ # Update category embeddings if category not found
224
+
225
+ if category not in category_embeddings:
226
+ update_category_embeddings(embeddings_metadata)
227
+ category_embeddings = st.session_state["cat_embed_" + model_name]
228
+
229
+ # Compute cosine similarity
230
+ category_embedding = category_embeddings[category]
231
+ cos_sim = cosine_similarity(input_embedding, category_embedding)
232
+ cosine_sim[category] = cos_sim
233
+
234
+ # Sort the cosine similarities
235
+ sorted_cosine_sim = dict(sorted(cosine_sim.items(), key=lambda item: item[1], reverse=True))
236
+
237
+ return sorted_cosine_sim
238
+
239
+
240
+ def plot_piechart(sorted_cosine_scores_items):
241
+ sorted_cosine_scores = np.array([
242
+ sorted_cosine_scores_items[index][1]
243
+ for index in range(len(sorted_cosine_scores_items))
244
+ ]
245
+ )
246
+ categories = st.session_state.categories.split(" ")
247
+ categories_sorted = [
248
+ categories[sorted_cosine_scores_items[index][0]]
249
+ for index in range(len(sorted_cosine_scores_items))
250
+ ]
251
+ fig, ax = plt.subplots()
252
+ ax.pie(sorted_cosine_scores, labels=categories_sorted, autopct="%1.1f%%")
253
+ st.pyplot(fig) # Figure
254
+
255
+
256
+ def plot_piechart_helper(sorted_cosine_scores_items):
257
+ sorted_cosine_scores = np.array(
258
+ [
259
+ sorted_cosine_scores_items[index][1]
260
+ for index in range(len(sorted_cosine_scores_items))
261
+ ]
262
+ )
263
+ categories = st.session_state.categories.split(" ")
264
+ categories_sorted = [
265
+ categories[sorted_cosine_scores_items[index][0]]
266
+ for index in range(len(sorted_cosine_scores_items))
267
+ ]
268
+ fig, ax = plt.subplots(figsize=(3, 3))
269
+ my_explode = np.zeros(len(categories_sorted))
270
+ my_explode[0] = 0.2
271
+ if len(categories_sorted) == 3:
272
+ my_explode[1] = 0.1 # explode this by 0.2
273
+ elif len(categories_sorted) > 3:
274
+ my_explode[2] = 0.05
275
+ ax.pie(
276
+ sorted_cosine_scores,
277
+ labels=categories_sorted,
278
+ autopct="%1.1f%%",
279
+ explode=my_explode,
280
+ )
281
+
282
+ return fig
283
+
284
+
285
+ def plot_piecharts(sorted_cosine_scores_models):
286
+ scores_list = []
287
+ categories = st.session_state.categories.split(" ")
288
+ index = 0
289
+ for model in sorted_cosine_scores_models:
290
+ scores_list.append(sorted_cosine_scores_models[model])
291
+ # scores_list[index] = np.array([scores_list[index][ind2][1] for ind2 in range(len(scores_list[index]))])
292
+ index += 1
293
+
294
+ if len(sorted_cosine_scores_models) == 2:
295
+ fig, (ax1, ax2) = plt.subplots(2)
296
+
297
+ categories_sorted = [
298
+ categories[scores_list[0][index][0]] for index in range(len(scores_list[0]))
299
+ ]
300
+ sorted_scores = np.array(
301
+ [scores_list[0][index][1] for index in range(len(scores_list[0]))]
302
+ )
303
+ ax1.pie(sorted_scores, labels=categories_sorted, autopct="%1.1f%%")
304
+
305
+ categories_sorted = [
306
+ categories[scores_list[1][index][0]] for index in range(len(scores_list[1]))
307
+ ]
308
+ sorted_scores = np.array(
309
+ [scores_list[1][index][1] for index in range(len(scores_list[1]))]
310
+ )
311
+ ax2.pie(sorted_scores, labels=categories_sorted, autopct="%1.1f%%")
312
+
313
+ st.pyplot(fig)
314
+
315
+
316
+ def plot_alatirchart(sorted_cosine_scores_models):
317
+ models = list(sorted_cosine_scores_models.keys())
318
+ tabs = st.tabs(models)
319
+ figs = {}
320
+ for model in models:
321
+ figs[model] = plot_piechart_helper(sorted_cosine_scores_models[model])
322
+
323
+ for index in range(len(tabs)):
324
+ with tabs[index]:
325
+ st.pyplot(figs[models[index]])
326
+
327
+
328
+ # 测试
329
+
330
+ import os
331
+ print("Current Working Directory:", os.getcwd())
332
+
333
+
334
+ ### Text Search ###
335
+ st.sidebar.title("GloVe Twitter")
336
+ st.sidebar.markdown(
337
+ """
338
+ GloVe is an unsupervised learning algorithm for obtaining vector representations for words. Pretrained on
339
+ 2 billion tweets with vocabulary size of 1.2 million. Download from [Stanford NLP](http://nlp.stanford.edu/data/glove.twitter.27B.zip).
340
+
341
+ Jeffrey Pennington, Richard Socher, and Christopher D. Manning. 2014. *GloVe: Global Vectors for Word Representation*.
342
+ """
343
+ )
344
+
345
+
346
+ # 初始化 Session State 变量
347
+ if 'categories' not in st.session_state:
348
+ st.session_state['categories'] = "Flowers Colors Cars Weather Food"
349
+ if 'text_search' not in st.session_state:
350
+ st.session_state['text_search'] = "Roses are red, trucks are blue, and Seattle is grey right now"
351
+
352
+ # ... [其余 Streamlit 代码]
353
+
354
+ model_type = st.sidebar.selectbox("Choose the model", ("25d", "50d"), index=1)
355
+
356
+
357
+ st.title("Search Based Retrieval Demo")
358
+ st.subheader(
359
+ "Pass in space separated categories you want this search demo to be about."
360
+ )
361
+ # st.selectbox(label="Pick the categories you want this search demo to be about...",
362
+ # options=("Flowers Colors Cars Weather Food", "Chocolate Milk", "Anger Joy Sad Frustration Worry Happiness", "Positive Negative"),
363
+ # key="categories"
364
+ # )
365
+
366
+
367
+ # 用户输入的 categories
368
+ user_categories = st.text_input(
369
+ label="Categories", value=st.session_state.categories
370
+ )
371
+
372
+ # 更新 Session State 变量 - 修改的地方
373
+ st.session_state.categories = user_categories
374
+
375
+ # st.text_input(
376
+ # label="Categories", key="categories", value="Flowers Colors Cars Weather Food"
377
+ # )
378
+
379
+ # Categories = st.session_state.get('categories', "Flowers Colors Cars Weather Food")
380
+
381
+
382
+ print(st.session_state.get("categories"))
383
+ # print(st.session_state["categories"])
384
+
385
+ print(type(st.session_state.get("categories")))
386
+ # print(type(st.session_state["categories"]))
387
+
388
+ # print("Categories = ", categories)
389
+ # st.session_state.categories = categories
390
+
391
+ st.subheader("Pass in an input word or even a sentence")
392
+ user_text_search = st.text_input(
393
+ label="Input your sentence",
394
+ value=st.session_state.text_search,
395
+
396
+ )
397
+
398
+ # 更新 Session State 变量 - 修改的地方
399
+ st.session_state.text_search = user_text_search
400
+ # st.session_state.text_search = text_search
401
+
402
+ # Download glove embeddings if it doesn't exist
403
+ embeddings_path = "embeddings_" + str(model_type) + "_temp.npy"
404
+ word_index_dict_path = "word_index_dict_" + str(model_type) + "_temp.pkl"
405
+ if not os.path.isfile(embeddings_path) or not os.path.isfile(word_index_dict_path):
406
+ print("Model type = ", model_type)
407
+ glove_path = "Data/glove_" + str(model_type) + ".pkl"
408
+ print("glove_path = ", glove_path)
409
+
410
+ # Download embeddings from google drive
411
+ with st.spinner("Downloading glove embeddings..."):
412
+ download_glove_embeddings_gdrive(model_type)
413
+
414
+
415
+ # Load glove embeddings
416
+ word_index_dict, embeddings = load_glove_embeddings_gdrive(model_type)
417
+
418
+
419
+ # Find closest word to an input word
420
+ if st.session_state.text_search:
421
+ # Glove embeddings
422
+ print("Glove Embedding")
423
+ embeddings_metadata = {
424
+ "embedding_model": "glove",
425
+ "word_index_dict": word_index_dict,
426
+ "embeddings": embeddings,
427
+ "model_type": model_type,
428
+ "text_search": st.session_state.text_search
429
+ }
430
+ with st.spinner("Obtaining Cosine similarity for Glove..."):
431
+ sorted_cosine_sim_glove = get_sorted_cosine_similarity(
432
+ embeddings_metadata
433
+ )
434
+
435
+ # Sentence transformer embeddings
436
+ print("Sentence Transformer Embedding")
437
+ embeddings_metadata = {"embedding_model": "transformers", "model_name": "",
438
+ "text_search": st.session_state.text_search }
439
+ with st.spinner("Obtaining Cosine similarity for 384d sentence transformer..."):
440
+ sorted_cosine_sim_transformer = get_sorted_cosine_similarity(
441
+ embeddings_metadata
442
+ )
443
+
444
+ # Results and Plot Pie Chart for Glove
445
+ print("Categories are: ", st.session_state.categories)
446
+ st.subheader(
447
+ "Closest word I have between: "
448
+ + st.session_state.categories
449
+ + " as per different Embeddings"
450
+ )
451
+
452
+ # print(sorted_cosine_sim_glove)
453
+ # print(sorted_cosine_sim_transformer)
454
+ # print(sorted_distilbert)
455
+ # Altair Chart for all models
456
+ # plot_alatirchart(
457
+ # {
458
+ # "glove_" + str(model_type): sorted_cosine_sim_glove,
459
+ # "sentence_transformer_384": sorted_cosine_sim_transformer,
460
+ # }
461
+ # )
462
+ # "distilbert_512": sorted_distilbert})
463
+
464
+ # 修改的地方!
465
+ # Display the closest category result for GloVe and Sentence Transformer Embeddings
466
+ st.write(f"The closest category in GloVe embeddings is: {list(sorted_cosine_sim_glove.keys())[0]}")
467
+ st.write(
468
+ f"The closest category in Sentence Transformer embeddings is: {list(sorted_cosine_sim_transformer.keys())[0]}")
469
+
470
+ st.write("")
471
+ st.write(
472
+ "Demo developed by [Dr. Karthik Mohan](https://www.linkedin.com/in/karthik-mohan-72a4b323/)"
473
+ )
word_index_dict_50d_temp.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:674af352f703098ef122f6a8db7c5e08c5081829d49daea32e5aeac1fe582900
3
+ size 60284151