wilmerags commited on
Commit
f040401
·
1 Parent(s): ba42dce

feat: Add experimental version of plot after adaption from referred space

Browse files
Files changed (1) hide show
  1. app.py +88 -5
app.py CHANGED
@@ -1,13 +1,75 @@
 
 
 
 
1
  import streamlit as st
2
  import tweepy
3
- from sentence_transformers import SentenceTransformer
4
 
 
 
 
 
 
 
 
5
 
 
6
  @st.cache(show_spinner=False, allow_output_mutation=True)
7
  def load_model(model_name: str) -> SentenceTransformer:
8
  embedder = model_name
9
  return SentenceTransformer(embedder)
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  client = tweepy.Client(bearer_token=st.secrets["tw_bearer_token"])
13
  model_to_use = {
@@ -21,14 +83,35 @@ col1, col2 = st.columns(2)
21
  with col1:
22
  tw_user = st.text_input("Twitter handle", "huggingface")
23
  with col2:
24
- sample = st.number_input("Maximum number of tweets to use", 1, 300, 100, 10)
25
 
26
  expected_lang = st.radio(
27
  "What language should be assumed to be found?",
28
- ('English', 'Use all the ones you know (~15 lang)'))
 
 
29
 
 
 
30
 
31
  usr = client.get_user(username=tw_user)
32
 
33
- st.write(usr.data.id)
34
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import numpy as np
4
+
5
  import streamlit as st
6
  import tweepy
 
7
 
8
+ from bokeh.models import ColumnDataSource, HoverTool
9
+ from bokeh.palettes import Cividis256 as Pallete
10
+ from bokeh.plotting import Figure, figure
11
+ from bokeh.transform import factor_cmap
12
+
13
+ from sklearn.manifold import TSNE
14
+ from sentence_transformers import SentenceTransformer
15
 
16
+ # Original implementation from: https://huggingface.co/spaces/edugp/embedding-lenses/blob/main/app.py
17
  @st.cache(show_spinner=False, allow_output_mutation=True)
18
  def load_model(model_name: str) -> SentenceTransformer:
19
  embedder = model_name
20
  return SentenceTransformer(embedder)
21
 
22
+ def embed_text(text: List[str], model: SentenceTransformer) -> np.ndarray:
23
+ return model.encode(text)
24
+
25
+ def get_tsne_embeddings(
26
+ embeddings: np.ndarray, perplexity: int = 30, n_components: int = 2, init: str = "pca", n_iter: int = 5000, random_state: int = SEED
27
+ ) -> np.ndarray:
28
+ tsne = TSNE(perplexity=perplexity, n_components=n_components, init=init, n_iter=n_iter, random_state=random_state)
29
+ return tsne.fit_transform(embeddings)
30
+
31
+ def draw_interactive_scatter_plot(
32
+ texts: np.ndarray, xs: np.ndarray, ys: np.ndarray, values: np.ndarray, labels: np.ndarray, text_column: str, label_column: str
33
+ ) -> Figure:
34
+ # Normalize values to range between 0-255, to assign a color for each value
35
+ max_value = values.max()
36
+ min_value = values.min()
37
+ if max_value - min_value == 0:
38
+ values_color = np.ones(len(values))
39
+ else:
40
+ values_color = ((values - min_value) / (max_value - min_value) * 255).round().astype(int).astype(str)
41
+ values_color_set = sorted(values_color)
42
+ values_list = values.astype(str).tolist()
43
+ values_set = sorted(values_list)
44
+ labels_list = labels.astype(str).tolist()
45
+ source = ColumnDataSource(data=dict(x=xs, y=ys, text=texts, label=values_list, original_label=labels_list))
46
+ hover = HoverTool(tooltips=[(text_column, "@text{safe}"), (label_column, "@original_label")])
47
+ p = figure(plot_width=800, plot_height=800, tools=[hover])
48
+ p.circle("x", "y", size=10, source=source, fill_color=factor_cmap("label", palette=[Pallete[int(id_)] for id_ in values_color_set], factors=values_set))
49
+ p.axis.visible = False
50
+ p.xgrid.grid_line_color = None
51
+ p.ygrid.grid_line_color = None
52
+ p.toolbar.logo = None
53
+ return p
54
+
55
+ # Up to here
56
+ def generate_plot(
57
+ df: List[str],
58
+ labels: List[int],
59
+ model: SentenceTransformer,
60
+ ) -> Figure:
61
+ with st.spinner(text="Embedding text..."):
62
+ embeddings = embed_text(df, model)
63
+ logger.info("Encoding labels")
64
+ encoded_labels = encode_labels(labels)
65
+ with st.spinner("Reducing dimensionality..."):
66
+ embeddings_2d = get_tsne_embeddings(embeddings)
67
+ logger.info("Generating figure")
68
+ plot = draw_interactive_scatter_plot(
69
+ df, embeddings_2d[:, 0], embeddings_2d[:, 1], encoded_labels.values, labels, 'text', 'label'
70
+ )
71
+ return plot
72
+
73
 
74
  client = tweepy.Client(bearer_token=st.secrets["tw_bearer_token"])
75
  model_to_use = {
 
83
  with col1:
84
  tw_user = st.text_input("Twitter handle", "huggingface")
85
  with col2:
86
+ tw_sample = st.number_input("Maximum number of tweets to use", 1, 300, 100, 10)
87
 
88
  expected_lang = st.radio(
89
  "What language should be assumed to be found?",
90
+ ('English', 'Use all the ones you know (~15 lang)'),
91
+ 0
92
+ )
93
 
94
+ with st.spinner(text="Loading model..."):
95
+ model = load_model(model_to_use[expected_lang])
96
 
97
  usr = client.get_user(username=tw_user)
98
 
99
+ # st.write(usr.data.id)
100
+
101
+ if tw_user:
102
+ with st.spinner(f"Getting to know the '{tw_user}'..."):
103
+ tweets_objs = []
104
+ while tw_sample >= 100:
105
+ current_sample = min(100, tw_sample)
106
+ tweets_response = client.get_user_tweets(usr.data.id, max_results=current_sample)
107
+ tweets_objs += tweets_response.data
108
+ tw_sample -= current_sample
109
+ tweets_response = client.get_user_tweets(usr.data.id, max_results=tw_sample)
110
+ tweets_objs += tweets_response.data
111
+ tweets_txt = [tweet.text for tweet in tweets_objs]
112
+ labels = [0] * len(tweets_txt)
113
+ # plot = generate_plot(df, text_column, label_column, sample, dimensionality_reduction_function, model)
114
+ plot = generate_plot(tweets_txt, labels, model)
115
+ logger.info("Displaying plot")
116
+ st.bokeh_chart(plot)
117
+ logger.info("Done")