Spaces:
Runtime error
Runtime error
Commit
·
384d3d8
1
Parent(s):
2574e04
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
A Streamlit application to visualize sentence embeddings
|
3 |
+
Author: Mohit Mayank
|
4 |
+
Contact: [email protected]
|
5 |
+
"""
|
6 |
+
|
7 |
+
## Import
|
8 |
+
## ----------------
|
9 |
+
# data
|
10 |
+
import pandas as pd
|
11 |
+
# model
|
12 |
+
from sentence_transformers import SentenceTransformer, util
|
13 |
+
# viz
|
14 |
+
import streamlit as st
|
15 |
+
import plotly.express as px
|
16 |
+
# DR
|
17 |
+
from sklearn.decomposition import PCA
|
18 |
+
from sklearn.manifold import TSNE
|
19 |
+
|
20 |
+
## Init
|
21 |
+
## ----------------
|
22 |
+
# set config
|
23 |
+
# st.set_page_config(layout="wide", page_title="SentenceViz 🕵")
|
24 |
+
st.markdown("# SentenceViz")
|
25 |
+
st.markdown("A Streamlit application to visulize sentence embeddings")
|
26 |
+
|
27 |
+
# load the summarization model (cache for faster loading)
|
28 |
+
@st.cache(allow_output_mutation=True)
|
29 |
+
def load_similarity_model(model_name='all-MiniLM-L6-v2'):
|
30 |
+
model = SentenceTransformer(model_name)
|
31 |
+
return model
|
32 |
+
|
33 |
+
@st.cache(allow_output_mutation=True)
|
34 |
+
def perform_embedding(df, text_col_name):
|
35 |
+
embeddings = model.encode(df[text_col_name])
|
36 |
+
return embeddings
|
37 |
+
|
38 |
+
# gloabl vars
|
39 |
+
df = None
|
40 |
+
model = None
|
41 |
+
embeddings = None
|
42 |
+
|
43 |
+
## Design Sidebar
|
44 |
+
## -----------------
|
45 |
+
## Data
|
46 |
+
st.sidebar.markdown("## Data")
|
47 |
+
uploaded_file = st.sidebar.file_uploader("Upload a CSV file with sentences (we remove NaN)")
|
48 |
+
if uploaded_file is not None:
|
49 |
+
progress = st.empty()
|
50 |
+
progress.text("Reading file...")
|
51 |
+
df = pd.read_csv(uploaded_file).dropna().reset_index(drop=True)
|
52 |
+
progress.text(f"Reading file...Done! Size: {df.shape[0]}")
|
53 |
+
|
54 |
+
## Embedding
|
55 |
+
st.sidebar.markdown("## Embedding")
|
56 |
+
supported_models = ['all-MiniLM-L6-v2', 'paraphrase-albert-small-v2', 'paraphrase-MiniLM-L3-v2', 'all-distilroberta-v1', 'all-mpnet-base-v2']
|
57 |
+
selected_model_option = st.sidebar.selectbox("Select Model:", supported_models)
|
58 |
+
text_col_name = st.sidebar.text_input("Text column to embed")
|
59 |
+
if len(text_col_name) > 0 and df is not None:
|
60 |
+
print("text_col_name -->", text_col_name)
|
61 |
+
df[text_col_name] = df[text_col_name].str.wrap(30)
|
62 |
+
df[text_col_name] = df[text_col_name].apply(lambda x: x.replace('\n', '<br>'))
|
63 |
+
progress = st.empty()
|
64 |
+
progress.text("Creating embedding...")
|
65 |
+
model = load_similarity_model(selected_model_option)
|
66 |
+
embeddings = perform_embedding(df, text_col_name)
|
67 |
+
progress.text("Creating embedding...Done!")
|
68 |
+
|
69 |
+
## Visualization
|
70 |
+
st.sidebar.markdown("## Visualization")
|
71 |
+
dr_algo = st.sidebar.selectbox("Dimensionality Reduction Algorithm", ('PCA', 't-SNE'))
|
72 |
+
color_col = st.sidebar.text_input("Color using this col")
|
73 |
+
if len(color_col.strip()) == 0:
|
74 |
+
color_col = None
|
75 |
+
|
76 |
+
if st.sidebar.button('Plot!'):
|
77 |
+
# get the embeddings and perform DR
|
78 |
+
if dr_algo == 'PCA':
|
79 |
+
pca = PCA(n_components=2)
|
80 |
+
reduced_embeddings = pca.fit_transform(embeddings)
|
81 |
+
elif dr_algo == 't-SNE':
|
82 |
+
tsne = TSNE(n_components=2)
|
83 |
+
reduced_embeddings = tsne.fit_transform(embeddings)
|
84 |
+
|
85 |
+
# modify the df
|
86 |
+
# df['complete_embeddings'] = embeddings
|
87 |
+
df['viz_embeddings_x'] = reduced_embeddings[:, 0]
|
88 |
+
df['viz_embeddings_y'] = reduced_embeddings[:, 1]
|
89 |
+
|
90 |
+
# plot the data
|
91 |
+
fig = px.scatter(df, x='viz_embeddings_x', y='viz_embeddings_y',
|
92 |
+
title=f'"{dr_algo}" on {df.shape[0]} "{selected_model_option}" embeddings',
|
93 |
+
color=color_col, hover_data=[text_col_name])
|
94 |
+
fig.update_layout(yaxis={'visible': False, 'showticklabels': False})
|
95 |
+
fig.update_layout(xaxis={'visible': False, 'showticklabels': False})
|
96 |
+
fig.update_traces(marker=dict(size=10, opacity=0.7, line=dict(width=1,color='DarkSlateGrey')),selector=dict(mode='markers'))
|
97 |
+
st.plotly_chart(fig, use_container_width=True)
|