File size: 5,200 Bytes
f040401
 
 
 
7eae2ec
25a6734
7964bf4
25a6734
f040401
 
 
 
 
 
 
25a6734
a5482a3
 
 
 
 
 
f040401
a5482a3
 
d548156
 
 
 
 
f040401
 
 
 
1731f45
f040401
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aecdaab
 
f040401
 
 
 
 
 
 
 
9c5d67e
f040401
9c5d67e
f040401
9c5d67e
 
a5482a3
20c61de
2d7af1c
bff1960
20c61de
 
7964bf4
9c5d67e
f040401
 
9c5d67e
f040401
 
 
d548156
1bbc870
 
9c5d67e
0a699f8
 
1bbc870
0a699f8
f040401
61a29bd
 
 
 
 
 
 
 
 
 
 
d548156
bff1960
f040401
d548156
691ccbe
dbebcc0
 
f040401
 
 
 
59ff44b
f040401
 
5c585a4
59ff44b
f040401
92e2ae9
 
f040401
9c5d67e
5e899e5
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from typing import List

import numpy as np

import streamlit as st
import tweepy
import hdbscan

from bokeh.models import ColumnDataSource, HoverTool
from bokeh.palettes import Cividis256 as Pallete
from bokeh.plotting import Figure, figure
from bokeh.transform import factor_cmap

from sklearn.manifold import TSNE
from sentence_transformers import SentenceTransformer

client = tweepy.Client(bearer_token=st.secrets["tw_bearer_token"])
model_to_use = {
    "English": "all-MiniLM-L12-v2",
    "Use all the ones you know (~15 lang)": "paraphrase-multilingual-MiniLM-L12-v2"
}

# Original implementation from: https://huggingface.co/spaces/edugp/embedding-lenses/blob/main/app.py
SEED = 42

@st.cache(show_spinner=False, allow_output_mutation=True)
def load_model(model_name: str) -> SentenceTransformer:
    embedder = model_name
    return SentenceTransformer(embedder)

def embed_text(text: List[str], model: SentenceTransformer) -> np.ndarray:
    return model.encode(text)

def get_tsne_embeddings(
    embeddings: np.ndarray, perplexity: int = 10, n_components: int = 2, init: str = "pca", n_iter: int = 5000, random_state: int = SEED
) -> np.ndarray:
    tsne = TSNE(perplexity=perplexity, n_components=n_components, init=init, n_iter=n_iter, random_state=random_state)
    return tsne.fit_transform(embeddings)

def draw_interactive_scatter_plot(
    texts: np.ndarray, xs: np.ndarray, ys: np.ndarray, values: np.ndarray, labels: np.ndarray, text_column: str, label_column: str
) -> Figure:
    # Normalize values to range between 0-255, to assign a color for each value
    max_value = values.max()
    min_value = values.min()
    if max_value - min_value == 0:
        values_color = np.ones(len(values))
    else:
        values_color = ((values - min_value) / (max_value - min_value) * 255).round().astype(int).astype(str)
    values_color_set = sorted(values_color)
    values_list = values.astype(str).tolist()
    values_set = sorted(values_list)
    labels_list = labels.astype(str).tolist()
    source = ColumnDataSource(data=dict(x=xs, y=ys, text=texts, label=values_list, original_label=labels_list))
    hover = HoverTool(tooltips=[(text_column, "@text{safe}"), (label_column, "@original_label")])
    p = figure(plot_width=800, plot_height=800, tools=[hover])
    colors = factor_cmap("label", palette=[Pallete[int(id_)] for id_ in values_color_set], factors=values_set)
    p.circle("x", "y", size=10, source=source, fill_alpha=0.4, line_color=colors, fill_color=colors)
    p.axis.visible = False
    p.xgrid.grid_line_color = None
    p.ygrid.grid_line_color = None
    p.toolbar.logo = None
    return p

# Up to here
def generate_plot(
    tws: List[str],
    model: SentenceTransformer,
    tw_user: str
) -> Figure:
    with st.spinner(text=f"Trying to understand '{tw_user}' tweets..."):
        embeddings = embed_text(tws, model)
    # encoded_labels = encode_labels(labels)
    cluster = hdbscan.HDBSCAN(
        min_cluster_size=3,
        metric='euclidean',
        cluster_selection_method='eom'
    ).fit(embeddings)
    encoded_labels = cluster.labels_
    with st.spinner("Now trying to express them with my own words..."):
        embeddings_2d = get_tsne_embeddings(embeddings)
    plot = draw_interactive_scatter_plot(
        tws, embeddings_2d[:, 0], embeddings_2d[:, 1], encoded_labels, encoded_labels, 'text', 'label'
    )
    return plot


st.title("Tweet-SNEst")
st.write("Visualize tweets embeddings in 2D using colors for topics labels.")
st.caption('Please beware this is using Twitter free version of their API and might be needed to wait sometimes.')
col1, col2 = st.columns(2)
with col1:
    tw_user = st.text_input("Twitter handle", "huggingface")
with col2:
    tw_sample = st.number_input("Maximum number of tweets to use", 1, 300, 100, 10)

col1, col2 = st.columns(2)

with col1:
    expected_lang = st.radio(
        "What language should be assumed to be found?",
        ('English', 'Use all the ones you know (~15 lang)'),
        0
    )
with col2:
    go_btn = st.button('Visualize')

with st.spinner(text="Grabbing lenses..."):
    model = load_model(model_to_use[expected_lang])

if go_btn and tw_user != '':
    usr = client.get_user(username=tw_user)
    tw_user = tw_user.replace(' ', '')
    with st.spinner(f"Getting to know the '{tw_user}'..."):
        tweets_objs = []
        while tw_sample >= 100:
            current_sample = min(100, tw_sample)
            tweets_response = client.get_users_tweets(usr.data.id, max_results=current_sample, exclude=['retweets', 'replies'])
            tweets_objs += tweets_response.data
            tw_sample -= current_sample
        if tw_sample > 0:
            tweets_response = client.get_users_tweets(usr.data.id, max_results=tw_sample, exclude=['retweets', 'replies'])
        tweets_objs += tweets_response.data
    tweets_txt = [tweet.text for tweet in tweets_objs]
    tweets_txt = list(set(tweets_txt))
    # plot = generate_plot(df, text_column, label_column, sample, dimensionality_reduction_function, model)
    plot = generate_plot(tweets_txt, model, tw_user)
    st.bokeh_chart(plot)
elif go_btn and tw_user == '':
    st.warning('Twitter handler field is empty πŸ™„')