from typing import List import re import tweepy import hdbscan import numpy as np import streamlit as st from gensim.utils import deaccent # gensim==3.8.1 from bokeh.models import ColumnDataSource, HoverTool, Label from bokeh.palettes import Colorblind as Pallete from bokeh.palettes import Set3 as AuxPallete from bokeh.plotting import Figure, figure from bokeh.transform import factor_cmap from sklearn.manifold import TSNE from sentence_transformers import SentenceTransformer client = tweepy.Client(bearer_token=st.secrets["tw_bearer_token"]) model_to_use = { "English": "all-MiniLM-L12-v2", "Use all the ones you know (~15 lang)": "paraphrase-multilingual-MiniLM-L12-v2" } def remove_unk_chars(txt_list: List[str]): txt_list = [re.sub('\s+', ' ', tweet) for tweet in txt_list] txt_list = [re.sub("\'", "", tweet) for tweet in txt_list] txt_list = [deaccent(tweet).lower() for tweet in txt_list] def _remove_urls(txt_list: List[str]): url_regex = re.compile( r'^(?:http|ftp)s?://' # http:// or https:// r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|' #domain... r'localhost|' #localhost... r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # ...or ip r'(?::\d+)?' # optional port r'(?:/?|[/?]\S+)$', re.IGNORECASE) txt_list = [tweet.split(' ') for tweet in txt_list] return [' '.join([word for word in tweet if not bool(re.match(url_regex, word))]) for tweet in txt_list] def _remove_punctuation(txt_list: List[str]): punctuation = string.punctuation + 'ยฟยก|' txt_list = [tweet.split(' ') for tweet in txt_list] return [' '.join([word.translate(str.maketrans('', '', punctuation)) for word in tweet]) for tweet in txt_list] preprocess_pipeline = [ _remove_unk_chars, _remove_urls, _remove_punctuation ] def preprocess(txt_list: str): for op in preprocess_pipeline: txt_list = op(txt_list) return txt_list # Original implementation from: https://huggingface.co/spaces/edugp/embedding-lenses/blob/main/app.py SEED = 42 @st.cache(show_spinner=False, allow_output_mutation=True) def load_model(model_name: str) -> SentenceTransformer: embedder = model_name return SentenceTransformer(embedder) def embed_text(text: List[str], model: SentenceTransformer) -> np.ndarray: return model.encode(text) def get_tsne_embeddings( embeddings: np.ndarray, perplexity: int = 10, n_components: int = 2, init: str = "pca", n_iter: int = 5000, random_state: int = SEED ) -> np.ndarray: tsne = TSNE(perplexity=perplexity, n_components=n_components, init=init, n_iter=n_iter, random_state=random_state) return tsne.fit_transform(embeddings) def draw_interactive_scatter_plot( texts: np.ndarray, xs: np.ndarray, ys: np.ndarray, values: np.ndarray, labels: np.ndarray, text_column: str, label_column: str ) -> Figure: # Normalize values to range between 0-255, to assign a color for each value max_value = values.max() min_value = values.min() if max_value - min_value == 0: values_color = np.ones(len(values)) else: values_color = ((values - min_value) / (max_value - min_value) * 255).round().astype(int).astype(str) values_color_set = sorted(values_color) values_list = values.astype(str).tolist() values_set = sorted(values_list) labels_list = labels.astype(str).tolist() source = ColumnDataSource(data=dict(x=xs, y=ys, text=texts, label=values_list, original_label=labels_list)) hover = HoverTool(tooltips=[(text_column, "@text{safe}"), (label_column, "@original_label")]) n_colors = len(set(values_color_set)) if n_colors not in Pallete: Palette = AuxPallete p = figure(plot_width=800, plot_height=800, tools=[hover], title='2D visualization of tweets', background_fill_color="#fafafa") colors = factor_cmap("label", palette=[Pallete[n_colors][int(id_) + 1] for id_ in values_set], factors=values_set) p.circle("x", "y", size=12, source=source, fill_alpha=0.4, line_color=colors, fill_color=colors, legend_group="label") p.axis.visible = False p.xgrid.grid_line_dash = "dashed" p.ygrid.grid_line_dash = "dashed" # p.xgrid.grid_line_color = None # p.ygrid.grid_line_color = None p.toolbar.logo = None p.legend.location = "top_left" p.legend.title = "Topics ID" p.legend.background_fill_alpha = 0.2 disclaimer = Label(x=0, y=0, x_units="screen", y_units="screen", text_font_size="14px", text_color="gray", text="Topic equals -1 means no topic was detected for such tweet") p.add_layout(disclaimer, "below") return p # Up to here def generate_plot( tws: List[str], model: SentenceTransformer, tw_user: str ) -> Figure: with st.spinner(text=f"Trying to understand '{tw_user}' tweets... ๐Ÿค”"): embeddings = embed_text(tws, model) # encoded_labels = encode_labels(labels) cluster = hdbscan.HDBSCAN( min_cluster_size=3, metric='euclidean', cluster_selection_method='eom' ).fit(embeddings) encoded_labels = cluster.labels_ with st.spinner("Now trying to express them with my own words... ๐Ÿ’ฌ"): embeddings_2d = get_tsne_embeddings(embeddings) plot = draw_interactive_scatter_plot( tws, embeddings_2d[:, 0], embeddings_2d[:, 1], encoded_labels, encoded_labels, 'Tweet', 'Topic' ) return plot st.title("Tweet-SNEst") st.write("Visualize tweets embeddings in 2D using colors for topics labels.") st.caption('Please beware this is using Twitter free version of their API and might be needed to wait sometimes.') col1, col2 = st.columns(2) with col1: tw_user = st.text_input("Twitter handle", "huggingface") with col2: tw_sample = st.number_input("Maximum number of tweets to use", 1, 300, 100, 10) col1, col2 = st.columns(2) with col1: expected_lang = st.radio( "What language should be assumed to be found?", ('English', 'Use all the ones you know (~15 lang)'), 0 ) with col2: go_btn = st.button('Visualize ๐Ÿš€') with st.spinner(text="Loading brain... ๐Ÿง "): model = load_model(model_to_use[expected_lang]) if go_btn and tw_user != '': usr = client.get_user(username=tw_user) tw_user = tw_user.replace(' ', '') with st.spinner(f"Getting to know the '{tw_user}'... ๐Ÿ”"): tweets_objs = [] while tw_sample >= 100: current_sample = min(100, tw_sample) tweets_response = client.get_users_tweets(usr.data.id, max_results=current_sample, exclude=['retweets', 'replies']) tweets_objs += tweets_response.data tw_sample -= current_sample if tw_sample > 0: tweets_response = client.get_users_tweets(usr.data.id, max_results=tw_sample, exclude=['retweets', 'replies']) tweets_objs += tweets_response.data tweets_txt = [tweet.text for tweet in tweets_objs] tweets_txt = list(set(tweets_txt)) tweets_txt = preproces(tweets_txt) # plot = generate_plot(df, text_column, label_column, sample, dimensionality_reduction_function, model) plot = generate_plot(tweets_txt, model, tw_user) st.bokeh_chart(plot) elif go_btn and tw_user == '': st.warning('Twitter handler field is empty ๐Ÿ™„')