File size: 9,996 Bytes
f040401
 
7f1a60e
53bad35
b2c3406
8b81843
25a6734
7964bf4
8b81843
b2c3406
 
 
53bad35
6028fe6
5b5b795
 
f040401
 
 
 
c7010dd
25a6734
a5482a3
 
f0f3e26
a5482a3
 
 
8b81843
 
 
 
d1cba0c
b2c3406
 
 
76a8518
b2c3406
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d2c094
b2c3406
4d8d3df
 
8b81843
4d8d3df
b2c3406
d556203
6e3e890
4d8d3df
 
b2c3406
 
 
 
 
 
 
f040401
a5482a3
 
d548156
 
 
 
 
f040401
 
 
 
1731f45
f040401
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b05a76
f040401
5b5b795
 
 
085853b
a27bca6
2c7ffd9
7f9ee2f
1323e99
f040401
61ce655
 
d71c853
 
2c7ffd9
51f54ce
ce50e18
 
 
2f7dfa0
71a7e2c
 
 
 
60bdc33
f040401
 
 
 
9c5d67e
0d13483
f040401
9c5d67e
f040401
30f3de6
0d13483
f0f3e26
 
 
 
 
 
7964bf4
7f1a60e
30f3de6
7f1a60e
da31fda
ce50e18
1323e99
7f1a60e
 
4b206d5
7f1a60e
 
1dcaf3f
4b206d5
7f1a60e
 
4b206d5
401a74f
7f1a60e
 
 
1c4df76
fe1cb2a
7f1a60e
dded833
 
 
1f3e17c
fe1cb2a
578e511
 
6319da9
 
 
 
7f1a60e
f040401
 
7f1a60e
f040401
 
 
d548156
1bbc870
 
9c5d67e
0a699f8
 
1bbc870
0a699f8
f040401
61a29bd
 
 
 
 
 
 
 
 
 
a38bdb0
d548156
30f3de6
d099a9b
 
 
 
d548156
691ccbe
dbebcc0
1ef5823
30f3de6
f040401
 
 
59ff44b
f040401
 
5c585a4
59ff44b
1ef5823
92e2ae9
 
0d13483
 
5e899e5
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
from typing import List

import itertools
import string
import re
import requests
import tweepy
import hdbscan

import numpy as np
import streamlit as st

from gensim.utils import deaccent
from bokeh.models import ColumnDataSource, HoverTool, Label, Legend
from bokeh.palettes import Colorblind as Pallete
from bokeh.palettes import Set3 as AuxPallete
from bokeh.plotting import Figure, figure
from bokeh.transform import factor_cmap

from sklearn.manifold import TSNE
from sentence_transformers import SentenceTransformer, util

client = tweepy.Client(bearer_token=st.secrets["tw_bearer_token"])
model_to_use = {
    "English": "all-MiniLM-L6-v2",
    "Use all the ones you know (~15 lang)": "paraphrase-multilingual-MiniLM-L12-v2"
}


stopwords_list = requests.get("https://gist.githubusercontent.com/rg089/35e00abf8941d72d419224cfd5b5925d/raw/12d899b70156fd0041fa9778d657330b024b959c/stopwords.txt").content
stopwords = set(stopwords_list.decode().splitlines()) 

def _remove_unk_chars(txt_list: List[str]):
    txt_list = [re.sub('\s+', ' ', tweet) for tweet in txt_list]
    txt_list = [re.sub("\'", "", tweet) for tweet in txt_list]
    txt_list = [deaccent(tweet).lower() for tweet in txt_list]
    return txt_list

def _remove_urls(txt_list: List[str]):
    url_regex = re.compile(
        r'^(?:http|ftp)s?://' # http:// or https://
        r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|' #domain...
        r'localhost|' #localhost...
        r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # ...or ip
        r'(?::\d+)?' # optional port
        r'(?:/?|[/?]\S+)$', re.IGNORECASE)
    txt_list = [tweet.split(' ') for tweet in txt_list]
    return [' '.join([word for word in tweet if not bool(re.match(url_regex, word))]) for tweet in txt_list]

def _remove_punctuation(txt_list: List[str]):
    punctuation = string.punctuation + 'ΒΏΒ‘|'
    txt_list = [tweet.split(' ') for tweet in txt_list]
    return [' '.join([word.translate(str.maketrans('', '', punctuation)) for word in tweet]) for tweet in txt_list]

def _remove_stopwords(txt_list: List[str]):
    txt_list = [tweet.split(' ') for tweet in txt_list]
    return [' '.join([word for word in tweet if word not in stopwords]) for tweet in txt_list]

preprocess_pipeline = [
    _remove_unk_chars,
    _remove_urls,
    _remove_punctuation,
    _remove_stopwords,
]

def preprocess(txt_list: str):
    for op in preprocess_pipeline:
        txt_list = op(txt_list)
    return txt_list

# Original implementation from: https://huggingface.co/spaces/edugp/embedding-lenses/blob/main/app.py
SEED = 42

@st.cache(show_spinner=False, allow_output_mutation=True)
def load_model(model_name: str) -> SentenceTransformer:
    embedder = model_name
    return SentenceTransformer(embedder)

def embed_text(text: List[str], model: SentenceTransformer) -> np.ndarray:
    return model.encode(text)

def get_tsne_embeddings(
    embeddings: np.ndarray, perplexity: int = 10, n_components: int = 2, init: str = "pca", n_iter: int = 5000, random_state: int = SEED
) -> np.ndarray:
    tsne = TSNE(perplexity=perplexity, n_components=n_components, init=init, n_iter=n_iter, random_state=random_state)
    return tsne.fit_transform(embeddings)

def draw_interactive_scatter_plot(
    texts: np.ndarray, xs: np.ndarray, ys: np.ndarray, values: np.ndarray, labels: np.ndarray, text_column: str, label_column: str
) -> Figure:
    # Normalize values to range between 0-255, to assign a color for each value
    max_value = values.max()
    min_value = values.min()
    if max_value - min_value == 0:
        values_color = np.ones(len(values))
    else:
        values_color = ((values - min_value) / (max_value - min_value) * 255).round().astype(int).astype(str)
    values_color_set = sorted(values_color)
    values_list = values.astype(str).tolist()
    values_set = sorted(values_list)
    source = ColumnDataSource(data=dict(x=xs, y=ys, text=texts, label=values_list, original_label=labels))
    hover = HoverTool(tooltips=[(text_column, "@text{safe}"), (label_column, "@original_label")])
    n_colors = len(set(values_color_set))
    if n_colors not in Pallete:
        Palette = AuxPallete
    p = figure(plot_width=800, plot_height=800, tools=[hover], title='2D visualization of tweets', background_fill_color="#fafafa")
    colors = factor_cmap("label", palette=[Pallete[n_colors][int(id_) + 1] for id_ in values_set], factors=values_set)
    
    p.add_layout(Legend(location='top_left', title='Topics keywords', background_fill_alpha=0.2), 'above')  
    p.circle("x", "y", size=12, source=source, fill_alpha=0.4, line_color=colors, fill_color=colors, legend_group="original_label")
    p.axis.visible = False
    p.xgrid.grid_line_dash = "dashed"
    p.ygrid.grid_line_dash = "dashed"
    # p.xgrid.grid_line_color = None
    # p.ygrid.grid_line_color = None
    p.toolbar.logo = None

    # p.legend.location = "bottom_right"
    # p.legend.title = "Topics ID"
    # p.legend.background_fill_alpha = 0.25
    
    # disclaimer = Label(x=0, y=0, x_units="screen", y_units="screen",
    #               text_font_size="14px", text_color="gray",
    #               text="Topic equals -1 means no topic was detected for such tweet")
    # p.add_layout(disclaimer, "below")

    return p

# Up to here
def generate_plot(
    tws: List[str],
    tws_cleaned: List[str],
    model: SentenceTransformer,
    tw_user: str
) -> Figure:
    with st.spinner(text=f"Trying to understand '{tw_user}' tweets... πŸ€”"):
        embeddings = embed_text(tws_cleaned, model)
        # encoded_labels = encode_labels(labels)
        cluster = hdbscan.HDBSCAN(
            min_cluster_size=3,
            metric='euclidean',
            cluster_selection_method='eom'
        ).fit(embeddings)
    encoded_labels = cluster.labels_
    cluster_keyword = {}
    with st.spinner("Now trying to express them with my own words... πŸ’¬"):
        for label in set(encoded_labels):
            if label == -1:
                cluster_keyword[label] = 'Too diverse!'
                continue
            cluster_keyword[label] = []
            cluster_tws = []
            cluster_ixs = []
            for ix, obs in enumerate(encoded_labels):
                if obs == label:
                    cluster_tws.append(tws_cleaned[ix])
                    cluster_ixs.append(ix)
            cluster_words = [tw.split(' ') for tw in cluster_tws]
            cluster_words = list(set(itertools.chain.from_iterable(cluster_words)))
            # cluster_embeddings = embed_text(cluster_tws, model)
            cluster_embeddings = [embeddings[i] for i in cluster_ixs]
            cluster_embeddings_avg = np.mean(cluster_embeddings, axis=0)
            cluster_words_embeddings = embed_text(cluster_words, model)
            cluster_to_words_similarities = util.dot_score(cluster_embeddings_avg, cluster_words_embeddings)
            cluster_to_words_similarities = [(word_ix, similarity) for word_ix, similarity in enumerate(cluster_to_words_similarities[0])]
            cluster_to_words_similarities = sorted(cluster_to_words_similarities, key=lambda x: x[1], reverse=True)
            while len(cluster_keyword[label]) < 3:
                try:
                    most_descriptive = cluster_to_words_similarities.pop(0)
                except IndexError:
                    break
                cluster_keyword[label].append(cluster_words[most_descriptive[0]])
            if len(cluster_keyword[label]) == 1:
                cluster_keyword[label] = cluster_keyword[label][0]
            elif len(cluster_keyword[label]) == 0:
                cluster_keyword[label] = '-'
            elif len(cluster_keyword[label]) > 1:
                cluster_keyword[label] = ', '.join(cluster_keyword[label])
        encoded_labels_keywords = [cluster_keyword[encoded_label] for encoded_label in encoded_labels]
        embeddings_2d = get_tsne_embeddings(embeddings)
    plot = draw_interactive_scatter_plot(
        tws, embeddings_2d[:, 0], embeddings_2d[:, 1], encoded_labels, encoded_labels_keywords, 'Tweet', 'Topic'
    )
    return plot


st.title("Tweet-SNEst")
st.write("Visualize tweets embeddings in 2D using colors for topics labels.")
st.caption('Please beware this is using Twitter free version of their API and might be needed to wait sometimes.')
col1, col2 = st.columns(2)
with col1:
    tw_user = st.text_input("Twitter handle", "huggingface")
with col2:
    tw_sample = st.number_input("Maximum number of tweets to use", 1, 300, 100, 10)

col1, col2 = st.columns(2)

with col1:
    expected_lang = st.radio(
        "What language should be assumed to be found?",
        ('English', 'Use all the ones you know (~15 lang)'),
        0
    )
with col2:
    go_btn = st.button('Visualize πŸš€')

with st.spinner(text="Loading brain... 🧠"):
    try:
        model = load_model(model_to_use[expected_lang])
    except FileNotFoundError:
        model = SentenceTransformer(model_to_use[expected_lang])

if go_btn and tw_user != '':
    tw_user = tw_user.replace(' ', '')
    usr = client.get_user(username=tw_user)
    with st.spinner(f"Getting to know the '{tw_user}'... πŸ”"):
        tweets_objs = []
        while tw_sample >= 100:
            current_sample = min(100, tw_sample)
            tweets_response = client.get_users_tweets(usr.data.id, max_results=current_sample, exclude=['retweets', 'replies'])
            tweets_objs += tweets_response.data
            tw_sample -= current_sample
        if tw_sample > 0:
            tweets_response = client.get_users_tweets(usr.data.id, max_results=tw_sample, exclude=['retweets', 'replies'])
            tweets_objs += tweets_response.data
    tweets_txt = [tweet.text for tweet in tweets_objs]
    tweets_txt = list(set(tweets_txt))
    tweets_txt_cleaned = preprocess(tweets_txt)
    plot = generate_plot(tweets_txt, tweets_txt_cleaned, model, tw_user)
    st.bokeh_chart(plot)
elif go_btn and tw_user == '':
    st.warning('Twitter handler field is empty πŸ™„')