File size: 7,301 Bytes
f040401
 
b2c3406
25a6734
7964bf4
b2c3406
 
 
 
 
25a6734
2f7dfa0
5b5b795
 
f040401
 
 
 
 
25a6734
a5482a3
 
 
 
 
 
b2c3406
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f040401
a5482a3
 
d548156
 
 
 
 
f040401
 
 
 
1731f45
f040401
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b5b795
 
 
085853b
a27bca6
c959b1a
f040401
61ce655
 
d71c853
 
f040401
c959b1a
 
9460aa5
2f7dfa0
dc798d1
8e41537
2f7dfa0
 
f040401
 
 
 
9c5d67e
f040401
9c5d67e
f040401
30f3de6
9c5d67e
a5482a3
20c61de
2d7af1c
bff1960
20c61de
 
7964bf4
30f3de6
f040401
 
2f7dfa0
f040401
 
 
d548156
1bbc870
 
9c5d67e
0a699f8
 
1bbc870
0a699f8
f040401
61a29bd
 
 
 
 
 
 
 
 
 
a38bdb0
d548156
30f3de6
f040401
d548156
691ccbe
dbebcc0
 
30f3de6
f040401
 
 
59ff44b
f040401
 
5c585a4
59ff44b
f040401
92e2ae9
 
b2c3406
f040401
9c5d67e
5e899e5
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
from typing import List

import re
import tweepy
import hdbscan
import numpy as np
import streamlit as st


from gensim.utils import deaccent # gensim==3.8.1

from bokeh.models import ColumnDataSource, HoverTool, Label
from bokeh.palettes import Colorblind as Pallete
from bokeh.palettes import Set3 as AuxPallete
from bokeh.plotting import Figure, figure
from bokeh.transform import factor_cmap

from sklearn.manifold import TSNE
from sentence_transformers import SentenceTransformer

client = tweepy.Client(bearer_token=st.secrets["tw_bearer_token"])
model_to_use = {
    "English": "all-MiniLM-L12-v2",
    "Use all the ones you know (~15 lang)": "paraphrase-multilingual-MiniLM-L12-v2"
}

def remove_unk_chars(txt_list: List[str]):
    txt_list = [re.sub('\s+', ' ', tweet) for tweet in txt_list]
    txt_list = [re.sub("\'", "", tweet) for tweet in txt_list]
    txt_list = [deaccent(tweet).lower() for tweet in txt_list]

def _remove_urls(txt_list: List[str]):
    url_regex = re.compile(
        r'^(?:http|ftp)s?://' # http:// or https://
        r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|' #domain...
        r'localhost|' #localhost...
        r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # ...or ip
        r'(?::\d+)?' # optional port
        r'(?:/?|[/?]\S+)$', re.IGNORECASE)
    txt_list = [tweet.split(' ') for tweet in txt_list]
    return [' '.join([word for word in tweet if not bool(re.match(url_regex, word))]) for tweet in txt_list]

def _remove_punctuation(txt_list: List[str]):
    punctuation = string.punctuation + 'ΒΏΒ‘|'
    txt_list = [tweet.split(' ') for tweet in txt_list]
   return [' '.join([word.translate(str.maketrans('', '', punctuation)) for word in tweet]) for tweet in txt_list]

preprocess_pipeline = [
    _remove_unk_chars,
    _remove_urls,
    _remove_punctuation
]

def preprocess(txt_list: str):
    for op in preprocess_pipeline:
        txt_list = op(txt_list)
    return txt_list

# Original implementation from: https://huggingface.co/spaces/edugp/embedding-lenses/blob/main/app.py
SEED = 42

@st.cache(show_spinner=False, allow_output_mutation=True)
def load_model(model_name: str) -> SentenceTransformer:
    embedder = model_name
    return SentenceTransformer(embedder)

def embed_text(text: List[str], model: SentenceTransformer) -> np.ndarray:
    return model.encode(text)

def get_tsne_embeddings(
    embeddings: np.ndarray, perplexity: int = 10, n_components: int = 2, init: str = "pca", n_iter: int = 5000, random_state: int = SEED
) -> np.ndarray:
    tsne = TSNE(perplexity=perplexity, n_components=n_components, init=init, n_iter=n_iter, random_state=random_state)
    return tsne.fit_transform(embeddings)

def draw_interactive_scatter_plot(
    texts: np.ndarray, xs: np.ndarray, ys: np.ndarray, values: np.ndarray, labels: np.ndarray, text_column: str, label_column: str
) -> Figure:
    # Normalize values to range between 0-255, to assign a color for each value
    max_value = values.max()
    min_value = values.min()
    if max_value - min_value == 0:
        values_color = np.ones(len(values))
    else:
        values_color = ((values - min_value) / (max_value - min_value) * 255).round().astype(int).astype(str)
    values_color_set = sorted(values_color)
    values_list = values.astype(str).tolist()
    values_set = sorted(values_list)
    labels_list = labels.astype(str).tolist()
    source = ColumnDataSource(data=dict(x=xs, y=ys, text=texts, label=values_list, original_label=labels_list))
    hover = HoverTool(tooltips=[(text_column, "@text{safe}"), (label_column, "@original_label")])
    n_colors = len(set(values_color_set))
    if n_colors not in Pallete:
        Palette = AuxPallete
    p = figure(plot_width=800, plot_height=800, tools=[hover], title='2D visualization of tweets', background_fill_color="#fafafa")
    colors = factor_cmap("label", palette=[Pallete[n_colors][int(id_) + 1] for id_ in values_set], factors=values_set)
    p.circle("x", "y", size=12, source=source, fill_alpha=0.4, line_color=colors, fill_color=colors, legend_group="label")
    p.axis.visible = False
    p.xgrid.grid_line_dash = "dashed"
    p.ygrid.grid_line_dash = "dashed"
    # p.xgrid.grid_line_color = None
    # p.ygrid.grid_line_color = None
    p.toolbar.logo = None
    p.legend.location = "top_left"
    p.legend.title = "Topics ID"
    p.legend.background_fill_alpha = 0.2
    
    disclaimer = Label(x=0, y=0, x_units="screen", y_units="screen",
                   text_font_size="14px", text_color="gray",
                   text="Topic equals -1 means no topic was detected for such tweet")
    p.add_layout(disclaimer, "below")
    return p

# Up to here
def generate_plot(
    tws: List[str],
    model: SentenceTransformer,
    tw_user: str
) -> Figure:
    with st.spinner(text=f"Trying to understand '{tw_user}' tweets... πŸ€”"):
        embeddings = embed_text(tws, model)
    # encoded_labels = encode_labels(labels)
    cluster = hdbscan.HDBSCAN(
        min_cluster_size=3,
        metric='euclidean',
        cluster_selection_method='eom'
    ).fit(embeddings)
    encoded_labels = cluster.labels_
    with st.spinner("Now trying to express them with my own words... πŸ’¬"):
        embeddings_2d = get_tsne_embeddings(embeddings)
    plot = draw_interactive_scatter_plot(
        tws, embeddings_2d[:, 0], embeddings_2d[:, 1], encoded_labels, encoded_labels, 'Tweet', 'Topic'
    )
    return plot


st.title("Tweet-SNEst")
st.write("Visualize tweets embeddings in 2D using colors for topics labels.")
st.caption('Please beware this is using Twitter free version of their API and might be needed to wait sometimes.')
col1, col2 = st.columns(2)
with col1:
    tw_user = st.text_input("Twitter handle", "huggingface")
with col2:
    tw_sample = st.number_input("Maximum number of tweets to use", 1, 300, 100, 10)

col1, col2 = st.columns(2)

with col1:
    expected_lang = st.radio(
        "What language should be assumed to be found?",
        ('English', 'Use all the ones you know (~15 lang)'),
        0
    )
with col2:
    go_btn = st.button('Visualize πŸš€')

with st.spinner(text="Loading brain... 🧠"):
    model = load_model(model_to_use[expected_lang])

if go_btn and tw_user != '':
    usr = client.get_user(username=tw_user)
    tw_user = tw_user.replace(' ', '')
    with st.spinner(f"Getting to know the '{tw_user}'... πŸ”"):
        tweets_objs = []
        while tw_sample >= 100:
            current_sample = min(100, tw_sample)
            tweets_response = client.get_users_tweets(usr.data.id, max_results=current_sample, exclude=['retweets', 'replies'])
            tweets_objs += tweets_response.data
            tw_sample -= current_sample
        if tw_sample > 0:
            tweets_response = client.get_users_tweets(usr.data.id, max_results=tw_sample, exclude=['retweets', 'replies'])
        tweets_objs += tweets_response.data
    tweets_txt = [tweet.text for tweet in tweets_objs]
    tweets_txt = list(set(tweets_txt))
    tweets_txt = preproces(tweets_txt)
    # plot = generate_plot(df, text_column, label_column, sample, dimensionality_reduction_function, model)
    plot = generate_plot(tweets_txt, model, tw_user)
    st.bokeh_chart(plot)
elif go_btn and tw_user == '':
    st.warning('Twitter handler field is empty πŸ™„')