Spaces:
Build error
Build error
from typing import List | |
import itertools | |
import string | |
import re | |
import requests | |
import tweepy | |
import hdbscan | |
import numpy as np | |
import streamlit as st | |
from gensim.utils import deaccent | |
from bokeh.models import ColumnDataSource, HoverTool, Label, Legend | |
from bokeh.palettes import Colorblind as Pallete | |
from bokeh.palettes import Set3 as AuxPallete | |
from bokeh.plotting import Figure, figure | |
from bokeh.transform import factor_cmap | |
from sklearn.manifold import TSNE | |
from sentence_transformers import SentenceTransformer, util | |
client = tweepy.Client(bearer_token=st.secrets["tw_bearer_token"]) | |
model_to_use = { | |
"English": "all-MiniLM-L12-v2", | |
"Use all the ones you know (~15 lang)": "paraphrase-multilingual-MiniLM-L12-v2" | |
} | |
stopwords_list = requests.get("https://gist.githubusercontent.com/rg089/35e00abf8941d72d419224cfd5b5925d/raw/12d899b70156fd0041fa9778d657330b024b959c/stopwords.txt").content | |
stopwords = set(stopwords_list.decode().splitlines()) | |
def _remove_unk_chars(txt_list: List[str]): | |
txt_list = [re.sub('\s+', ' ', tweet) for tweet in txt_list] | |
txt_list = [re.sub("\'", "", tweet) for tweet in txt_list] | |
txt_list = [deaccent(tweet).lower() for tweet in txt_list] | |
return txt_list | |
def _remove_urls(txt_list: List[str]): | |
url_regex = re.compile( | |
r'^(?:http|ftp)s?://' # http:// or https:// | |
r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|' #domain... | |
r'localhost|' #localhost... | |
r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # ...or ip | |
r'(?::\d+)?' # optional port | |
r'(?:/?|[/?]\S+)$', re.IGNORECASE) | |
txt_list = [tweet.split(' ') for tweet in txt_list] | |
return [' '.join([word for word in tweet if not bool(re.match(url_regex, word))]) for tweet in txt_list] | |
def _remove_punctuation(txt_list: List[str]): | |
punctuation = string.punctuation + 'ΒΏΒ‘|' | |
txt_list = [tweet.split(' ') for tweet in txt_list] | |
return [' '.join([word.translate(str.maketrans('', '', punctuation)) for word in tweet]) for tweet in txt_list] | |
def _remove_stopwords(txt_list: List[str]): | |
txt_list = [tweet.split(' ') for tweet in txt_list] | |
return [' '.join([word for word in tweet if word not in stopwords]) for tweet in txt_list] | |
preprocess_pipeline = [ | |
_remove_unk_chars, | |
_remove_urls, | |
_remove_punctuation, | |
_remove_stopwords, | |
] | |
def preprocess(txt_list: str): | |
for op in preprocess_pipeline: | |
txt_list = op(txt_list) | |
return txt_list | |
# Original implementation from: https://huggingface.co/spaces/edugp/embedding-lenses/blob/main/app.py | |
SEED = 42 | |
def load_model(model_name: str) -> SentenceTransformer: | |
embedder = model_name | |
return SentenceTransformer(embedder) | |
def embed_text(text: List[str], model: SentenceTransformer) -> np.ndarray: | |
return model.encode(text) | |
def get_tsne_embeddings( | |
embeddings: np.ndarray, perplexity: int = 10, n_components: int = 2, init: str = "pca", n_iter: int = 5000, random_state: int = SEED | |
) -> np.ndarray: | |
tsne = TSNE(perplexity=perplexity, n_components=n_components, init=init, n_iter=n_iter, random_state=random_state) | |
return tsne.fit_transform(embeddings) | |
def draw_interactive_scatter_plot( | |
texts: np.ndarray, xs: np.ndarray, ys: np.ndarray, values: np.ndarray, labels: np.ndarray, text_column: str, label_column: str | |
) -> Figure: | |
# Normalize values to range between 0-255, to assign a color for each value | |
max_value = values.max() | |
min_value = values.min() | |
if max_value - min_value == 0: | |
values_color = np.ones(len(values)) | |
else: | |
values_color = ((values - min_value) / (max_value - min_value) * 255).round().astype(int).astype(str) | |
values_color_set = sorted(values_color) | |
values_list = values.astype(str).tolist() | |
values_set = sorted(values_list) | |
source = ColumnDataSource(data=dict(x=xs, y=ys, text=texts, label=values_list, original_label=labels)) | |
hover = HoverTool(tooltips=[(text_column, "@text{safe}"), (label_column, "@original_label")]) | |
n_colors = len(set(values_color_set)) | |
if n_colors not in Pallete: | |
Palette = AuxPallete | |
p = figure(plot_width=800, plot_height=800, tools=[hover], title='2D visualization of tweets', background_fill_color="#fafafa") | |
colors = factor_cmap("label", palette=[Pallete[n_colors][int(id_) + 1] for id_ in values_set], factors=values_set) | |
p.add_layout(Legend(location='top_left', title='Topics keywords', background_fill_alpha=0.2), 'above') | |
p.circle("x", "y", size=12, source=source, fill_alpha=0.4, line_color=colors, fill_color=colors, legend_group="original_label") | |
p.axis.visible = False | |
p.xgrid.grid_line_dash = "dashed" | |
p.ygrid.grid_line_dash = "dashed" | |
# p.xgrid.grid_line_color = None | |
# p.ygrid.grid_line_color = None | |
p.toolbar.logo = None | |
# p.legend.location = "bottom_right" | |
# p.legend.title = "Topics ID" | |
# p.legend.background_fill_alpha = 0.25 | |
# disclaimer = Label(x=0, y=0, x_units="screen", y_units="screen", | |
# text_font_size="14px", text_color="gray", | |
# text="Topic equals -1 means no topic was detected for such tweet") | |
# p.add_layout(disclaimer, "below") | |
return p | |
# Up to here | |
def generate_plot( | |
tws: List[str], | |
tws_cleaned: List[str], | |
model: SentenceTransformer, | |
tw_user: str | |
) -> Figure: | |
with st.spinner(text=f"Trying to understand '{tw_user}' tweets... π€"): | |
embeddings = embed_text(tws_cleaned, model) | |
# encoded_labels = encode_labels(labels) | |
cluster = hdbscan.HDBSCAN( | |
min_cluster_size=3, | |
metric='euclidean', | |
cluster_selection_method='eom' | |
).fit(embeddings) | |
encoded_labels = cluster.labels_ | |
cluster_keyword = {} | |
with st.spinner("Now trying to express them with my own words... π¬"): | |
for label in set(encoded_labels): | |
if label == -1: | |
cluster_keyword[label] = 'Too diverse!' | |
continue | |
cluster_keyword[label] = [] | |
cluster_tws = [] | |
for ix, obs in enumerate(encoded_labels): | |
if obs == label: | |
cluster_tws.append(tws_cleaned[ix]) | |
cluster_words = [tw.split(' ') for tw in cluster_tws] | |
cluster_words = list(set(itertools.chain.from_iterable(cluster_words))) | |
cluster_embeddings = embed_text(cluster_tws, model) | |
cluster_embeddings_avg = np.mean(cluster_embeddings, axis=0) | |
cluster_words_embeddings = embed_text(cluster_words, model) | |
cluster_to_words_similarities = util.dot_score(cluster_embeddings_avg, cluster_words_embeddings) | |
cluster_to_words_similarities = [(word_ix, similarity) for word_ix, similarity in enumerate(cluster_to_words_similarities[0])] | |
cluster_to_words_similarities = sorted(cluster_to_words_similarities, key=lambda x: x[1], reverse=True) | |
while len(cluster_keyword[label]) < 3: | |
most_descriptive = cluster_to_words_similarities.pop(0) | |
cluster_keyword[label].append(cluster_words[most_descriptive[0]]) | |
cluster_keyword[label] = ', '.join(cluster_keyword[label]) | |
encoded_labels_keywords = [cluster_keyword[encoded_label] for encoded_label in encoded_labels] | |
embeddings_2d = get_tsne_embeddings(embeddings) | |
plot = draw_interactive_scatter_plot( | |
tws, embeddings_2d[:, 0], embeddings_2d[:, 1], encoded_labels, encoded_labels_keywords, 'Tweet', 'Topic' | |
) | |
return plot | |
st.title("Tweet-SNEst") | |
st.write("Visualize tweets embeddings in 2D using colors for topics labels.") | |
st.caption('Please beware this is using Twitter free version of their API and might be needed to wait sometimes.') | |
col1, col2 = st.columns(2) | |
with col1: | |
tw_user = st.text_input("Twitter handle", "huggingface") | |
with col2: | |
tw_sample = st.number_input("Maximum number of tweets to use", 1, 300, 100, 10) | |
col1, col2 = st.columns(2) | |
with col1: | |
expected_lang = st.radio( | |
"What language should be assumed to be found?", | |
('English', 'Use all the ones you know (~15 lang)'), | |
0 | |
) | |
with col2: | |
go_btn = st.button('Visualize π') | |
with st.spinner(text="Loading brain... π§ "): | |
model = load_model(model_to_use[expected_lang]) | |
if go_btn and tw_user != '': | |
usr = client.get_user(username=tw_user) | |
tw_user = tw_user.replace(' ', '') | |
with st.spinner(f"Getting to know the '{tw_user}'... π"): | |
tweets_objs = [] | |
while tw_sample >= 100: | |
current_sample = min(100, tw_sample) | |
tweets_response = client.get_users_tweets(usr.data.id, max_results=current_sample, exclude=['retweets', 'replies']) | |
tweets_objs += tweets_response.data | |
tw_sample -= current_sample | |
if tw_sample > 0: | |
tweets_response = client.get_users_tweets(usr.data.id, max_results=tw_sample, exclude=['retweets', 'replies']) | |
tweets_objs += tweets_response.data | |
tweets_txt = [tweet.text for tweet in tweets_objs] | |
tweets_txt = list(set(tweets_txt)) | |
tweets_txt_cleaned = preprocess(tweets_txt) | |
plot = generate_plot(tweets_txt, tweets_txt_cleaned, model, tw_user) | |
st.bokeh_chart(plot) | |
elif go_btn and tw_user == '': | |
st.warning('Twitter handler field is empty π') |