ericlkc commited on
Commit
dc6fc9a
·
verified ·
1 Parent(s): e668e37

Update app.py for Part4

Browse files
Files changed (1) hide show
  1. app.py +37 -56
app.py CHANGED
@@ -236,42 +236,35 @@ def get_sorted_cosine_similarity(embeddings_metadata):
236
 
237
  return sorted_cosine_sim
238
 
239
-
240
- def plot_piechart(sorted_cosine_scores_items):
241
- sorted_cosine_scores = np.array([
242
- sorted_cosine_scores_items[index][1]
243
- for index in range(len(sorted_cosine_scores_items))
244
- ]
245
- )
246
- categories = st.session_state.categories.split(" ")
247
- categories_sorted = [
248
- categories[sorted_cosine_scores_items[index][0]]
249
- for index in range(len(sorted_cosine_scores_items))
250
- ]
251
- fig, ax = plt.subplots()
252
- ax.pie(sorted_cosine_scores, labels=categories_sorted, autopct="%1.1f%%")
253
- st.pyplot(fig) # Figure
254
 
255
 
256
  def plot_piechart_helper(sorted_cosine_scores_items):
257
- sorted_cosine_scores = np.array(
258
- [
259
- sorted_cosine_scores_items[index][1]
260
- for index in range(len(sorted_cosine_scores_items))
261
- ]
262
- )
263
- categories = st.session_state.categories.split(" ")
264
- categories_sorted = [
265
- categories[sorted_cosine_scores_items[index][0]]
266
- for index in range(len(sorted_cosine_scores_items))
267
- ]
268
  fig, ax = plt.subplots(figsize=(3, 3))
269
  my_explode = np.zeros(len(categories_sorted))
270
  my_explode[0] = 0.2
271
  if len(categories_sorted) == 3:
272
- my_explode[1] = 0.1 # explode this by 0.2
273
  elif len(categories_sorted) > 3:
274
  my_explode[2] = 0.05
 
275
  ax.pie(
276
  sorted_cosine_scores,
277
  labels=categories_sorted,
@@ -314,10 +307,13 @@ def plot_piecharts(sorted_cosine_scores_models):
314
 
315
 
316
  def plot_alatirchart(sorted_cosine_scores_models):
 
 
317
  models = list(sorted_cosine_scores_models.keys())
318
  tabs = st.tabs(models)
319
  figs = {}
320
  for model in models:
 
321
  figs[model] = plot_piechart_helper(sorted_cosine_scores_models[model])
322
 
323
  for index in range(len(tabs)):
@@ -325,12 +321,6 @@ def plot_alatirchart(sorted_cosine_scores_models):
325
  st.pyplot(figs[models[index]])
326
 
327
 
328
- # 测试
329
-
330
- import os
331
- print("Current Working Directory:", os.getcwd())
332
-
333
-
334
  ### Text Search ###
335
  st.sidebar.title("GloVe Twitter")
336
  st.sidebar.markdown(
@@ -343,17 +333,14 @@ Jeffrey Pennington, Richard Socher, and Christopher D. Manning. 2014. *GloVe: Gl
343
  )
344
 
345
 
346
- # 初始化 Session State 变量
347
  if 'categories' not in st.session_state:
348
  st.session_state['categories'] = "Flowers Colors Cars Weather Food"
349
  if 'text_search' not in st.session_state:
350
  st.session_state['text_search'] = "Roses are red, trucks are blue, and Seattle is grey right now"
351
 
352
- # ... [其余 Streamlit 代码]
353
 
354
  model_type = st.sidebar.selectbox("Choose the model", ("25d", "50d"), index=1)
355
 
356
-
357
  st.title("Search Based Retrieval Demo")
358
  st.subheader(
359
  "Pass in space separated categories you want this search demo to be about."
@@ -364,12 +351,11 @@ st.subheader(
364
  # )
365
 
366
 
367
- # 用户输入的 categories
368
  user_categories = st.text_input(
369
  label="Categories", value=st.session_state.categories
370
  )
371
 
372
- # 更新 Session State 变量 - 修改的地方
373
  st.session_state.categories = user_categories
374
 
375
  # st.text_input(
@@ -395,7 +381,6 @@ user_text_search = st.text_input(
395
 
396
  )
397
 
398
- # 更新 Session State 变量 - 修改的地方
399
  st.session_state.text_search = user_text_search
400
  # st.session_state.text_search = text_search
401
 
@@ -449,25 +434,21 @@ if st.session_state.text_search:
449
  + " as per different Embeddings"
450
  )
451
 
452
- # print(sorted_cosine_sim_glove)
453
- # print(sorted_cosine_sim_transformer)
454
- # print(sorted_distilbert)
455
- # Altair Chart for all models
456
- # plot_alatirchart(
457
- # {
458
- # "glove_" + str(model_type): sorted_cosine_sim_glove,
459
- # "sentence_transformer_384": sorted_cosine_sim_transformer,
460
- # }
461
- # )
462
- # "distilbert_512": sorted_distilbert})
463
-
464
- # 修改的地方!
465
- # Display the closest category result for GloVe and Sentence Transformer Embeddings
466
- st.write(f"The closest category in GloVe embeddings is: {list(sorted_cosine_sim_glove.keys())[0]}")
467
  st.write(
468
- f"The closest category in Sentence Transformer embeddings is: {list(sorted_cosine_sim_transformer.keys())[0]}")
 
 
 
 
 
 
 
469
 
470
  st.write("")
471
  st.write(
472
- "Demo developed by [Dr. Karthik Mohan](https://www.linkedin.com/in/karthik-mohan-72a4b323/)"
473
  )
 
236
 
237
  return sorted_cosine_sim
238
 
239
+ #
240
+ # def plot_piechart(sorted_cosine_scores_items):
241
+ # sorted_cosine_scores = np.array([
242
+ # sorted_cosine_scores_items[index][1]
243
+ # for index in range(len(sorted_cosine_scores_items))
244
+ # ]
245
+ # )
246
+ # categories = st.session_state.categories.split(" ")
247
+ # categories_sorted = [
248
+ # categories[sorted_cosine_scores_items[index][0]]
249
+ # for index in range(len(sorted_cosine_scores_items))
250
+ # ]
251
+ # fig, ax = plt.subplots()
252
+ # ax.pie(sorted_cosine_scores, labels=categories_sorted, autopct="%1.1f%%")
253
+ # st.pyplot(fig) # Figure
254
 
255
 
256
  def plot_piechart_helper(sorted_cosine_scores_items):
257
+ sorted_cosine_scores = np.array(list(sorted_cosine_scores_items.values()))
258
+ categories_sorted = list(sorted_cosine_scores_items.keys())
259
+
 
 
 
 
 
 
 
 
260
  fig, ax = plt.subplots(figsize=(3, 3))
261
  my_explode = np.zeros(len(categories_sorted))
262
  my_explode[0] = 0.2
263
  if len(categories_sorted) == 3:
264
+ my_explode[1] = 0.1
265
  elif len(categories_sorted) > 3:
266
  my_explode[2] = 0.05
267
+
268
  ax.pie(
269
  sorted_cosine_scores,
270
  labels=categories_sorted,
 
307
 
308
 
309
  def plot_alatirchart(sorted_cosine_scores_models):
310
+
311
+
312
  models = list(sorted_cosine_scores_models.keys())
313
  tabs = st.tabs(models)
314
  figs = {}
315
  for model in models:
316
+ # modified
317
  figs[model] = plot_piechart_helper(sorted_cosine_scores_models[model])
318
 
319
  for index in range(len(tabs)):
 
321
  st.pyplot(figs[models[index]])
322
 
323
 
 
 
 
 
 
 
324
  ### Text Search ###
325
  st.sidebar.title("GloVe Twitter")
326
  st.sidebar.markdown(
 
333
  )
334
 
335
 
 
336
  if 'categories' not in st.session_state:
337
  st.session_state['categories'] = "Flowers Colors Cars Weather Food"
338
  if 'text_search' not in st.session_state:
339
  st.session_state['text_search'] = "Roses are red, trucks are blue, and Seattle is grey right now"
340
 
 
341
 
342
  model_type = st.sidebar.selectbox("Choose the model", ("25d", "50d"), index=1)
343
 
 
344
  st.title("Search Based Retrieval Demo")
345
  st.subheader(
346
  "Pass in space separated categories you want this search demo to be about."
 
351
  # )
352
 
353
 
354
+ # categories of user input
355
  user_categories = st.text_input(
356
  label="Categories", value=st.session_state.categories
357
  )
358
 
 
359
  st.session_state.categories = user_categories
360
 
361
  # st.text_input(
 
381
 
382
  )
383
 
 
384
  st.session_state.text_search = user_text_search
385
  # st.session_state.text_search = text_search
386
 
 
434
  + " as per different Embeddings"
435
  )
436
 
437
+ print(sorted_cosine_sim_glove)
438
+ print(sorted_cosine_sim_transformer)
439
+
440
+ st.write(f"Closest category using GloVe embeddings : {list(sorted_cosine_sim_glove.keys())[0]}")
 
 
 
 
 
 
 
 
 
 
 
441
  st.write(
442
+ f"Closest category using Sentence Transformer embeddings : {list(sorted_cosine_sim_transformer.keys())[0]}")
443
+
444
+ plot_alatirchart(
445
+ {
446
+ "glove_" + str(model_type): sorted_cosine_sim_glove,
447
+ "sentence_transformer_384": sorted_cosine_sim_transformer,
448
+ }
449
+ )
450
 
451
  st.write("")
452
  st.write(
453
+ "Demo developed by [V50](https://huggingface.co/spaces/ericlkc/V50)"
454
  )