Spaces:
Build error
Build error
File size: 9,996 Bytes
f040401 7f1a60e 53bad35 b2c3406 8b81843 25a6734 7964bf4 8b81843 b2c3406 53bad35 6028fe6 5b5b795 f040401 c7010dd 25a6734 a5482a3 f0f3e26 a5482a3 8b81843 d1cba0c b2c3406 76a8518 b2c3406 9d2c094 b2c3406 4d8d3df 8b81843 4d8d3df b2c3406 d556203 6e3e890 4d8d3df b2c3406 f040401 a5482a3 d548156 f040401 1731f45 f040401 4b05a76 f040401 5b5b795 085853b a27bca6 2c7ffd9 7f9ee2f 1323e99 f040401 61ce655 d71c853 2c7ffd9 51f54ce ce50e18 2f7dfa0 71a7e2c 60bdc33 f040401 9c5d67e 0d13483 f040401 9c5d67e f040401 30f3de6 0d13483 f0f3e26 7964bf4 7f1a60e 30f3de6 7f1a60e da31fda ce50e18 1323e99 7f1a60e 4b206d5 7f1a60e 1dcaf3f 4b206d5 7f1a60e 4b206d5 401a74f 7f1a60e 1c4df76 fe1cb2a 7f1a60e dded833 1f3e17c fe1cb2a 578e511 6319da9 7f1a60e f040401 7f1a60e f040401 d548156 1bbc870 9c5d67e 0a699f8 1bbc870 0a699f8 f040401 61a29bd a38bdb0 d548156 30f3de6 d099a9b d548156 691ccbe dbebcc0 1ef5823 30f3de6 f040401 59ff44b f040401 5c585a4 59ff44b 1ef5823 92e2ae9 0d13483 5e899e5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 |
from typing import List
import itertools
import string
import re
import requests
import tweepy
import hdbscan
import numpy as np
import streamlit as st
from gensim.utils import deaccent
from bokeh.models import ColumnDataSource, HoverTool, Label, Legend
from bokeh.palettes import Colorblind as Pallete
from bokeh.palettes import Set3 as AuxPallete
from bokeh.plotting import Figure, figure
from bokeh.transform import factor_cmap
from sklearn.manifold import TSNE
from sentence_transformers import SentenceTransformer, util
client = tweepy.Client(bearer_token=st.secrets["tw_bearer_token"])
model_to_use = {
"English": "all-MiniLM-L6-v2",
"Use all the ones you know (~15 lang)": "paraphrase-multilingual-MiniLM-L12-v2"
}
stopwords_list = requests.get("https://gist.githubusercontent.com/rg089/35e00abf8941d72d419224cfd5b5925d/raw/12d899b70156fd0041fa9778d657330b024b959c/stopwords.txt").content
stopwords = set(stopwords_list.decode().splitlines())
def _remove_unk_chars(txt_list: List[str]):
txt_list = [re.sub('\s+', ' ', tweet) for tweet in txt_list]
txt_list = [re.sub("\'", "", tweet) for tweet in txt_list]
txt_list = [deaccent(tweet).lower() for tweet in txt_list]
return txt_list
def _remove_urls(txt_list: List[str]):
url_regex = re.compile(
r'^(?:http|ftp)s?://' # http:// or https://
r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|' #domain...
r'localhost|' #localhost...
r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # ...or ip
r'(?::\d+)?' # optional port
r'(?:/?|[/?]\S+)$', re.IGNORECASE)
txt_list = [tweet.split(' ') for tweet in txt_list]
return [' '.join([word for word in tweet if not bool(re.match(url_regex, word))]) for tweet in txt_list]
def _remove_punctuation(txt_list: List[str]):
punctuation = string.punctuation + 'ΒΏΒ‘|'
txt_list = [tweet.split(' ') for tweet in txt_list]
return [' '.join([word.translate(str.maketrans('', '', punctuation)) for word in tweet]) for tweet in txt_list]
def _remove_stopwords(txt_list: List[str]):
txt_list = [tweet.split(' ') for tweet in txt_list]
return [' '.join([word for word in tweet if word not in stopwords]) for tweet in txt_list]
preprocess_pipeline = [
_remove_unk_chars,
_remove_urls,
_remove_punctuation,
_remove_stopwords,
]
def preprocess(txt_list: str):
for op in preprocess_pipeline:
txt_list = op(txt_list)
return txt_list
# Original implementation from: https://huggingface.co/spaces/edugp/embedding-lenses/blob/main/app.py
SEED = 42
@st.cache(show_spinner=False, allow_output_mutation=True)
def load_model(model_name: str) -> SentenceTransformer:
embedder = model_name
return SentenceTransformer(embedder)
def embed_text(text: List[str], model: SentenceTransformer) -> np.ndarray:
return model.encode(text)
def get_tsne_embeddings(
embeddings: np.ndarray, perplexity: int = 10, n_components: int = 2, init: str = "pca", n_iter: int = 5000, random_state: int = SEED
) -> np.ndarray:
tsne = TSNE(perplexity=perplexity, n_components=n_components, init=init, n_iter=n_iter, random_state=random_state)
return tsne.fit_transform(embeddings)
def draw_interactive_scatter_plot(
texts: np.ndarray, xs: np.ndarray, ys: np.ndarray, values: np.ndarray, labels: np.ndarray, text_column: str, label_column: str
) -> Figure:
# Normalize values to range between 0-255, to assign a color for each value
max_value = values.max()
min_value = values.min()
if max_value - min_value == 0:
values_color = np.ones(len(values))
else:
values_color = ((values - min_value) / (max_value - min_value) * 255).round().astype(int).astype(str)
values_color_set = sorted(values_color)
values_list = values.astype(str).tolist()
values_set = sorted(values_list)
source = ColumnDataSource(data=dict(x=xs, y=ys, text=texts, label=values_list, original_label=labels))
hover = HoverTool(tooltips=[(text_column, "@text{safe}"), (label_column, "@original_label")])
n_colors = len(set(values_color_set))
if n_colors not in Pallete:
Palette = AuxPallete
p = figure(plot_width=800, plot_height=800, tools=[hover], title='2D visualization of tweets', background_fill_color="#fafafa")
colors = factor_cmap("label", palette=[Pallete[n_colors][int(id_) + 1] for id_ in values_set], factors=values_set)
p.add_layout(Legend(location='top_left', title='Topics keywords', background_fill_alpha=0.2), 'above')
p.circle("x", "y", size=12, source=source, fill_alpha=0.4, line_color=colors, fill_color=colors, legend_group="original_label")
p.axis.visible = False
p.xgrid.grid_line_dash = "dashed"
p.ygrid.grid_line_dash = "dashed"
# p.xgrid.grid_line_color = None
# p.ygrid.grid_line_color = None
p.toolbar.logo = None
# p.legend.location = "bottom_right"
# p.legend.title = "Topics ID"
# p.legend.background_fill_alpha = 0.25
# disclaimer = Label(x=0, y=0, x_units="screen", y_units="screen",
# text_font_size="14px", text_color="gray",
# text="Topic equals -1 means no topic was detected for such tweet")
# p.add_layout(disclaimer, "below")
return p
# Up to here
def generate_plot(
tws: List[str],
tws_cleaned: List[str],
model: SentenceTransformer,
tw_user: str
) -> Figure:
with st.spinner(text=f"Trying to understand '{tw_user}' tweets... π€"):
embeddings = embed_text(tws_cleaned, model)
# encoded_labels = encode_labels(labels)
cluster = hdbscan.HDBSCAN(
min_cluster_size=3,
metric='euclidean',
cluster_selection_method='eom'
).fit(embeddings)
encoded_labels = cluster.labels_
cluster_keyword = {}
with st.spinner("Now trying to express them with my own words... π¬"):
for label in set(encoded_labels):
if label == -1:
cluster_keyword[label] = 'Too diverse!'
continue
cluster_keyword[label] = []
cluster_tws = []
cluster_ixs = []
for ix, obs in enumerate(encoded_labels):
if obs == label:
cluster_tws.append(tws_cleaned[ix])
cluster_ixs.append(ix)
cluster_words = [tw.split(' ') for tw in cluster_tws]
cluster_words = list(set(itertools.chain.from_iterable(cluster_words)))
# cluster_embeddings = embed_text(cluster_tws, model)
cluster_embeddings = [embeddings[i] for i in cluster_ixs]
cluster_embeddings_avg = np.mean(cluster_embeddings, axis=0)
cluster_words_embeddings = embed_text(cluster_words, model)
cluster_to_words_similarities = util.dot_score(cluster_embeddings_avg, cluster_words_embeddings)
cluster_to_words_similarities = [(word_ix, similarity) for word_ix, similarity in enumerate(cluster_to_words_similarities[0])]
cluster_to_words_similarities = sorted(cluster_to_words_similarities, key=lambda x: x[1], reverse=True)
while len(cluster_keyword[label]) < 3:
try:
most_descriptive = cluster_to_words_similarities.pop(0)
except IndexError:
break
cluster_keyword[label].append(cluster_words[most_descriptive[0]])
if len(cluster_keyword[label]) == 1:
cluster_keyword[label] = cluster_keyword[label][0]
elif len(cluster_keyword[label]) == 0:
cluster_keyword[label] = '-'
elif len(cluster_keyword[label]) > 1:
cluster_keyword[label] = ', '.join(cluster_keyword[label])
encoded_labels_keywords = [cluster_keyword[encoded_label] for encoded_label in encoded_labels]
embeddings_2d = get_tsne_embeddings(embeddings)
plot = draw_interactive_scatter_plot(
tws, embeddings_2d[:, 0], embeddings_2d[:, 1], encoded_labels, encoded_labels_keywords, 'Tweet', 'Topic'
)
return plot
st.title("Tweet-SNEst")
st.write("Visualize tweets embeddings in 2D using colors for topics labels.")
st.caption('Please beware this is using Twitter free version of their API and might be needed to wait sometimes.')
col1, col2 = st.columns(2)
with col1:
tw_user = st.text_input("Twitter handle", "huggingface")
with col2:
tw_sample = st.number_input("Maximum number of tweets to use", 1, 300, 100, 10)
col1, col2 = st.columns(2)
with col1:
expected_lang = st.radio(
"What language should be assumed to be found?",
('English', 'Use all the ones you know (~15 lang)'),
0
)
with col2:
go_btn = st.button('Visualize π')
with st.spinner(text="Loading brain... π§ "):
try:
model = load_model(model_to_use[expected_lang])
except FileNotFoundError:
model = SentenceTransformer(model_to_use[expected_lang])
if go_btn and tw_user != '':
tw_user = tw_user.replace(' ', '')
usr = client.get_user(username=tw_user)
with st.spinner(f"Getting to know the '{tw_user}'... π"):
tweets_objs = []
while tw_sample >= 100:
current_sample = min(100, tw_sample)
tweets_response = client.get_users_tweets(usr.data.id, max_results=current_sample, exclude=['retweets', 'replies'])
tweets_objs += tweets_response.data
tw_sample -= current_sample
if tw_sample > 0:
tweets_response = client.get_users_tweets(usr.data.id, max_results=tw_sample, exclude=['retweets', 'replies'])
tweets_objs += tweets_response.data
tweets_txt = [tweet.text for tweet in tweets_objs]
tweets_txt = list(set(tweets_txt))
tweets_txt_cleaned = preprocess(tweets_txt)
plot = generate_plot(tweets_txt, tweets_txt_cleaned, model, tw_user)
st.bokeh_chart(plot)
elif go_btn and tw_user == '':
st.warning('Twitter handler field is empty π') |