wilmerags commited on
Commit
7f1a60e
·
1 Parent(s): e2fe46b

test: Experiment with keyword selection for topics

Browse files
Files changed (1) hide show
  1. app.py +20 -1
app.py CHANGED
@@ -1,5 +1,6 @@
1
  from typing import List
2
 
 
3
  import string
4
  import re
5
  import requests
@@ -138,10 +139,28 @@ def generate_plot(
138
  cluster_selection_method='eom'
139
  ).fit(embeddings)
140
  encoded_labels = cluster.labels_
 
141
  with st.spinner("Now trying to express them with my own words... 💬"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  embeddings_2d = get_tsne_embeddings(embeddings)
143
  plot = draw_interactive_scatter_plot(
144
- tws, embeddings_2d[:, 0], embeddings_2d[:, 1], encoded_labels, encoded_labels, 'Tweet', 'Topic'
145
  )
146
  return plot
147
 
 
1
  from typing import List
2
 
3
+ import itertools
4
  import string
5
  import re
6
  import requests
 
139
  cluster_selection_method='eom'
140
  ).fit(embeddings)
141
  encoded_labels = cluster.labels_
142
+ cluster_keyword = {}
143
  with st.spinner("Now trying to express them with my own words... 💬"):
144
+ for label in set(encoded_labels):
145
+ cluster_keyword[label] = []
146
+ cluster_tws = []
147
+ for ix, obs in enumerate(encoded_labels):
148
+ if obs == label:
149
+ cluster_tws.append(tws_cleaned)
150
+ cluster_words = [tw.split(' ') for tw in cluster_tws]
151
+ cluster_words = list(set(itertools.chain.from_iterable(cluster_words)))
152
+ cluster_embeddings = embed_text(cluster_tws, model)
153
+ cluster_embeddings_avg = np.mean(cluster_embeddings, axis=0)
154
+ cluster_words_embeddings = embed_text(cluster_words, model)
155
+ cluster_to_words_similarities = util.dot_score(cluster_embeddings_avg, cluster_words_embeddings)
156
+ while len(cluster_keyword[label]) < 3:
157
+ most_descriptive = np.argmax(cluster_to_words_similarities)
158
+ del cluster_to_words_similarities[most_descriptive]
159
+ cluster_keyword[label].append(cluster_words[most_descriptive])
160
+ encoded_labels_keywords = [cluster_keyword[encoded_label] for encoded_label in encoded_labels]
161
  embeddings_2d = get_tsne_embeddings(embeddings)
162
  plot = draw_interactive_scatter_plot(
163
+ tws, embeddings_2d[:, 0], embeddings_2d[:, 1], encoded_labels, encoded_labels_keywords, 'Tweet', 'Topic'
164
  )
165
  return plot
166