edugp commited on
Commit
5aad559
·
1 Parent(s): 9447eb8

Add embedding lenses app

Browse files
Files changed (2) hide show
  1. app.py +94 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Any, List, Optional
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+ import streamlit as st
7
+ from bokeh.models import ColumnDataSource, HoverTool
8
+ from bokeh.palettes import Cividis256 as Pallete
9
+ from bokeh.plotting import figure
10
+ from bokeh.transform import factor_cmap
11
+ from sentence_transformers import SentenceTransformer
12
+ from sklearn.manifold import TSNE
13
+
14
+ logging.basicConfig(level=logging.INFO)
15
+ logger = logging.getLogger(__name__)
16
+ SEED = 0
17
+
18
+
19
+ @st.cache(show_spinner=False)
20
+ def load_model():
21
+ embedder = "distiluse-base-multilingual-cased-v1"
22
+ return SentenceTransformer(embedder)
23
+
24
+
25
+ def embed_text(text: List[str]) -> np.ndarray:
26
+ embedder_model = load_model()
27
+ return embedder_model.encode(text)
28
+
29
+
30
+ def encode_labels(labels: pd.Series) -> pd.Series:
31
+ if pd.api.types.is_numeric_dtype(labels):
32
+ return labels
33
+ return labels.astype("category").cat.codes
34
+
35
+
36
+ def get_tsne_embeddings(
37
+ embeddings: np.ndarray, perplexity: int = 30, n_components: int = 2, init: str = "pca", n_iter: int = 5000, random_state: int = SEED
38
+ ) -> np.ndarray:
39
+ tsne = TSNE(perplexity=perplexity, n_components=n_components, init=init, n_iter=n_iter, random_state=random_state)
40
+ return tsne.fit_transform(embeddings)
41
+
42
+
43
+ def draw_interactive_scatter_plot(
44
+ texts: np.ndarray, xs: np.ndarray, ys: np.ndarray, values: np.ndarray, labels: np.ndarray, text_column: str, label_column: str
45
+ ) -> Any:
46
+ # Normalize values to range between 0-255, to assign a color for each value
47
+ max_value = values.max()
48
+ min_value = values.min()
49
+ values_color = ((values - min_value) / (max_value - min_value) * 255).round().astype(int).astype(str)
50
+ values_color_set = sorted(values_color)
51
+
52
+ values_list = values.astype(str).tolist()
53
+ values_set = sorted(values_list)
54
+ labels_list = labels.astype(str).tolist()
55
+
56
+ source = ColumnDataSource(data=dict(x=xs, y=ys, text=texts, label=values_list, original_label=labels_list))
57
+ hover = HoverTool(tooltips=[(text_column, "@text{safe}"), (label_column, "@original_label")])
58
+ p = figure(plot_width=800, plot_height=800, tools=[hover], title="Embedding Lenses")
59
+ p.circle("x", "y", size=10, source=source, fill_color=factor_cmap("label", palette=[Pallete[int(id_)] for id_ in values_color_set], factors=values_set))
60
+ return p
61
+
62
+
63
+ def generate_plot(tsv: st.uploaded_file_manager.UploadedFile, text_column: str, label_column: str, sample: Optional[int]):
64
+ logger.info("Loading dataset in memory")
65
+ df = pd.read_csv(tsv, sep="\t")
66
+ if label_column not in df.columns:
67
+ df[label_column] = 0
68
+ df = df.dropna(subset=[text_column, label_column])
69
+ if sample:
70
+ df = df.sample(min(sample, df.shape[0]), random_state=SEED)
71
+ logger.info("Embedding sentences")
72
+ embeddings = embed_text(df[text_column].values.tolist())
73
+ logger.info("Encoding labels")
74
+ encoded_labels = encode_labels(df[label_column])
75
+ logger.info("Running t-SNE")
76
+ tsne_embeddings = get_tsne_embeddings(embeddings)
77
+ logger.info("Generating figure")
78
+ plot = draw_interactive_scatter_plot(
79
+ df[text_column].values, tsne_embeddings[:, 0], tsne_embeddings[:, 1], encoded_labels.values, df[label_column].values, text_column, label_column
80
+ )
81
+ return plot
82
+
83
+
84
+ st.title("Embedding Lenses")
85
+ uploaded_file = st.file_uploader("Choose an csv/tsv file...", type=["csv", "tsv"])
86
+ text_column = st.text_input("Text column name", "text")
87
+ label_column = st.text_input("Numerical/categorical column name (ignore if not applicable)", "label")
88
+ sample = st.number_input("Maximum number of documents to use", 1, 100000, 1000)
89
+
90
+ if uploaded_file:
91
+ plot = generate_plot(uploaded_file, text_column, label_column, sample)
92
+ logger.info("Displaying plot")
93
+ st.bokeh_chart(plot)
94
+ logger.info("Done")
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ huggingface-hub==0.0.12
2
+ streamlit==0.84.1
3
+ transformers==4.8.2
4
+ watchdog==2.1.3
5
+ sentence-transformers==2.0.0
6
+ bokeh==2.2.2