Build error
Build error
File size: 5,802 Bytes
f040401 7eae2ec 25a6734 7964bf4 25a6734 2f7dfa0 5b5b795 f040401 25a6734 a5482a3 f040401 a5482a3 d548156 f040401 1731f45 f040401 5b5b795 085853b e98d306 61ce655 f040401 61ce655 d71c853 f040401 2f7dfa0 dc798d1 8e41537 2f7dfa0 f040401 9c5d67e f040401 9c5d67e f040401 30f3de6 9c5d67e a5482a3 20c61de 2d7af1c bff1960 20c61de 7964bf4 30f3de6 f040401 2f7dfa0 f040401 d548156 1bbc870 9c5d67e 0a699f8 1bbc870 0a699f8 f040401 61a29bd a38bdb0 d548156 30f3de6 f040401 d548156 691ccbe dbebcc0 30f3de6 f040401 59ff44b f040401 5c585a4 59ff44b f040401 92e2ae9 f040401 9c5d67e 5e899e5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
from typing import List
import numpy as np
import streamlit as st
import tweepy
import hdbscan
from bokeh.models import ColumnDataSource, HoverTool, Label
from bokeh.palettes import Colorblind as Pallete
from bokeh.palettes import Set3 as AuxPallete
from bokeh.plotting import Figure, figure
from bokeh.transform import factor_cmap
from sklearn.manifold import TSNE
from sentence_transformers import SentenceTransformer
client = tweepy.Client(bearer_token=st.secrets["tw_bearer_token"])
model_to_use = {
"English": "all-MiniLM-L12-v2",
"Use all the ones you know (~15 lang)": "paraphrase-multilingual-MiniLM-L12-v2"
# Original implementation from:
SEED = 42
@st.cache(show_spinner=False, allow_output_mutation=True)
def load_model(model_name: str) -> SentenceTransformer:
embedder = model_name
return SentenceTransformer(embedder)
def embed_text(text: List[str], model: SentenceTransformer) -> np.ndarray:
return model.encode(text)
def get_tsne_embeddings(
embeddings: np.ndarray, perplexity: int = 10, n_components: int = 2, init: str = "pca", n_iter: int = 5000, random_state: int = SEED
) -> np.ndarray:
tsne = TSNE(perplexity=perplexity, n_components=n_components, init=init, n_iter=n_iter, random_state=random_state)
return tsne.fit_transform(embeddings)
def draw_interactive_scatter_plot(
texts: np.ndarray, xs: np.ndarray, ys: np.ndarray, values: np.ndarray, labels: np.ndarray, text_column: str, label_column: str
) -> Figure:
# Normalize values to range between 0-255, to assign a color for each value
max_value = values.max()
min_value = values.min()
if max_value - min_value == 0:
values_color = np.ones(len(values))
values_color = ((values - min_value) / (max_value - min_value) * 255).round().astype(int).astype(str)
values_color_set = sorted(values_color)
values_list = values.astype(str).tolist()
values_set = sorted(values_list)
labels_list = labels.astype(str).tolist()
source = ColumnDataSource(data=dict(x=xs, y=ys, text=texts, label=values_list, original_label=labels_list))
hover = HoverTool(tooltips=[(text_column, "@text{safe}"), (label_column, "@original_label")])
n_colors = len(set(values_color_set))
if n_colors not in Pallete:
Palette = AuxPallete
p = figure(plot_width=800, plot_height=800, tools=[hover], title='2D visualization of tweets', background_fill_color="#fafafa")
colors = factor_cmap("label", palette=[Pallete[n_colors][int(id_) + 1] for id_ in values_color_set], factors=values_set)"x", "y", size=12, source=source, fill_alpha=0.4, line_color=colors, fill_color=colors)
p.axis.visible = False
p.xgrid.grid_line_dash = "dashed"
p.ygrid.grid_line_dash = "dashed"
# p.xgrid.grid_line_color = None
# p.ygrid.grid_line_color = None
p.toolbar.logo = None
disclaimer = Label(x=0, y=0, x_units="screen", y_units="screen",
text_font_size="14px", text_color="gray",
text="Topic equals -1 means no topic was detected for such tweet")
p.add_layout(disclaimer, "below")
return p
# Up to here
def generate_plot(
tws: List[str],
model: SentenceTransformer,
tw_user: str
) -> Figure:
with st.spinner(text=f"Trying to understand '{tw_user}' tweets... π€"):
embeddings = embed_text(tws, model)
# encoded_labels = encode_labels(labels)
cluster = hdbscan.HDBSCAN(
encoded_labels = cluster.labels_
with st.spinner("Now trying to express them with my own words... π¬"):
embeddings_2d = get_tsne_embeddings(embeddings)
plot = draw_interactive_scatter_plot(
tws, embeddings_2d[:, 0], embeddings_2d[:, 1], encoded_labels, encoded_labels, 'Tweet', 'Topic'
return plot
st.write("Visualize tweets embeddings in 2D using colors for topics labels.")
st.caption('Please beware this is using Twitter free version of their API and might be needed to wait sometimes.')
col1, col2 = st.columns(2)
with col1:
tw_user = st.text_input("Twitter handle", "huggingface")
with col2:
tw_sample = st.number_input("Maximum number of tweets to use", 1, 300, 100, 10)
col1, col2 = st.columns(2)
with col1:
expected_lang =
"What language should be assumed to be found?",
('English', 'Use all the ones you know (~15 lang)'),
with col2:
go_btn = st.button('Visualize π')
with st.spinner(text="Loading brain... π§ "):
model = load_model(model_to_use[expected_lang])
if go_btn and tw_user != '':
usr = client.get_user(username=tw_user)
tw_user = tw_user.replace(' ', '')
with st.spinner(f"Getting to know the '{tw_user}'... π"):
tweets_objs = []
while tw_sample >= 100:
current_sample = min(100, tw_sample)
tweets_response = client.get_users_tweets(, max_results=current_sample, exclude=['retweets', 'replies'])
tweets_objs +=
tw_sample -= current_sample
if tw_sample > 0:
tweets_response = client.get_users_tweets(, max_results=tw_sample, exclude=['retweets', 'replies'])
tweets_objs +=
tweets_txt = [tweet.text for tweet in tweets_objs]
tweets_txt = list(set(tweets_txt))
# plot = generate_plot(df, text_column, label_column, sample, dimensionality_reduction_function, model)
plot = generate_plot(tweets_txt, model, tw_user)
elif go_btn and tw_user == '':
st.warning('Twitter handler field is empty π') |