LenixC commited on
Commit
15f8762
·
1 Parent(s): 936c6dc

Switched to Radio for metric, added padding to graphs to prevent the label overlapping.

Browse files
Files changed (2) hide show
  1. __pycache__/app.cpython-39.pyc +0 -0
  2. app.py +27 -19
__pycache__/app.cpython-39.pyc ADDED
Binary file (4.62 kB). View file
 
app.py CHANGED
@@ -44,19 +44,18 @@ def ground_truth_plot(n_features):
44
 
45
  X = np.array(X)
46
  y = np.array(y)
 
 
47
 
48
- # Plot the ground-truth labelling
49
- gt_plot = plt.figure()
50
- plt.axes([0, 0, 1, 1])
51
  for l, color, n in zip(range(n_clusters), colors, labels):
52
  lines = plt.plot(X[y == l].T, c=color, alpha=0.5)
53
  lines[0].set_label(n)
54
 
 
 
55
  plt.legend(loc="best")
56
-
57
- plt.axis("tight")
58
  plt.axis("off")
59
- plt.suptitle("Ground truth", size=20, y=1)
60
  return gt_plot, X, y
61
 
62
  def plot_cluster_waves(metric, X, y):
@@ -64,18 +63,21 @@ def plot_cluster_waves(metric, X, y):
64
  n_clusters=n_clusters, linkage="average", metric=metric
65
  )
66
  model.fit(X)
67
- clust_plot = plt.figure()
68
- plt.axes([0, 0, 1, 1])
69
  for l, color in zip(np.arange(model.n_clusters), colors):
70
  plt.plot(X[model.labels_ == l].T, c=color, alpha=0.5)
 
 
 
71
  plt.axis("tight")
72
  plt.axis("off")
73
- plt.suptitle("AgglomerativeClustering(metric=%s)" % metric, size=20, y=1)
74
  return clust_plot
75
 
76
  def plot_distances(metric, X, y):
77
  avg_dist = np.zeros((n_clusters, n_clusters))
78
- dist_plot = plt.figure()
 
79
  for i in range(n_clusters):
80
  for j in range(n_clusters):
81
  avg_dist[i, j] = pairwise_distances(
@@ -99,8 +101,9 @@ def plot_distances(metric, X, y):
99
  plt.xticks(range(n_clusters), labels, rotation=45)
100
  plt.yticks(range(n_clusters), labels)
101
  plt.colorbar()
102
- plt.suptitle("Interclass %s distances" % metric, size=18, y=1)
103
- plt.tight_layout()
 
104
  return dist_plot
105
 
106
  def agg_cluster(n_feats, measure):
@@ -123,18 +126,23 @@ with gr.Blocks() as demo:
123
  with gr.Row():
124
  with gr.Column():
125
  n_feats = gr.Slider(10, 4000, 2000, label="Number of Features")
126
- measure = gr.Dropdown(["cosine", "euclidean", "cityblock"], value="cosine")
127
- btn = gr.Button(label="Run")
128
  gt_graph = gr.Plot(label="Ground Truth Graph")
 
129
  with gr.Row():
130
  dist_plot = gr.Plot(label="Interclass Distances")
131
  clust_waves = gr.Plot(label="Agglomerative Clustering")
132
 
133
- btn.click(
134
- fn=agg_cluster,
135
- inputs=[n_feats, measure],
136
- outputs=[gt_graph, clust_waves, dist_plot]
137
- )
 
 
 
 
 
138
 
139
  if __name__ == '__main__':
140
  demo.launch()
 
44
 
45
  X = np.array(X)
46
  y = np.array(y)
47
+
48
+ gt_plot, ax = plt.subplots()
49
 
 
 
 
50
  for l, color, n in zip(range(n_clusters), colors, labels):
51
  lines = plt.plot(X[y == l].T, c=color, alpha=0.5)
52
  lines[0].set_label(n)
53
 
54
+ plt.subplots_adjust(top=0.8, bottom=0, left=0, right=1.0)
55
+ ax.set_title("Ground Truth", size=20, pad=1)
56
  plt.legend(loc="best")
 
 
57
  plt.axis("off")
58
+
59
  return gt_plot, X, y
60
 
61
  def plot_cluster_waves(metric, X, y):
 
63
  n_clusters=n_clusters, linkage="average", metric=metric
64
  )
65
  model.fit(X)
66
+
67
+ clust_plot, ax = plt.subplots()
68
  for l, color in zip(np.arange(model.n_clusters), colors):
69
  plt.plot(X[model.labels_ == l].T, c=color, alpha=0.5)
70
+
71
+ plt.subplots_adjust(top=0.75, bottom=0, left=0, right=1.0)
72
+ ax.set_title("Agglomerative Clustering\n(metric=%s)" % metric, size=20, pad=1.0)
73
  plt.axis("tight")
74
  plt.axis("off")
 
75
  return clust_plot
76
 
77
  def plot_distances(metric, X, y):
78
  avg_dist = np.zeros((n_clusters, n_clusters))
79
+ dist_plot, ax = plt.subplots()
80
+
81
  for i in range(n_clusters):
82
  for j in range(n_clusters):
83
  avg_dist[i, j] = pairwise_distances(
 
101
  plt.xticks(range(n_clusters), labels, rotation=45)
102
  plt.yticks(range(n_clusters), labels)
103
  plt.colorbar()
104
+ plt.subplots_adjust(top=0.8)
105
+ ax.set_title("Interclass %s distances" % metric, size=20, pad=1.0)
106
+ plt.axis("off")
107
  return dist_plot
108
 
109
  def agg_cluster(n_feats, measure):
 
126
  with gr.Row():
127
  with gr.Column():
128
  n_feats = gr.Slider(10, 4000, 2000, label="Number of Features")
129
+ measure = gr.Radio(["cosine", "euclidean", "cityblock"], label="Metric", value="cosine")
 
130
  gt_graph = gr.Plot(label="Ground Truth Graph")
131
+ gt_graph.style()
132
  with gr.Row():
133
  dist_plot = gr.Plot(label="Interclass Distances")
134
  clust_waves = gr.Plot(label="Agglomerative Clustering")
135
 
136
+ n_feats.change(
137
+ fn=agg_cluster,
138
+ inputs=[n_feats, measure],
139
+ outputs=[gt_graph, clust_waves, dist_plot]
140
+ )
141
+ measure.change(
142
+ fn=agg_cluster,
143
+ inputs=[n_feats, measure],
144
+ outputs=[gt_graph, clust_waves, dist_plot]
145
+ )
146
 
147
  if __name__ == '__main__':
148
  demo.launch()