NursNurs commited on
Commit
1ad084f
·
1 Parent(s): 26e2925

Added sorting by price, relevancy, rating

Browse files
Files changed (1) hide show
  1. app.py +255 -64
app.py CHANGED
@@ -9,6 +9,7 @@ from transformers import BertTokenizer, BertModel
9
  from collections import defaultdict, Counter
10
  from tqdm.auto import tqdm
11
  from sklearn.metrics.pairwise import cosine_similarity
 
12
 
13
  #Loading the model
14
  @st.cache_resource
@@ -31,13 +32,26 @@ def load_data():
31
  vectors_df = pd.read_csv('restaurants_dataframe_with_embeddings.csv')
32
  embeds = dict(enumerate(vectors_df['Embeddings']))
33
  rest_names = list(vectors_df['Names'])
 
34
  return embeds, rest_names, vectors_df
35
 
36
  #type: dict; keys: 0-n
37
- restaurants_embeds, rest_names, df = load_data()
38
 
39
  model, tokenizer = get_models()
40
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  #a function that takes a sentence and converts it into embeddings
42
  def get_bert_embeddings(sentence, model, tokenizer):
43
  inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True)
@@ -47,19 +61,70 @@ def get_bert_embeddings(sentence, model, tokenizer):
47
  return embeddings
48
 
49
  # a function that return top-K best restaurants
50
- def return_top_k(query, k=10):
51
  embedded_query = get_bert_embeddings(query, model, tokenizer)
52
  embedded_query = embedded_query.numpy()
53
-
54
- top_similar = dict()
55
  for i in range(len(restaurants_embeds)):
56
  name = rest_names[i]
57
- top_similar[i] = cosine_similarity(embedded_query, str_to_numpy(restaurants_embeds[i]))[0][0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
- top_similar = dict(sorted(top_similar.items(), key=lambda item: item[1], reverse=True))
60
- top_similar = dict([(key, value) for key, value in top_similar.items()][:k])
61
- names = [rest_names[i] for i in top_similar.keys()]
62
- result = dict(zip(names, top_similar.values()))
 
 
 
63
  return result
64
 
65
  #combines 2 users preferences into 1 string and fetches best options
@@ -67,14 +132,69 @@ def get_combined_preferences(user1, user2):
67
  #TODO: optimize for more users
68
  shared_pref = ''
69
  for pref in user1:
70
- shared_pref += pref
71
  shared_pref += " "
72
  shared_pref += " "
73
  for pref in user2:
74
- shared_pref += pref
75
  shared_pref += " "
76
- return shared_pref
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  if 'preferences_1' not in st.session_state:
79
  st.session_state.preferences_1 = []
80
 
@@ -87,81 +207,152 @@ if 'food' not in st.session_state:
87
  if 'ambiance' not in st.session_state:
88
  st.session_state.ambiance = ['Romantic date', 'Friends catching up', 'Family gathering', 'Big group', 'Business-meeting', 'Other']
89
 
90
-
 
 
91
  if 'price' not in st.session_state:
92
- st.session_state.price = dict(enumerate(['$', '$$', '$$$', '$$$$']))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
  # Configure Streamlit page and state
95
  st.title("GoTogether!")
96
- st.markdown(
97
- "Tell us about your preferences!")
98
  st.caption("In section 'Others', you can describe any wishes.")
99
 
100
- st.write('User 1')
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
 
 
102
  food_1 = st.selectbox('Select the food type you prefer', st.session_state.food, key=1)
103
  if food_1 == 'Other':
104
  food_1 = st.text_input(label="Your description", placeholder="What kind of food would you like to eat?", key=10)
105
-
106
- st.session_state.preferences_1.append(food_1)
107
 
108
  ambiance_1 = st.selectbox('What describes your occasion the best?', st.session_state.ambiance, key=2)
109
  if ambiance_1 == 'Other':
110
- ambiance_1 = st.text_input(label="Your description", placeholder="How would you describe your meeting?", key=11)
111
-
112
- price_1 = st.select_slider("Your preferred price range", options=('$', '$$', '$$$', '$$$$'), key=3)
113
 
114
- st.session_state.preferences_1.append(ambiance_1)
 
 
115
 
116
- st.write('User 2')
 
 
 
 
 
117
 
118
- food_2 = st.selectbox('Select the food type you prefer', st.session_state.food, key=4)
 
119
  if food_2 == 'Other':
120
- food_2 = st.text_input(label="Your description", placeholder="What kind of food would you like to eat?", key=13)
121
-
122
- st.session_state.preferences_2.append(food_2)
123
 
124
  ambiance_2 = st.selectbox('What describes your occasion the best?', st.session_state.ambiance, key=5)
125
  if ambiance_2 == 'Other':
126
- ambiance_2 = st.text_input(label="Your description", placeholder="How would you describe your meeting?", key=12)
 
 
 
 
 
 
 
 
127
 
128
- price_2 = st.select_slider("Your preferred price range", options=('$', '$$', '$$$', '$$$$'), key=6)
 
 
 
 
 
 
 
129
 
130
- st.session_state.preferences_2.append(ambiance_2)
 
 
 
 
 
 
 
131
 
132
- submit = st.button("Submit")
133
- if submit:
134
- with st.spinner("Please wait while we are finding the best solution..."):
135
- query = get_combined_preferences(st.session_state.preferences_1, st.session_state.preferences_2)
136
- st.write("Your query is:", query)
137
- results = return_top_k(query, k=10)
138
- st.write("Here are the best matches to your preferences:")
139
- i = 1
140
- for name, score in results.items():
141
- st.write("Top", i, ':', name, score)
142
- condition = df['Names'] == name
143
- # Use the condition to extract the value(s)
144
- description = df.loc[condition, 'Strings']
145
- st.write(description)
146
- i+=1
147
 
148
- st.session_state.preferences_1, st.session_state.preferences_2 = [], []
149
-
150
- #TODO: include rating and price as variables
151
-
152
- # if input:
153
- # input_embed = model.encode(input)
154
- # sim_score = similarity_top(input_embed, icd_embeddings)
155
- # i = 1
156
- # for dis, value in sim_score:
157
- # st.write(f":green[Prediction number] {i}:")
158
- # st.write(f"{dis} (similarity score:", value, ")")
159
- # i+= 1
160
-
161
- # text_spinner_placeholder = st.empty()
162
- # with st.spinner("Please wait while your visualizations are being generated..."):
163
- # time.sleep(5)
164
- # vis_results_2d(input_embed)
165
- # vis_results_3d(input_embed)
166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  # #TODO: implement price range as a sliding bar
 
 
 
 
9
  from collections import defaultdict, Counter
10
  from tqdm.auto import tqdm
11
  from sklearn.metrics.pairwise import cosine_similarity
12
+ import time
13
 
14
  #Loading the model
15
  @st.cache_resource
 
32
  vectors_df = pd.read_csv('restaurants_dataframe_with_embeddings.csv')
33
  embeds = dict(enumerate(vectors_df['Embeddings']))
34
  rest_names = list(vectors_df['Names'])
35
+ vectors_df['Weights'] = [1]*len(vectors_df)
36
  return embeds, rest_names, vectors_df
37
 
38
  #type: dict; keys: 0-n
39
+ restaurants_embeds, rest_names, init_df = load_data()
40
 
41
  model, tokenizer = get_models()
42
 
43
+ # query_params = st.experimental_get_query_params()
44
+ # st.write("query_params")
45
+ # st.write(query_params)
46
+
47
+ # def update_params():
48
+ # st.experimental_set_query_params(
49
+ # sorting=st.session_state.sort_by)
50
+
51
+ # if query_params:
52
+ # sort_by = query_params["sorting"][0]
53
+ # st.session_state.sort_by = sort_by
54
+
55
  #a function that takes a sentence and converts it into embeddings
56
  def get_bert_embeddings(sentence, model, tokenizer):
57
  inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True)
 
61
  return embeddings
62
 
63
  # a function that return top-K best restaurants
64
+ def compute_cos_sim(query):
65
  embedded_query = get_bert_embeddings(query, model, tokenizer)
66
  embedded_query = embedded_query.numpy()
67
+ top_similar = np.array([])
 
68
  for i in range(len(restaurants_embeds)):
69
  name = rest_names[i]
70
+ top_similar = np.append(top_similar, cosine_similarity(embedded_query, str_to_numpy(restaurants_embeds[i]))[0][0])
71
+
72
+ st.session_state.df['cos_sim'] = top_similar.tolist()
73
+ weights = np.array(st.session_state.df['Weights'])
74
+ #multiply weights by the cosine similarity
75
+ top_similar_weighted = dict(enumerate(np.multiply(top_similar, weights)))
76
+
77
+ st.session_state.df['Relevancy'] = top_similar_weighted.values()
78
+ return st.session_state.df
79
+
80
+ def sort_by_relevancy(k):
81
+ '''
82
+ k - int - how many top-matching places to show
83
+ '''
84
+ top_similar_weighted = dict(enumerate(st.session_state.precalculated_df['Relevancy']))
85
+ #sort in the descending order
86
+ top_similar_weighted = dict(sorted(top_similar_weighted.items(), key=lambda item: item[1], reverse=True))
87
+ #leave only K recommendations
88
+ top_k_similar = dict([(key, value) for key, value in top_similar_weighted.items()][:k])
89
+ #get restaurant names
90
+ names = [rest_names[i] for i in top_k_similar.keys()]
91
+ result = dict(zip(names, top_k_similar.values()))
92
+ return result
93
+
94
+ def sort_by_price(k):
95
+ '''
96
+ k - int - how many top-matching places to show
97
+ '''
98
+ relevance = np.array(st.session_state.precalculated_df['Relevancy'])
99
+ prices = np.array([st.session_state.price[str(val)] for val in st.session_state.precalculated_df['Price']])
100
+ top_similar_by_price = dict(enumerate(np.multiply(relevance, prices)))
101
+ st.session_state.precalculated_df['Sort_price'] = top_similar_by_price.values()
102
+
103
+ #sort in the descending order
104
+ top_similar_by_price = dict(sorted(top_similar_by_price.items(), key=lambda item: item[1], reverse=True))
105
+ #leave only K recommendations
106
+ top_k_similar = dict([(key, value) for key, value in top_similar_by_price.items()][:k])
107
+ #get restaurant names
108
+ names = [rest_names[i] for i in top_k_similar.keys()]
109
+ result = dict(zip(names, top_k_similar.values()))
110
+ return result
111
+
112
+ def sort_by_rating(k):
113
+ '''
114
+ k - int - how many top-matching places to show
115
+ '''
116
+ relevance = np.array(st.session_state.precalculated_df['Relevancy'])
117
+ rating = np.array(list(st.session_state.precalculated_df['Rating']))
118
+ top_similar_by_rating = dict(enumerate(np.multiply(relevance, rating)))
119
+ st.session_state.precalculated_df['Sort_rating'] = top_similar_by_rating.values()
120
 
121
+ #sort in the descending order
122
+ top_similar_by_rating = dict(sorted(top_similar_by_rating.items(), key=lambda item: item[1], reverse=True))
123
+ #leave only K recommendations
124
+ top_k_similar = dict([(key, value) for key, value in top_similar_by_rating.items()][:k])
125
+ #get restaurant names
126
+ names = [rest_names[i] for i in top_k_similar.keys()]
127
+ result = dict(zip(names, top_k_similar.values()))
128
  return result
129
 
130
  #combines 2 users preferences into 1 string and fetches best options
 
132
  #TODO: optimize for more users
133
  shared_pref = ''
134
  for pref in user1:
135
+ shared_pref += pref.lower()
136
  shared_pref += " "
137
  shared_pref += " "
138
  for pref in user2:
139
+ shared_pref += pref.lower()
140
  shared_pref += " "
 
141
 
142
+ freq_words = Counter(shared_pref.split())
143
+
144
+ return shared_pref, freq_words
145
+
146
+ def filter_places(restrictions):
147
+ #punish the weight of places that don't fit restrictions
148
+ # st.write("Here are the restrictions you provided:")
149
+ # st.write(restrictions)
150
+ taboo = set([word.lower() for word in restrictions])
151
+ for i in range(len(st.session_state.df)):
152
+ descr = [word.lower() for word in st.session_state.df['Strings'][i].split()]
153
+ name = st.session_state.df['Names'][i]
154
+ for criteria in taboo:
155
+ if criteria not in descr:
156
+ st.session_state.df['Weights'][i] = 0.1 * st.session_state.df['Weights'][i]
157
+
158
+
159
+ return st.session_state.df
160
+
161
+ def promote_places(preferences):
162
+ '''
163
+ input type: dict()
164
+ a function that takes most common words, checks if descriptions fit them, increases their weight if they do
165
+ '''
166
+ #punish the weight of places that don't fit restrictions
167
+ # st.write("Here are the most common preferences you provided:")
168
+ # st.write(preferences)
169
+
170
+ for i in range(len(st.session_state.df)):
171
+ descr = [word.lower() for word in st.session_state.df['Strings'][i].split()]
172
+ name = st.session_state.df['Names'][i]
173
+ for pref in preferences:
174
+ if pref in descr:
175
+ st.session_state.df['Weights'][i] = 2 * st.session_state.df['Weights'][i]
176
+
177
+ return st.session_state.df
178
+
179
+ def generate_results(sort_by):
180
+ if sort_by == 'Price':
181
+ with st.spinner("Sorting your results by price..."):
182
+ st.write("Sorting your results by price...")
183
+ results = sort_by_price(10)
184
+ elif sort_by == 'Rating':
185
+ with st.spinner("Sorting your results by rating..."):
186
+ st.write("Sorting your results by rating...")
187
+ results = sort_by_rating(10)
188
+ elif sort_by == 'Relevancy (default)':
189
+ with st.spinner("Sorting your results by relevancy..."):
190
+ st.write("Sorting your results by relevancy...")
191
+ results = sort_by_relevancy(10)
192
+ else:
193
+ st.write("Sorry, we are still working on this option. For now, the results are sorted by relevance")
194
+ with st.spinner("Sorting your results by relevancy..."):
195
+ results = sort_by_relevancy(10)
196
+ return results
197
+
198
  if 'preferences_1' not in st.session_state:
199
  st.session_state.preferences_1 = []
200
 
 
207
  if 'ambiance' not in st.session_state:
208
  st.session_state.ambiance = ['Romantic date', 'Friends catching up', 'Family gathering', 'Big group', 'Business-meeting', 'Other']
209
 
210
+ if 'restrictions' not in st.session_state:
211
+ st.session_state.restrictions = []
212
+
213
  if 'price' not in st.session_state:
214
+ st.session_state.price = {'$': 2, '₩': 2, '$$': 1, '₩₩': 1, '$$$': 0.5, '$$$$': 0.1, "nan": 1}
215
+
216
+ if 'sort_by' not in st.session_state:
217
+ st.session_state.sort_by = ''
218
+
219
+ if 'options' not in st.session_state:
220
+ st.session_state.options = ['Relevancy (default)', 'Price', 'Rating', 'Distance']
221
+
222
+ if 'df' not in st.session_state:
223
+ st.session_state.df = init_df
224
+
225
+ if 'precalculated_df' not in st.session_state:
226
+ st.session_state.precalculated_df = pd.DataFrame()
227
+
228
+ if 'stop_search' not in st.session_state:
229
+ st.session_state.stop_search = False
230
 
231
  # Configure Streamlit page and state
232
  st.title("GoTogether!")
233
+ st.markdown("Tell us about your preferences!")
 
234
  st.caption("In section 'Others', you can describe any wishes.")
235
 
236
+ # options_disability_1 = st.multiselect(
237
+ # 'Do you need a wheelchair?',
238
+ # ['Yes', 'No'], ['No'], key=101)
239
+
240
+ # if options_disability_1 == 'Yes':
241
+ # st.session_state.restrictions.append('Wheelchair')
242
+
243
+ # price_1 = st.select_slider("Your preferred price range", options=('$', '$$', '$$$', '$$$$'), key=3)
244
+
245
+ # st.session_state.preferences_1.append(ambiance_1)
246
+
247
+ # Komplettes Beispiel für die Verwendung der 'with'-Notation
248
+ # with st.form('my_form_1'):
249
+ # st.subheader('**User 1**')
250
 
251
+ st.write("User 1")
252
+ # Eingabe-Widgets
253
  food_1 = st.selectbox('Select the food type you prefer', st.session_state.food, key=1)
254
  if food_1 == 'Other':
255
  food_1 = st.text_input(label="Your description", placeholder="What kind of food would you like to eat?", key=10)
 
 
256
 
257
  ambiance_1 = st.selectbox('What describes your occasion the best?', st.session_state.ambiance, key=2)
258
  if ambiance_1 == 'Other':
259
+ ambiance_1 = st.text_input(label="Your description", placeholder="How would you describe your meeting?", key=11)
 
 
260
 
261
+ options_food_1 = st.multiselect(
262
+ 'Do you have any dietary restrictions?',
263
+ ['Vegan', 'Vegetarian', 'Halal'], key=100)
264
 
265
+ additional_1 = st.text_input(label="Your description", placeholder="Anything else you wanna share?", key=102)
266
+
267
+ with_kids = st.checkbox('I will come with kids', key=200)
268
+
269
+ # st.subheader('**User 2**')
270
+ st.write("User 2")
271
 
272
+ # Eingabe-Widgets
273
+ food_2 = st.selectbox('Select the food type you prefer', st.session_state.food, key=3)
274
  if food_2 == 'Other':
275
+ food_2 = st.text_input(label="Your description", placeholder="What kind of food would you like to eat?", key=4)
 
 
276
 
277
  ambiance_2 = st.selectbox('What describes your occasion the best?', st.session_state.ambiance, key=5)
278
  if ambiance_2 == 'Other':
279
+ ambiance_2 = st.text_input(label="Your description", placeholder="How would you describe your meeting?", key=6)
280
+
281
+ options_food_2 = st.multiselect(
282
+ 'Do you have any dietary restrictions?',
283
+ ['Vegan', 'Vegetarian', 'Halal', 'Other'], key=7)
284
+
285
+ additional_2 = st.text_input(label="Your description", placeholder="Anything else you wanna share?", key=8)
286
+
287
+ with_kids_2 = st.checkbox('I will come with kids', key=201)
288
 
289
+ if len(st.session_state.preferences_1) == 0:
290
+ st.session_state.preferences_1.append(food_1)
291
+ st.session_state.preferences_1.append(ambiance_1)
292
+ st.session_state.restrictions.extend(options_food_1)
293
+ if additional_1:
294
+ st.session_state.preferences_1.append(additional_1)
295
+ if with_kids:
296
+ st.session_state.restrictions.append('kids')
297
 
298
+ if len(st.session_state.preferences_2) == 0:
299
+ st.session_state.preferences_2.append(food_2)
300
+ st.session_state.preferences_2.append(ambiance_2)
301
+ st.session_state.restrictions.extend(options_food_2)
302
+ if additional_2:
303
+ st.session_state.preferences_2.append(additional_2)
304
+ if with_kids_2:
305
+ st.session_state.restrictions.append('kids')
306
 
307
+ submitted = st.button('Submit!')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
308
 
309
+ if submitted:
310
+ st.markdown("Thanks, we received your preferences!")
311
+ st.session_state.stop_search = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
 
313
+ else:
314
+ st.write('☝️ Describe your preferences!')
315
+
316
+
317
+ submit = st.button("Find best matches!", type='primary')
318
+
319
+ if submit or (not st.session_state.precalculated_df.empty):
320
+ with st.spinner("Please wait while we are finding the best solution..."):
321
+ if st.session_state.precalculated_df.empty:
322
+ query = get_combined_preferences(st.session_state.preferences_1, st.session_state.preferences_2)
323
+ st.write("Your query is:", query[0])
324
+ #sort places based on restrictions
325
+ st.session_state.precalculated_df = filter_places(st.session_state.restrictions)
326
+ #sort places by elevating preferrences
327
+ st.session_state.precalculated_df = promote_places(query[1])
328
+ st.session_state.precalculated_df = compute_cos_sim(query[0])
329
+ sort_by = st.selectbox(('Sort by:'), st.session_state.options, key=400,
330
+ index=st.session_state.options.index('Relevancy (default)'))
331
+ if sort_by:
332
+ st.session_state.sort_by = sort_by
333
+ results = generate_results(st.session_state.sort_by)
334
+ k = 10
335
+ st.write(f"Here are the best {k} matches to your preferences:")
336
+ i = 1
337
+ for name, score in results.items():
338
+ st.write("Top", i, ':', name, score)
339
+ condition = st.session_state.precalculated_df['Names'] == name
340
+ # Use the condition to extract the value(s)
341
+ description = st.session_state.precalculated_df.loc[condition, 'Strings']
342
+ st.write(description)
343
+ i+=1
344
+
345
+
346
+
347
+ stop = st.button("New search!", type='primary', key=500)
348
+ if stop:
349
+ st.session_state.preferences_1, st.session_state.preferences_2 = [], []
350
+ st.session_state.restrictions = []
351
+ st.session_state.sort_by = ""
352
+ st.session_state.df = init_df
353
+ st.session_state.precalculated_df = pd.DataFrame()
354
+
355
  # #TODO: implement price range as a sliding bar
356
+ # When the user presses "New search", erase everything
357
+ # Propose URLs
358
+ # Show keywords instead of whole strings