NursNurs commited on
Commit
696b1e3
·
1 Parent(s): 4cbe5e0

Upload 4 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ restaurants_dataframe_with_embeddings.csv filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import numpy as np
4
+ import time
5
+ import string
6
+ import pandas as pd
7
+ import numpy as np
8
+ from transformers import BertTokenizer, BertModel
9
+ from collections import defaultdict, Counter
10
+ from tqdm.auto import tqdm
11
+ from sklearn.metrics.pairwise import cosine_similarity
12
+
13
+ #Loading the model
14
+ @st.cache_resource
15
+ def get_models():
16
+ st.write('Loading the model...')
17
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
18
+ model = BertModel.from_pretrained("bert-base-uncased")
19
+ st.write("_The model is loaded and ready to use! :tada:_")
20
+ return model, tokenizer
21
+
22
+ #convert numpy arrays from strings back to arrays
23
+ def str_to_numpy(array_string):
24
+ array_string = array_string.replace('\n', '').replace('[','').replace(']','')
25
+ numpy_array = np.fromstring(array_string, sep=' ')
26
+ numpy_array = numpy_array.reshape((1, -1))
27
+ return numpy_array
28
+
29
+ @st.cache_data # 👈 Add the caching decorator
30
+ def load_data():
31
+ vectors_df = pd.read_csv('restaurants_dataframe_with_embeddings.csv')
32
+ embeds = dict(enumerate(vectors_df['Embeddings']))
33
+ rest_names = list(vectors_df['Names'])
34
+ return embeds, rest_names, vectors_df
35
+
36
+ #type: dict; keys: 0-n
37
+ restaurants_embeds, rest_names, df = load_data()
38
+
39
+ model, tokenizer = get_models()
40
+
41
+ #a function that takes a sentence and converts it into embeddings
42
+ def get_bert_embeddings(sentence, model, tokenizer):
43
+ inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True)
44
+ with torch.no_grad():
45
+ outputs = model(**inputs)
46
+ embeddings = outputs.last_hidden_state.mean(dim=1) # Average pool over tokens
47
+ return embeddings
48
+
49
+ # a function that return top-K best restaurants
50
+ def return_top_k(query, k=10):
51
+ embedded_query = get_bert_embeddings(query, model, tokenizer)
52
+ embedded_query = embedded_query.numpy()
53
+
54
+ top_similar = dict()
55
+ for i in range(len(restaurants_embeds)):
56
+ name = rest_names[i]
57
+ top_similar[i] = cosine_similarity(embedded_query, str_to_numpy(restaurants_embeds[i]))[0][0]
58
+
59
+ top_similar = dict(sorted(top_similar.items(), key=lambda item: item[1], reverse=True))
60
+ top_similar = dict([(key, value) for key, value in top_similar.items()][:k])
61
+ names = [rest_names[i] for i in top_similar.keys()]
62
+ result = dict(zip(names, top_similar.values()))
63
+ return result
64
+
65
+ #combines 2 users preferences into 1 string and fetches best options
66
+ def get_combined_preferences(user1, user2):
67
+ #TODO: optimize for more users
68
+ shared_pref = ''
69
+ for pref in user1:
70
+ shared_pref += pref
71
+ shared_pref += " "
72
+ shared_pref += " "
73
+ for pref in user2:
74
+ shared_pref += pref
75
+ shared_pref += " "
76
+ return shared_pref
77
+
78
+ if 'preferences_1' not in st.session_state:
79
+ st.session_state.preferences_1 = []
80
+
81
+ if 'preferences_2' not in st.session_state:
82
+ st.session_state.preferences_2 = []
83
+
84
+ if 'food' not in st.session_state:
85
+ st.session_state.food = ['Coffee', 'Italian', 'Mexican', 'Chinese', 'Indian', 'Asian', 'Fast food', 'Other']
86
+
87
+ if 'ambiance' not in st.session_state:
88
+ st.session_state.ambiance = ['Romantic date', 'Friends catching up', 'Family gathering', 'Big group', 'Business-meeting', 'Other']
89
+
90
+
91
+ if 'price' not in st.session_state:
92
+ st.session_state.price = dict(enumerate(['$', '$$', '$$$', '$$$$']))
93
+
94
+ # Configure Streamlit page and state
95
+ st.title("GoTogether!")
96
+ st.markdown(
97
+ "Tell us about your preferences!")
98
+ st.caption("In section 'Others', you can describe any wishes.")
99
+
100
+ st.write('User 1')
101
+
102
+ food_1 = st.selectbox('Select the food type you prefer', st.session_state.food, key=1)
103
+ if food_1 == 'Other':
104
+ food_1 = st.text_input(label="Your description", placeholder="What kind of food would you like to eat?", key=10)
105
+
106
+ st.session_state.preferences_1.append(food_1)
107
+
108
+ ambiance_1 = st.selectbox('What describes your occasion the best?', st.session_state.ambiance, key=2)
109
+ if ambiance_1 == 'Other':
110
+ ambiance_1 = st.text_input(label="Your description", placeholder="How would you describe your meeting?", key=11)
111
+
112
+ price_1 = st.select_slider("Your preferred price range", options=('$', '$$', '$$$', '$$$$'), key=3)
113
+
114
+ st.session_state.preferences_1.append(ambiance_1)
115
+
116
+ st.write('User 2')
117
+
118
+ food_2 = st.selectbox('Select the food type you prefer', st.session_state.food, key=4)
119
+ if food_2 == 'Other':
120
+ food_2 = st.text_input(label="Your description", placeholder="What kind of food would you like to eat?", key=13)
121
+
122
+ st.session_state.preferences_2.append(food_2)
123
+
124
+ ambiance_2 = st.selectbox('What describes your occasion the best?', st.session_state.ambiance, key=5)
125
+ if ambiance_2 == 'Other':
126
+ ambiance_2 = st.text_input(label="Your description", placeholder="How would you describe your meeting?", key=12)
127
+
128
+ price_2 = st.select_slider("Your preferred price range", options=('$', '$$', '$$$', '$$$$'), key=6)
129
+
130
+ st.session_state.preferences_2.append(ambiance_2)
131
+
132
+ submit = st.button("Submit")
133
+ if submit:
134
+ with st.spinner("Please wait while we are finding the best solution..."):
135
+ query = get_combined_preferences(st.session_state.preferences_1, st.session_state.preferences_2)
136
+ st.write("Your query is:", query)
137
+ results = return_top_k(query, k=10)
138
+ st.write("Here are the best matches to your preferences:")
139
+ i = 1
140
+ for name, score in results.items():
141
+ st.write("Top", i, ':', name, score)
142
+ condition = df['Names'] == name
143
+ # Use the condition to extract the value(s)
144
+ description = df.loc[condition, 'Strings']
145
+ st.write(description)
146
+ i+=1
147
+
148
+ st.session_state.preferences_1, st.session_state.preferences_2 = [], []
149
+
150
+ #TODO: include rating and price as variables
151
+
152
+ # if input:
153
+ # input_embed = model.encode(input)
154
+ # sim_score = similarity_top(input_embed, icd_embeddings)
155
+ # i = 1
156
+ # for dis, value in sim_score:
157
+ # st.write(f":green[Prediction number] {i}:")
158
+ # st.write(f"{dis} (similarity score:", value, ")")
159
+ # i+= 1
160
+
161
+ # text_spinner_placeholder = st.empty()
162
+ # with st.spinner("Please wait while your visualizations are being generated..."):
163
+ # time.sleep(5)
164
+ # vis_results_2d(input_embed)
165
+ # vis_results_3d(input_embed)
166
+
167
+ # #TODO: implement price range as a sliding bar
restaurants_dataframe_with_embeddings.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:31926e0abc4ff33c12761b0cc1d2b2f855bb097463e72a2dccbe9f2f2df3cf70
3
+ size 19014982