Update app.py for Part4
Browse files
app.py
CHANGED
@@ -236,42 +236,35 @@ def get_sorted_cosine_similarity(embeddings_metadata):
|
|
236 |
|
237 |
return sorted_cosine_sim
|
238 |
|
239 |
-
|
240 |
-
def plot_piechart(sorted_cosine_scores_items):
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
|
255 |
|
256 |
def plot_piechart_helper(sorted_cosine_scores_items):
|
257 |
-
sorted_cosine_scores = np.array(
|
258 |
-
|
259 |
-
|
260 |
-
for index in range(len(sorted_cosine_scores_items))
|
261 |
-
]
|
262 |
-
)
|
263 |
-
categories = st.session_state.categories.split(" ")
|
264 |
-
categories_sorted = [
|
265 |
-
categories[sorted_cosine_scores_items[index][0]]
|
266 |
-
for index in range(len(sorted_cosine_scores_items))
|
267 |
-
]
|
268 |
fig, ax = plt.subplots(figsize=(3, 3))
|
269 |
my_explode = np.zeros(len(categories_sorted))
|
270 |
my_explode[0] = 0.2
|
271 |
if len(categories_sorted) == 3:
|
272 |
-
my_explode[1] = 0.1
|
273 |
elif len(categories_sorted) > 3:
|
274 |
my_explode[2] = 0.05
|
|
|
275 |
ax.pie(
|
276 |
sorted_cosine_scores,
|
277 |
labels=categories_sorted,
|
@@ -314,10 +307,13 @@ def plot_piecharts(sorted_cosine_scores_models):
|
|
314 |
|
315 |
|
316 |
def plot_alatirchart(sorted_cosine_scores_models):
|
|
|
|
|
317 |
models = list(sorted_cosine_scores_models.keys())
|
318 |
tabs = st.tabs(models)
|
319 |
figs = {}
|
320 |
for model in models:
|
|
|
321 |
figs[model] = plot_piechart_helper(sorted_cosine_scores_models[model])
|
322 |
|
323 |
for index in range(len(tabs)):
|
@@ -325,12 +321,6 @@ def plot_alatirchart(sorted_cosine_scores_models):
|
|
325 |
st.pyplot(figs[models[index]])
|
326 |
|
327 |
|
328 |
-
# 测试
|
329 |
-
|
330 |
-
import os
|
331 |
-
print("Current Working Directory:", os.getcwd())
|
332 |
-
|
333 |
-
|
334 |
### Text Search ###
|
335 |
st.sidebar.title("GloVe Twitter")
|
336 |
st.sidebar.markdown(
|
@@ -343,17 +333,14 @@ Jeffrey Pennington, Richard Socher, and Christopher D. Manning. 2014. *GloVe: Gl
|
|
343 |
)
|
344 |
|
345 |
|
346 |
-
# 初始化 Session State 变量
|
347 |
if 'categories' not in st.session_state:
|
348 |
st.session_state['categories'] = "Flowers Colors Cars Weather Food"
|
349 |
if 'text_search' not in st.session_state:
|
350 |
st.session_state['text_search'] = "Roses are red, trucks are blue, and Seattle is grey right now"
|
351 |
|
352 |
-
# ... [其余 Streamlit 代码]
|
353 |
|
354 |
model_type = st.sidebar.selectbox("Choose the model", ("25d", "50d"), index=1)
|
355 |
|
356 |
-
|
357 |
st.title("Search Based Retrieval Demo")
|
358 |
st.subheader(
|
359 |
"Pass in space separated categories you want this search demo to be about."
|
@@ -364,12 +351,11 @@ st.subheader(
|
|
364 |
# )
|
365 |
|
366 |
|
367 |
-
#
|
368 |
user_categories = st.text_input(
|
369 |
label="Categories", value=st.session_state.categories
|
370 |
)
|
371 |
|
372 |
-
# 更新 Session State 变量 - 修改的地方
|
373 |
st.session_state.categories = user_categories
|
374 |
|
375 |
# st.text_input(
|
@@ -395,7 +381,6 @@ user_text_search = st.text_input(
|
|
395 |
|
396 |
)
|
397 |
|
398 |
-
# 更新 Session State 变量 - 修改的地方
|
399 |
st.session_state.text_search = user_text_search
|
400 |
# st.session_state.text_search = text_search
|
401 |
|
@@ -449,25 +434,21 @@ if st.session_state.text_search:
|
|
449 |
+ " as per different Embeddings"
|
450 |
)
|
451 |
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
# plot_alatirchart(
|
457 |
-
# {
|
458 |
-
# "glove_" + str(model_type): sorted_cosine_sim_glove,
|
459 |
-
# "sentence_transformer_384": sorted_cosine_sim_transformer,
|
460 |
-
# }
|
461 |
-
# )
|
462 |
-
# "distilbert_512": sorted_distilbert})
|
463 |
-
|
464 |
-
# 修改的地方!
|
465 |
-
# Display the closest category result for GloVe and Sentence Transformer Embeddings
|
466 |
-
st.write(f"The closest category in GloVe embeddings is: {list(sorted_cosine_sim_glove.keys())[0]}")
|
467 |
st.write(
|
468 |
-
f"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
469 |
|
470 |
st.write("")
|
471 |
st.write(
|
472 |
-
"Demo developed by [
|
473 |
)
|
|
|
236 |
|
237 |
return sorted_cosine_sim
|
238 |
|
239 |
+
#
|
240 |
+
# def plot_piechart(sorted_cosine_scores_items):
|
241 |
+
# sorted_cosine_scores = np.array([
|
242 |
+
# sorted_cosine_scores_items[index][1]
|
243 |
+
# for index in range(len(sorted_cosine_scores_items))
|
244 |
+
# ]
|
245 |
+
# )
|
246 |
+
# categories = st.session_state.categories.split(" ")
|
247 |
+
# categories_sorted = [
|
248 |
+
# categories[sorted_cosine_scores_items[index][0]]
|
249 |
+
# for index in range(len(sorted_cosine_scores_items))
|
250 |
+
# ]
|
251 |
+
# fig, ax = plt.subplots()
|
252 |
+
# ax.pie(sorted_cosine_scores, labels=categories_sorted, autopct="%1.1f%%")
|
253 |
+
# st.pyplot(fig) # Figure
|
254 |
|
255 |
|
256 |
def plot_piechart_helper(sorted_cosine_scores_items):
|
257 |
+
sorted_cosine_scores = np.array(list(sorted_cosine_scores_items.values()))
|
258 |
+
categories_sorted = list(sorted_cosine_scores_items.keys())
|
259 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
260 |
fig, ax = plt.subplots(figsize=(3, 3))
|
261 |
my_explode = np.zeros(len(categories_sorted))
|
262 |
my_explode[0] = 0.2
|
263 |
if len(categories_sorted) == 3:
|
264 |
+
my_explode[1] = 0.1
|
265 |
elif len(categories_sorted) > 3:
|
266 |
my_explode[2] = 0.05
|
267 |
+
|
268 |
ax.pie(
|
269 |
sorted_cosine_scores,
|
270 |
labels=categories_sorted,
|
|
|
307 |
|
308 |
|
309 |
def plot_alatirchart(sorted_cosine_scores_models):
|
310 |
+
|
311 |
+
|
312 |
models = list(sorted_cosine_scores_models.keys())
|
313 |
tabs = st.tabs(models)
|
314 |
figs = {}
|
315 |
for model in models:
|
316 |
+
# modified
|
317 |
figs[model] = plot_piechart_helper(sorted_cosine_scores_models[model])
|
318 |
|
319 |
for index in range(len(tabs)):
|
|
|
321 |
st.pyplot(figs[models[index]])
|
322 |
|
323 |
|
|
|
|
|
|
|
|
|
|
|
|
|
324 |
### Text Search ###
|
325 |
st.sidebar.title("GloVe Twitter")
|
326 |
st.sidebar.markdown(
|
|
|
333 |
)
|
334 |
|
335 |
|
|
|
336 |
if 'categories' not in st.session_state:
|
337 |
st.session_state['categories'] = "Flowers Colors Cars Weather Food"
|
338 |
if 'text_search' not in st.session_state:
|
339 |
st.session_state['text_search'] = "Roses are red, trucks are blue, and Seattle is grey right now"
|
340 |
|
|
|
341 |
|
342 |
model_type = st.sidebar.selectbox("Choose the model", ("25d", "50d"), index=1)
|
343 |
|
|
|
344 |
st.title("Search Based Retrieval Demo")
|
345 |
st.subheader(
|
346 |
"Pass in space separated categories you want this search demo to be about."
|
|
|
351 |
# )
|
352 |
|
353 |
|
354 |
+
# categories of user input
|
355 |
user_categories = st.text_input(
|
356 |
label="Categories", value=st.session_state.categories
|
357 |
)
|
358 |
|
|
|
359 |
st.session_state.categories = user_categories
|
360 |
|
361 |
# st.text_input(
|
|
|
381 |
|
382 |
)
|
383 |
|
|
|
384 |
st.session_state.text_search = user_text_search
|
385 |
# st.session_state.text_search = text_search
|
386 |
|
|
|
434 |
+ " as per different Embeddings"
|
435 |
)
|
436 |
|
437 |
+
print(sorted_cosine_sim_glove)
|
438 |
+
print(sorted_cosine_sim_transformer)
|
439 |
+
|
440 |
+
st.write(f"Closest category using GloVe embeddings : {list(sorted_cosine_sim_glove.keys())[0]}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
441 |
st.write(
|
442 |
+
f"Closest category using Sentence Transformer embeddings : {list(sorted_cosine_sim_transformer.keys())[0]}")
|
443 |
+
|
444 |
+
plot_alatirchart(
|
445 |
+
{
|
446 |
+
"glove_" + str(model_type): sorted_cosine_sim_glove,
|
447 |
+
"sentence_transformer_384": sorted_cosine_sim_transformer,
|
448 |
+
}
|
449 |
+
)
|
450 |
|
451 |
st.write("")
|
452 |
st.write(
|
453 |
+
"Demo developed by [V50](https://huggingface.co/spaces/ericlkc/V50)"
|
454 |
)
|