tushifire commited on
Commit
12e8f06
Β·
1 Parent(s): 0d803eb

Adding visualization function

Browse files
Files changed (1) hide show
  1. app.py +96 -1
app.py CHANGED
@@ -105,7 +105,100 @@ _, labels = cluster.affinity_propagation(edge_model.covariance_, random_state=0)
105
  n_labels = labels.max()
106
 
107
 
 
 
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  import gradio as gr
110
 
111
  title = " πŸ“ˆ Visualizing the stock market structure πŸ“ˆ"
@@ -119,6 +212,8 @@ with gr.Blocks(title=title) as demo:
119
 
120
  for i in range(n_labels + 1):
121
  gr.Markdown( f"Cluster {i + 1}: {', '.join(names[labels == i])}")
122
-
 
 
123
  gr.Markdown( f"## In progress")
124
  demo.launch()
 
105
  n_labels = labels.max()
106
 
107
 
108
+ # Finding a low-dimension embedding for visualization: find the best position of
109
+ # the nodes (the stocks) on a 2D plane
110
 
111
+ from sklearn import manifold
112
+
113
+ node_position_model = manifold.LocallyLinearEmbedding(
114
+ n_components=2, eigen_solver="dense", n_neighbors=6
115
+ )
116
+
117
+ embedding = node_position_model.fit_transform(X.T).T
118
+
119
+ import matplotlib.pyplot as plt
120
+ from matplotlib.collections import LineCollection
121
+
122
+ def visualize_stocks():
123
+ fig = plt.figure(1, facecolor="w", figsize=(10, 8))
124
+ plt.clf()
125
+ ax = plt.axes([0.0, 0.0, 1.0, 1.0])
126
+ plt.axis("off")
127
+
128
+ # Plot the graph of partial correlations
129
+ partial_correlations = edge_model.precision_.copy()
130
+ d = 1 / np.sqrt(np.diag(partial_correlations))
131
+ partial_correlations *= d
132
+ partial_correlations *= d[:, np.newaxis]
133
+ non_zero = np.abs(np.triu(partial_correlations, k=1)) > 0.02
134
+
135
+ # Plot the nodes using the coordinates of our embedding
136
+ plt.scatter(
137
+ embedding[0], embedding[1], s=100 * d**2, c=labels, cmap=plt.cm.nipy_spectral
138
+ )
139
+
140
+ # Plot the edges
141
+ start_idx, end_idx = np.where(non_zero)
142
+ # a sequence of (*line0*, *line1*, *line2*), where::
143
+ # linen = (x0, y0), (x1, y1), ... (xm, ym)
144
+ segments = [
145
+ [embedding[:, start], embedding[:, stop]] for start, stop in zip(start_idx, end_idx)
146
+ ]
147
+ values = np.abs(partial_correlations[non_zero])
148
+ lc = LineCollection(
149
+ segments, zorder=0, cmap=plt.cm.hot_r, norm=plt.Normalize(0, 0.7 * values.max())
150
+ )
151
+ lc.set_array(values)
152
+ lc.set_linewidths(15 * values)
153
+ ax.add_collection(lc)
154
+
155
+ # Add a label to each node. The challenge here is that we want to
156
+ # position the labels to avoid overlap with other labels
157
+ for index, (name, label, (x, y)) in enumerate(zip(names, labels, embedding.T)):
158
+
159
+ dx = x - embedding[0]
160
+ dx[index] = 1
161
+ dy = y - embedding[1]
162
+ dy[index] = 1
163
+ this_dx = dx[np.argmin(np.abs(dy))]
164
+ this_dy = dy[np.argmin(np.abs(dx))]
165
+ if this_dx > 0:
166
+ horizontalalignment = "left"
167
+ x = x + 0.002
168
+ else:
169
+ horizontalalignment = "right"
170
+ x = x - 0.002
171
+ if this_dy > 0:
172
+ verticalalignment = "bottom"
173
+ y = y + 0.002
174
+ else:
175
+ verticalalignment = "top"
176
+ y = y - 0.002
177
+ plt.text(
178
+ x,
179
+ y,
180
+ name,
181
+ size=10,
182
+ horizontalalignment=horizontalalignment,
183
+ verticalalignment=verticalalignment,
184
+ bbox=dict(
185
+ facecolor="w",
186
+ edgecolor=plt.cm.nipy_spectral(label / float(n_labels)),
187
+ alpha=0.6,
188
+ ),
189
+ )
190
+
191
+ plt.xlim(
192
+ embedding[0].min() - 0.15 * embedding[0].ptp(),
193
+ embedding[0].max() + 0.10 * embedding[0].ptp(),
194
+ )
195
+ plt.ylim(
196
+ embedding[1].min() - 0.03 * embedding[1].ptp(),
197
+ embedding[1].max() + 0.03 * embedding[1].ptp(),
198
+ )
199
+
200
+ return fig
201
+
202
  import gradio as gr
203
 
204
  title = " πŸ“ˆ Visualizing the stock market structure πŸ“ˆ"
 
212
 
213
  for i in range(n_labels + 1):
214
  gr.Markdown( f"Cluster {i + 1}: {', '.join(names[labels == i])}")
215
+
216
+ btn = gr.Button(value="Visualize")
217
+ btn.click(visualize_stocks, outputs= gr.Plot(label='Visualizing stock into clusters') )
218
  gr.Markdown( f"## In progress")
219
  demo.launch()