Catherine Breslin commited on
Commit
2486cd2
·
1 Parent(s): 8e8b7d6

Clustering

Browse files
Files changed (1) hide show
  1. app.py +18 -0
app.py CHANGED
@@ -7,6 +7,21 @@ import numpy as np
7
  import seaborn as sns
8
  import matplotlib.pyplot as plt
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  def plot_heatmap(labels, heatmap, rotation=90):
11
  sns.set(font_scale=1.2)
12
  fig, ax = plt.subplots()
@@ -29,6 +44,8 @@ st.markdown("This demo uses the sentence_transformers library to plot sentence s
29
  # Streamlit text boxes
30
  text = st.text_area('Enter sentences:', value="The sun is hotter than the moon.\nThe sun is very bright.\nI hear that the universe is very large.\nToday is Tuesday.")
31
 
 
 
32
  # Model setup
33
  model = SentenceTransformer('paraphrase-distilroberta-base-v1')
34
  nltk.download('punkt')
@@ -43,5 +60,6 @@ if text:
43
  for j,ea in enumerate(embed):
44
  sim[i][j] = 1.0-cosine(em,ea)
45
  plot_heatmap(sentences, sim)
 
46
 
47
 
 
7
  import seaborn as sns
8
  import matplotlib.pyplot as plt
9
 
10
+ def cluster_examples(messages, embed, nc=3):
11
+ km = KMeans(
12
+ n_clusters=nc, init='random',
13
+ n_init=10, max_iter=300,
14
+ tol=1e-04, random_state=0
15
+ )
16
+ km = km.fit_predict(embed)
17
+ for n in range(nc):
18
+ idxs = [i for i in range(len(km)) if km[i] == n]
19
+ ms = [messages[i] for i in idxs]
20
+ st.markdown ("CLUSTER : %d"%n)
21
+ for m in ms:
22
+ st.markdown (m)
23
+
24
+
25
  def plot_heatmap(labels, heatmap, rotation=90):
26
  sns.set(font_scale=1.2)
27
  fig, ax = plt.subplots()
 
44
  # Streamlit text boxes
45
  text = st.text_area('Enter sentences:', value="The sun is hotter than the moon.\nThe sun is very bright.\nI hear that the universe is very large.\nToday is Tuesday.")
46
 
47
+ nc = st.slider('Select a number of clusters', min_value=1, max_value=15, value=3)
48
+
49
  # Model setup
50
  model = SentenceTransformer('paraphrase-distilroberta-base-v1')
51
  nltk.download('punkt')
 
60
  for j,ea in enumerate(embed):
61
  sim[i][j] = 1.0-cosine(em,ea)
62
  plot_heatmap(sentences, sim)
63
+ cluster_examples(sentences, embed, nc)
64
 
65