tushifire commited on
Commit
12f0ea9
Β·
1 Parent(s): 12e8f06

Added 3d Plot

Browse files
Files changed (1) hide show
  1. app.py +53 -70
app.py CHANGED
@@ -98,7 +98,6 @@ X /= X.std(axis=0)
98
  edge_model.fit(X)
99
 
100
 
101
-
102
  from sklearn import cluster
103
 
104
  _, labels = cluster.affinity_propagation(edge_model.covariance_, random_state=0)
@@ -118,13 +117,10 @@ embedding = node_position_model.fit_transform(X.T).T
118
 
119
  import matplotlib.pyplot as plt
120
  from matplotlib.collections import LineCollection
 
121
 
122
- def visualize_stocks():
123
- fig = plt.figure(1, facecolor="w", figsize=(10, 8))
124
- plt.clf()
125
- ax = plt.axes([0.0, 0.0, 1.0, 1.0])
126
- plt.axis("off")
127
 
 
128
  # Plot the graph of partial correlations
129
  partial_correlations = edge_model.precision_.copy()
130
  d = 1 / np.sqrt(np.diag(partial_correlations))
@@ -133,72 +129,53 @@ def visualize_stocks():
133
  non_zero = np.abs(np.triu(partial_correlations, k=1)) > 0.02
134
 
135
  # Plot the nodes using the coordinates of our embedding
136
- plt.scatter(
137
- embedding[0], embedding[1], s=100 * d**2, c=labels, cmap=plt.cm.nipy_spectral
 
 
 
 
 
 
138
  )
139
 
140
- # Plot the edges
141
  start_idx, end_idx = np.where(non_zero)
142
- # a sequence of (*line0*, *line1*, *line2*), where::
143
- # linen = (x0, y0), (x1, y1), ... (xm, ym)
144
  segments = [
145
- [embedding[:, start], embedding[:, stop]] for start, stop in zip(start_idx, end_idx)
 
 
 
 
 
 
 
 
146
  ]
147
- values = np.abs(partial_correlations[non_zero])
148
- lc = LineCollection(
149
- segments, zorder=0, cmap=plt.cm.hot_r, norm=plt.Normalize(0, 0.7 * values.max())
150
- )
151
- lc.set_array(values)
152
- lc.set_linewidths(15 * values)
153
- ax.add_collection(lc)
154
-
155
- # Add a label to each node. The challenge here is that we want to
156
- # position the labels to avoid overlap with other labels
157
- for index, (name, label, (x, y)) in enumerate(zip(names, labels, embedding.T)):
158
-
159
- dx = x - embedding[0]
160
- dx[index] = 1
161
- dy = y - embedding[1]
162
- dy[index] = 1
163
- this_dx = dx[np.argmin(np.abs(dy))]
164
- this_dy = dy[np.argmin(np.abs(dx))]
165
- if this_dx > 0:
166
- horizontalalignment = "left"
167
- x = x + 0.002
168
- else:
169
- horizontalalignment = "right"
170
- x = x - 0.002
171
- if this_dy > 0:
172
- verticalalignment = "bottom"
173
- y = y + 0.002
174
- else:
175
- verticalalignment = "top"
176
- y = y - 0.002
177
- plt.text(
178
- x,
179
- y,
180
- name,
181
- size=10,
182
- horizontalalignment=horizontalalignment,
183
- verticalalignment=verticalalignment,
184
- bbox=dict(
185
- facecolor="w",
186
- edgecolor=plt.cm.nipy_spectral(label / float(n_labels)),
187
- alpha=0.6,
188
  ),
189
  )
190
-
191
- plt.xlim(
192
- embedding[0].min() - 0.15 * embedding[0].ptp(),
193
- embedding[0].max() + 0.10 * embedding[0].ptp(),
194
- )
195
- plt.ylim(
196
- embedding[1].min() - 0.03 * embedding[1].ptp(),
197
- embedding[1].max() + 0.03 * embedding[1].ptp(),
198
- )
199
 
200
  return fig
201
-
 
202
  import gradio as gr
203
 
204
  title = " πŸ“ˆ Visualizing the stock market structure πŸ“ˆ"
@@ -206,14 +183,20 @@ title = " πŸ“ˆ Visualizing the stock market structure πŸ“ˆ"
206
  with gr.Blocks(title=title) as demo:
207
  gr.Markdown(f"# {title}")
208
  gr.Markdown(" Data is of 56 stocks between the period of 2003 - 2008 <br>")
209
- gr.Markdown(" Stocks the move in together with each other are grouped together in a cluster <br>")
 
 
210
 
211
- gr.Markdown(" **[Demo is based on sklearn docs](https://scikit-learn.org/stable/auto_examples/applications/plot_stock_market.html)**")
 
 
212
 
213
  for i in range(n_labels + 1):
214
- gr.Markdown( f"Cluster {i + 1}: {', '.join(names[labels == i])}")
215
-
216
  btn = gr.Button(value="Visualize")
217
- btn.click(visualize_stocks, outputs= gr.Plot(label='Visualizing stock into clusters') )
218
- gr.Markdown( f"## In progress")
219
- demo.launch()
 
 
 
98
  edge_model.fit(X)
99
 
100
 
 
101
  from sklearn import cluster
102
 
103
  _, labels = cluster.affinity_propagation(edge_model.covariance_, random_state=0)
 
117
 
118
  import matplotlib.pyplot as plt
119
  from matplotlib.collections import LineCollection
120
+ import plotly.graph_objs as go
121
 
 
 
 
 
 
122
 
123
+ def visualize_stocks():
124
  # Plot the graph of partial correlations
125
  partial_correlations = edge_model.precision_.copy()
126
  d = 1 / np.sqrt(np.diag(partial_correlations))
 
129
  non_zero = np.abs(np.triu(partial_correlations, k=1)) > 0.02
130
 
131
  # Plot the nodes using the coordinates of our embedding
132
+ scatter = go.Scatter3d(
133
+ x=embedding[0],
134
+ y=embedding[1],
135
+ z=embedding[2],
136
+ mode="markers",
137
+ marker=dict(size=35 * d**2, color=labels, colorscale="Viridis"),
138
+ hovertext=names,
139
+ hovertemplate="%{hovertext}<br>",
140
  )
141
 
142
+ # # Plot the edges
143
  start_idx, end_idx = np.where(non_zero)
144
+ # print(non_zero, non_zero.shape)
145
+ # print(start_idx, start_idx.shape)
146
  segments = [
147
+ dict(
148
+ x=[embedding[0][start], embedding[0][stop]],
149
+ y=[embedding[1][start], embedding[1][stop]],
150
+ z=[embedding[2][start], embedding[2][stop]],
151
+ colorscale="Hot",
152
+ color=np.abs(partial_correlations[start, stop]),
153
+ line=dict(width=10 * np.abs(partial_correlations[start, stop])),
154
+ )
155
+ for start, stop in zip(start_idx, end_idx)
156
  ]
157
+ fig = go.Figure(data=[scatter])
158
+
159
+ for idx, segment in enumerate(segments, 1):
160
+ fig.add_trace(
161
+ go.Scatter3d(
162
+ x=segment["x"], # x-coordinates of the line segment
163
+ y=segment["y"], # y-coordinates of the line segment
164
+ z=segment["z"], # z-coordinates of the line segment
165
+ mode="lines", # type of the plot (line)
166
+ line=dict(
167
+ color=segment["color"], # color of the line
168
+ colorscale=segment["colorscale"], # color scale of the line
169
+ width=segment["line"]["width"] * 2.5, # width of the line
170
+ ),
171
+ hoverinfo="none", # disable hover for the line segments
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  ),
173
  )
174
+ fig.data[idx].showlegend = False
 
 
 
 
 
 
 
 
175
 
176
  return fig
177
+
178
+
179
  import gradio as gr
180
 
181
  title = " πŸ“ˆ Visualizing the stock market structure πŸ“ˆ"
 
183
  with gr.Blocks(title=title) as demo:
184
  gr.Markdown(f"# {title}")
185
  gr.Markdown(" Data is of 56 stocks between the period of 2003 - 2008 <br>")
186
+ gr.Markdown(
187
+ " Stocks the move in together with each other are grouped together in a cluster <br>"
188
+ )
189
 
190
+ gr.Markdown(
191
+ " **[Demo is based on sklearn docs](https://scikit-learn.org/stable/auto_examples/applications/plot_stock_market.html)**"
192
+ )
193
 
194
  for i in range(n_labels + 1):
195
+ gr.Markdown(f"Cluster {i + 1}: {', '.join(names[labels == i])}")
196
+
197
  btn = gr.Button(value="Visualize")
198
+ btn.click(
199
+ visualize_stocks, outputs=gr.Plot(label="Visualizing stock into clusters")
200
+ )
201
+ gr.Markdown(f"## In progress")
202
+ demo.launch()