Spaces:
Build error
Build error
from typing import List | |
import numpy as np | |
import streamlit as st | |
import tweepy | |
import hdbscan | |
from bokeh.models import ColumnDataSource, HoverTool | |
from bokeh.palettes import Cividis256 as Pallete | |
from bokeh.plotting import Figure, figure | |
from bokeh.transform import factor_cmap | |
from sklearn.manifold import TSNE | |
from sentence_transformers import SentenceTransformer | |
client = tweepy.Client(bearer_token=st.secrets["tw_bearer_token"]) | |
model_to_use = { | |
"English": "all-MiniLM-L12-v2", | |
"Use all the ones you know (~15 lang)": "paraphrase-multilingual-MiniLM-L12-v2" | |
} | |
# Original implementation from: https://huggingface.co/spaces/edugp/embedding-lenses/blob/main/app.py | |
SEED = 42 | |
def load_model(model_name: str) -> SentenceTransformer: | |
embedder = model_name | |
return SentenceTransformer(embedder) | |
def embed_text(text: List[str], model: SentenceTransformer) -> np.ndarray: | |
return model.encode(text) | |
def get_tsne_embeddings( | |
embeddings: np.ndarray, perplexity: int = 10, n_components: int = 2, init: str = "pca", n_iter: int = 5000, random_state: int = SEED | |
) -> np.ndarray: | |
tsne = TSNE(perplexity=perplexity, n_components=n_components, init=init, n_iter=n_iter, random_state=random_state) | |
return tsne.fit_transform(embeddings) | |
def draw_interactive_scatter_plot( | |
texts: np.ndarray, xs: np.ndarray, ys: np.ndarray, values: np.ndarray, labels: np.ndarray, text_column: str, label_column: str | |
) -> Figure: | |
# Normalize values to range between 0-255, to assign a color for each value | |
max_value = values.max() | |
min_value = values.min() | |
if max_value - min_value == 0: | |
values_color = np.ones(len(values)) | |
else: | |
values_color = ((values - min_value) / (max_value - min_value) * 255).round().astype(int).astype(str) | |
values_color_set = sorted(values_color) | |
values_list = values.astype(str).tolist() | |
values_set = sorted(values_list) | |
labels_list = labels.astype(str).tolist() | |
source = ColumnDataSource(data=dict(x=xs, y=ys, text=texts, label=values_list, original_label=labels_list)) | |
hover = HoverTool(tooltips=[(text_column, "@text{safe}"), (label_column, "@original_label")]) | |
p = figure(plot_width=800, plot_height=800, tools=[hover]) | |
p.circle("x", "y", size=10, source=source, fill_color=factor_cmap("label", palette=[Pallete[int(id_)] for id_ in values_color_set], factors=values_set)) | |
p.axis.visible = False | |
p.xgrid.grid_line_color = None | |
p.ygrid.grid_line_color = None | |
p.toolbar.logo = None | |
return p | |
# Up to here | |
def generate_plot( | |
tws: List[str], | |
model: SentenceTransformer, | |
tw_user: str | |
) -> Figure: | |
with st.spinner(text=f"Trying to understand '{tw_user}' tweets..."): | |
embeddings = embed_text(tws, model) | |
# encoded_labels = encode_labels(labels) | |
cluster = hdbscan.HDBSCAN( | |
min_cluster_size=3, | |
metric='euclidean', | |
cluster_selection_method='eom' | |
).fit(embeddings) | |
encoded_labels = cluster.labels_ | |
with st.spinner("Now trying to express them with my own words..."): | |
embeddings_2d = get_tsne_embeddings(embeddings) | |
plot = draw_interactive_scatter_plot( | |
tws, embeddings_2d[:, 0], embeddings_2d[:, 1], encoded_labels, encoded_labels, 'text', 'label' | |
) | |
return plot | |
st.title("Tweet-SNEst") | |
st.write("Visualize tweets embeddings in 2D using colors for topics labels.") | |
st.caption('Please beware this is using Twitter free version of their API and might be needed to wait sometimes.') | |
col1, col2 = st.columns(2) | |
with col1: | |
tw_user = st.text_input("Twitter handle", "huggingface") | |
with col2: | |
tw_sample = st.number_input("Maximum number of tweets to use", 1, 300, 100, 10) | |
col1, col2 = st.columns(2) | |
with col1: | |
expected_lang = st.radio( | |
"What language should be assumed to be found?", | |
('English', 'Use all the ones you know (~15 lang)'), | |
0 | |
) | |
with col2: | |
go_btn = st.button('Visualize') | |
with st.spinner(text="Loading model..."): | |
model = load_model(model_to_use[expected_lang]) | |
usr = client.get_user(username=tw_user) | |
# st.write(usr.data.id) | |
if go_btn: | |
with st.spinner(f"Getting to know the '{tw_user}'..."): | |
tweets_objs = [] | |
while tw_sample >= 100: | |
current_sample = min(100, tw_sample) | |
tweets_response = client.get_users_tweets(usr.data.id, max_results=current_sample, exclude=['retweets', 'replies']) | |
tweets_objs += tweets_response.data | |
tw_sample -= current_sample | |
if tw_sample > 0: | |
tweets_response = client.get_users_tweets(usr.data.id, max_results=tw_sample, exclude=['retweets', 'replies']) | |
tweets_objs += tweets_response.data | |
tweets_txt = [tweet.text for tweet in tweets_objs] | |
tweets_txt = list(set(tweets_txt)) | |
# plot = generate_plot(df, text_column, label_column, sample, dimensionality_reduction_function, model) | |
plot = generate_plot(tweets_txt, model, tw_user) | |
st.bokeh_chart(plot) |