Spaces:
Build error
Build error
feat: Add experimental version of plot after adaption from referred space
Browse files
app.py
CHANGED
@@ -1,13 +1,75 @@
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
import tweepy
|
3 |
-
from sentence_transformers import SentenceTransformer
|
4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
|
|
6 |
@st.cache(show_spinner=False, allow_output_mutation=True)
|
7 |
def load_model(model_name: str) -> SentenceTransformer:
|
8 |
embedder = model_name
|
9 |
return SentenceTransformer(embedder)
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
client = tweepy.Client(bearer_token=st.secrets["tw_bearer_token"])
|
13 |
model_to_use = {
|
@@ -21,14 +83,35 @@ col1, col2 = st.columns(2)
|
|
21 |
with col1:
|
22 |
tw_user = st.text_input("Twitter handle", "huggingface")
|
23 |
with col2:
|
24 |
-
|
25 |
|
26 |
expected_lang = st.radio(
|
27 |
"What language should be assumed to be found?",
|
28 |
-
('English', 'Use all the ones you know (~15 lang)')
|
|
|
|
|
29 |
|
|
|
|
|
30 |
|
31 |
usr = client.get_user(username=tw_user)
|
32 |
|
33 |
-
st.write(usr.data.id)
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
import streamlit as st
|
6 |
import tweepy
|
|
|
7 |
|
8 |
+
from bokeh.models import ColumnDataSource, HoverTool
|
9 |
+
from bokeh.palettes import Cividis256 as Pallete
|
10 |
+
from bokeh.plotting import Figure, figure
|
11 |
+
from bokeh.transform import factor_cmap
|
12 |
+
|
13 |
+
from sklearn.manifold import TSNE
|
14 |
+
from sentence_transformers import SentenceTransformer
|
15 |
|
16 |
+
# Original implementation from: https://huggingface.co/spaces/edugp/embedding-lenses/blob/main/app.py
|
17 |
@st.cache(show_spinner=False, allow_output_mutation=True)
|
18 |
def load_model(model_name: str) -> SentenceTransformer:
|
19 |
embedder = model_name
|
20 |
return SentenceTransformer(embedder)
|
21 |
|
22 |
+
def embed_text(text: List[str], model: SentenceTransformer) -> np.ndarray:
|
23 |
+
return model.encode(text)
|
24 |
+
|
25 |
+
def get_tsne_embeddings(
|
26 |
+
embeddings: np.ndarray, perplexity: int = 30, n_components: int = 2, init: str = "pca", n_iter: int = 5000, random_state: int = SEED
|
27 |
+
) -> np.ndarray:
|
28 |
+
tsne = TSNE(perplexity=perplexity, n_components=n_components, init=init, n_iter=n_iter, random_state=random_state)
|
29 |
+
return tsne.fit_transform(embeddings)
|
30 |
+
|
31 |
+
def draw_interactive_scatter_plot(
|
32 |
+
texts: np.ndarray, xs: np.ndarray, ys: np.ndarray, values: np.ndarray, labels: np.ndarray, text_column: str, label_column: str
|
33 |
+
) -> Figure:
|
34 |
+
# Normalize values to range between 0-255, to assign a color for each value
|
35 |
+
max_value = values.max()
|
36 |
+
min_value = values.min()
|
37 |
+
if max_value - min_value == 0:
|
38 |
+
values_color = np.ones(len(values))
|
39 |
+
else:
|
40 |
+
values_color = ((values - min_value) / (max_value - min_value) * 255).round().astype(int).astype(str)
|
41 |
+
values_color_set = sorted(values_color)
|
42 |
+
values_list = values.astype(str).tolist()
|
43 |
+
values_set = sorted(values_list)
|
44 |
+
labels_list = labels.astype(str).tolist()
|
45 |
+
source = ColumnDataSource(data=dict(x=xs, y=ys, text=texts, label=values_list, original_label=labels_list))
|
46 |
+
hover = HoverTool(tooltips=[(text_column, "@text{safe}"), (label_column, "@original_label")])
|
47 |
+
p = figure(plot_width=800, plot_height=800, tools=[hover])
|
48 |
+
p.circle("x", "y", size=10, source=source, fill_color=factor_cmap("label", palette=[Pallete[int(id_)] for id_ in values_color_set], factors=values_set))
|
49 |
+
p.axis.visible = False
|
50 |
+
p.xgrid.grid_line_color = None
|
51 |
+
p.ygrid.grid_line_color = None
|
52 |
+
p.toolbar.logo = None
|
53 |
+
return p
|
54 |
+
|
55 |
+
# Up to here
|
56 |
+
def generate_plot(
|
57 |
+
df: List[str],
|
58 |
+
labels: List[int],
|
59 |
+
model: SentenceTransformer,
|
60 |
+
) -> Figure:
|
61 |
+
with st.spinner(text="Embedding text..."):
|
62 |
+
embeddings = embed_text(df, model)
|
63 |
+
logger.info("Encoding labels")
|
64 |
+
encoded_labels = encode_labels(labels)
|
65 |
+
with st.spinner("Reducing dimensionality..."):
|
66 |
+
embeddings_2d = get_tsne_embeddings(embeddings)
|
67 |
+
logger.info("Generating figure")
|
68 |
+
plot = draw_interactive_scatter_plot(
|
69 |
+
df, embeddings_2d[:, 0], embeddings_2d[:, 1], encoded_labels.values, labels, 'text', 'label'
|
70 |
+
)
|
71 |
+
return plot
|
72 |
+
|
73 |
|
74 |
client = tweepy.Client(bearer_token=st.secrets["tw_bearer_token"])
|
75 |
model_to_use = {
|
|
|
83 |
with col1:
|
84 |
tw_user = st.text_input("Twitter handle", "huggingface")
|
85 |
with col2:
|
86 |
+
tw_sample = st.number_input("Maximum number of tweets to use", 1, 300, 100, 10)
|
87 |
|
88 |
expected_lang = st.radio(
|
89 |
"What language should be assumed to be found?",
|
90 |
+
('English', 'Use all the ones you know (~15 lang)'),
|
91 |
+
0
|
92 |
+
)
|
93 |
|
94 |
+
with st.spinner(text="Loading model..."):
|
95 |
+
model = load_model(model_to_use[expected_lang])
|
96 |
|
97 |
usr = client.get_user(username=tw_user)
|
98 |
|
99 |
+
# st.write(usr.data.id)
|
100 |
+
|
101 |
+
if tw_user:
|
102 |
+
with st.spinner(f"Getting to know the '{tw_user}'..."):
|
103 |
+
tweets_objs = []
|
104 |
+
while tw_sample >= 100:
|
105 |
+
current_sample = min(100, tw_sample)
|
106 |
+
tweets_response = client.get_user_tweets(usr.data.id, max_results=current_sample)
|
107 |
+
tweets_objs += tweets_response.data
|
108 |
+
tw_sample -= current_sample
|
109 |
+
tweets_response = client.get_user_tweets(usr.data.id, max_results=tw_sample)
|
110 |
+
tweets_objs += tweets_response.data
|
111 |
+
tweets_txt = [tweet.text for tweet in tweets_objs]
|
112 |
+
labels = [0] * len(tweets_txt)
|
113 |
+
# plot = generate_plot(df, text_column, label_column, sample, dimensionality_reduction_function, model)
|
114 |
+
plot = generate_plot(tweets_txt, labels, model)
|
115 |
+
logger.info("Displaying plot")
|
116 |
+
st.bokeh_chart(plot)
|
117 |
+
logger.info("Done")
|