Switched to Radio for metric, added padding to graphs to prevent the label overlapping.
Browse files- __pycache__/app.cpython-39.pyc +0 -0
- app.py +27 -19
__pycache__/app.cpython-39.pyc
ADDED
Binary file (4.62 kB). View file
|
|
app.py
CHANGED
@@ -44,19 +44,18 @@ def ground_truth_plot(n_features):
|
|
44 |
|
45 |
X = np.array(X)
|
46 |
y = np.array(y)
|
|
|
|
|
47 |
|
48 |
-
# Plot the ground-truth labelling
|
49 |
-
gt_plot = plt.figure()
|
50 |
-
plt.axes([0, 0, 1, 1])
|
51 |
for l, color, n in zip(range(n_clusters), colors, labels):
|
52 |
lines = plt.plot(X[y == l].T, c=color, alpha=0.5)
|
53 |
lines[0].set_label(n)
|
54 |
|
|
|
|
|
55 |
plt.legend(loc="best")
|
56 |
-
|
57 |
-
plt.axis("tight")
|
58 |
plt.axis("off")
|
59 |
-
|
60 |
return gt_plot, X, y
|
61 |
|
62 |
def plot_cluster_waves(metric, X, y):
|
@@ -64,18 +63,21 @@ def plot_cluster_waves(metric, X, y):
|
|
64 |
n_clusters=n_clusters, linkage="average", metric=metric
|
65 |
)
|
66 |
model.fit(X)
|
67 |
-
|
68 |
-
|
69 |
for l, color in zip(np.arange(model.n_clusters), colors):
|
70 |
plt.plot(X[model.labels_ == l].T, c=color, alpha=0.5)
|
|
|
|
|
|
|
71 |
plt.axis("tight")
|
72 |
plt.axis("off")
|
73 |
-
plt.suptitle("AgglomerativeClustering(metric=%s)" % metric, size=20, y=1)
|
74 |
return clust_plot
|
75 |
|
76 |
def plot_distances(metric, X, y):
|
77 |
avg_dist = np.zeros((n_clusters, n_clusters))
|
78 |
-
dist_plot = plt.
|
|
|
79 |
for i in range(n_clusters):
|
80 |
for j in range(n_clusters):
|
81 |
avg_dist[i, j] = pairwise_distances(
|
@@ -99,8 +101,9 @@ def plot_distances(metric, X, y):
|
|
99 |
plt.xticks(range(n_clusters), labels, rotation=45)
|
100 |
plt.yticks(range(n_clusters), labels)
|
101 |
plt.colorbar()
|
102 |
-
plt.
|
103 |
-
|
|
|
104 |
return dist_plot
|
105 |
|
106 |
def agg_cluster(n_feats, measure):
|
@@ -123,18 +126,23 @@ with gr.Blocks() as demo:
|
|
123 |
with gr.Row():
|
124 |
with gr.Column():
|
125 |
n_feats = gr.Slider(10, 4000, 2000, label="Number of Features")
|
126 |
-
measure = gr.
|
127 |
-
btn = gr.Button(label="Run")
|
128 |
gt_graph = gr.Plot(label="Ground Truth Graph")
|
|
|
129 |
with gr.Row():
|
130 |
dist_plot = gr.Plot(label="Interclass Distances")
|
131 |
clust_waves = gr.Plot(label="Agglomerative Clustering")
|
132 |
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
|
|
|
|
|
|
|
|
|
|
138 |
|
139 |
if __name__ == '__main__':
|
140 |
demo.launch()
|
|
|
44 |
|
45 |
X = np.array(X)
|
46 |
y = np.array(y)
|
47 |
+
|
48 |
+
gt_plot, ax = plt.subplots()
|
49 |
|
|
|
|
|
|
|
50 |
for l, color, n in zip(range(n_clusters), colors, labels):
|
51 |
lines = plt.plot(X[y == l].T, c=color, alpha=0.5)
|
52 |
lines[0].set_label(n)
|
53 |
|
54 |
+
plt.subplots_adjust(top=0.8, bottom=0, left=0, right=1.0)
|
55 |
+
ax.set_title("Ground Truth", size=20, pad=1)
|
56 |
plt.legend(loc="best")
|
|
|
|
|
57 |
plt.axis("off")
|
58 |
+
|
59 |
return gt_plot, X, y
|
60 |
|
61 |
def plot_cluster_waves(metric, X, y):
|
|
|
63 |
n_clusters=n_clusters, linkage="average", metric=metric
|
64 |
)
|
65 |
model.fit(X)
|
66 |
+
|
67 |
+
clust_plot, ax = plt.subplots()
|
68 |
for l, color in zip(np.arange(model.n_clusters), colors):
|
69 |
plt.plot(X[model.labels_ == l].T, c=color, alpha=0.5)
|
70 |
+
|
71 |
+
plt.subplots_adjust(top=0.75, bottom=0, left=0, right=1.0)
|
72 |
+
ax.set_title("Agglomerative Clustering\n(metric=%s)" % metric, size=20, pad=1.0)
|
73 |
plt.axis("tight")
|
74 |
plt.axis("off")
|
|
|
75 |
return clust_plot
|
76 |
|
77 |
def plot_distances(metric, X, y):
|
78 |
avg_dist = np.zeros((n_clusters, n_clusters))
|
79 |
+
dist_plot, ax = plt.subplots()
|
80 |
+
|
81 |
for i in range(n_clusters):
|
82 |
for j in range(n_clusters):
|
83 |
avg_dist[i, j] = pairwise_distances(
|
|
|
101 |
plt.xticks(range(n_clusters), labels, rotation=45)
|
102 |
plt.yticks(range(n_clusters), labels)
|
103 |
plt.colorbar()
|
104 |
+
plt.subplots_adjust(top=0.8)
|
105 |
+
ax.set_title("Interclass %s distances" % metric, size=20, pad=1.0)
|
106 |
+
plt.axis("off")
|
107 |
return dist_plot
|
108 |
|
109 |
def agg_cluster(n_feats, measure):
|
|
|
126 |
with gr.Row():
|
127 |
with gr.Column():
|
128 |
n_feats = gr.Slider(10, 4000, 2000, label="Number of Features")
|
129 |
+
measure = gr.Radio(["cosine", "euclidean", "cityblock"], label="Metric", value="cosine")
|
|
|
130 |
gt_graph = gr.Plot(label="Ground Truth Graph")
|
131 |
+
gt_graph.style()
|
132 |
with gr.Row():
|
133 |
dist_plot = gr.Plot(label="Interclass Distances")
|
134 |
clust_waves = gr.Plot(label="Agglomerative Clustering")
|
135 |
|
136 |
+
n_feats.change(
|
137 |
+
fn=agg_cluster,
|
138 |
+
inputs=[n_feats, measure],
|
139 |
+
outputs=[gt_graph, clust_waves, dist_plot]
|
140 |
+
)
|
141 |
+
measure.change(
|
142 |
+
fn=agg_cluster,
|
143 |
+
inputs=[n_feats, measure],
|
144 |
+
outputs=[gt_graph, clust_waves, dist_plot]
|
145 |
+
)
|
146 |
|
147 |
if __name__ == '__main__':
|
148 |
demo.launch()
|