Added 3d Plot
Browse files
app.py
CHANGED
@@ -98,7 +98,6 @@ X /= X.std(axis=0)
|
|
98 |
edge_model.fit(X)
|
99 |
|
100 |
|
101 |
-
|
102 |
from sklearn import cluster
|
103 |
|
104 |
_, labels = cluster.affinity_propagation(edge_model.covariance_, random_state=0)
|
@@ -118,13 +117,10 @@ embedding = node_position_model.fit_transform(X.T).T
|
|
118 |
|
119 |
import matplotlib.pyplot as plt
|
120 |
from matplotlib.collections import LineCollection
|
|
|
121 |
|
122 |
-
def visualize_stocks():
|
123 |
-
fig = plt.figure(1, facecolor="w", figsize=(10, 8))
|
124 |
-
plt.clf()
|
125 |
-
ax = plt.axes([0.0, 0.0, 1.0, 1.0])
|
126 |
-
plt.axis("off")
|
127 |
|
|
|
128 |
# Plot the graph of partial correlations
|
129 |
partial_correlations = edge_model.precision_.copy()
|
130 |
d = 1 / np.sqrt(np.diag(partial_correlations))
|
@@ -133,72 +129,53 @@ def visualize_stocks():
|
|
133 |
non_zero = np.abs(np.triu(partial_correlations, k=1)) > 0.02
|
134 |
|
135 |
# Plot the nodes using the coordinates of our embedding
|
136 |
-
|
137 |
-
embedding[0],
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
)
|
139 |
|
140 |
-
# Plot the edges
|
141 |
start_idx, end_idx = np.where(non_zero)
|
142 |
-
#
|
143 |
-
#
|
144 |
segments = [
|
145 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
]
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
dy[index] = 1
|
163 |
-
this_dx = dx[np.argmin(np.abs(dy))]
|
164 |
-
this_dy = dy[np.argmin(np.abs(dx))]
|
165 |
-
if this_dx > 0:
|
166 |
-
horizontalalignment = "left"
|
167 |
-
x = x + 0.002
|
168 |
-
else:
|
169 |
-
horizontalalignment = "right"
|
170 |
-
x = x - 0.002
|
171 |
-
if this_dy > 0:
|
172 |
-
verticalalignment = "bottom"
|
173 |
-
y = y + 0.002
|
174 |
-
else:
|
175 |
-
verticalalignment = "top"
|
176 |
-
y = y - 0.002
|
177 |
-
plt.text(
|
178 |
-
x,
|
179 |
-
y,
|
180 |
-
name,
|
181 |
-
size=10,
|
182 |
-
horizontalalignment=horizontalalignment,
|
183 |
-
verticalalignment=verticalalignment,
|
184 |
-
bbox=dict(
|
185 |
-
facecolor="w",
|
186 |
-
edgecolor=plt.cm.nipy_spectral(label / float(n_labels)),
|
187 |
-
alpha=0.6,
|
188 |
),
|
189 |
)
|
190 |
-
|
191 |
-
plt.xlim(
|
192 |
-
embedding[0].min() - 0.15 * embedding[0].ptp(),
|
193 |
-
embedding[0].max() + 0.10 * embedding[0].ptp(),
|
194 |
-
)
|
195 |
-
plt.ylim(
|
196 |
-
embedding[1].min() - 0.03 * embedding[1].ptp(),
|
197 |
-
embedding[1].max() + 0.03 * embedding[1].ptp(),
|
198 |
-
)
|
199 |
|
200 |
return fig
|
201 |
-
|
|
|
202 |
import gradio as gr
|
203 |
|
204 |
title = " π Visualizing the stock market structure π"
|
@@ -206,14 +183,20 @@ title = " π Visualizing the stock market structure π"
|
|
206 |
with gr.Blocks(title=title) as demo:
|
207 |
gr.Markdown(f"# {title}")
|
208 |
gr.Markdown(" Data is of 56 stocks between the period of 2003 - 2008 <br>")
|
209 |
-
gr.Markdown(
|
|
|
|
|
210 |
|
211 |
-
gr.Markdown(
|
|
|
|
|
212 |
|
213 |
for i in range(n_labels + 1):
|
214 |
-
gr.Markdown(
|
215 |
-
|
216 |
btn = gr.Button(value="Visualize")
|
217 |
-
btn.click(
|
218 |
-
|
219 |
-
|
|
|
|
|
|
98 |
edge_model.fit(X)
|
99 |
|
100 |
|
|
|
101 |
from sklearn import cluster
|
102 |
|
103 |
_, labels = cluster.affinity_propagation(edge_model.covariance_, random_state=0)
|
|
|
117 |
|
118 |
import matplotlib.pyplot as plt
|
119 |
from matplotlib.collections import LineCollection
|
120 |
+
import plotly.graph_objs as go
|
121 |
|
|
|
|
|
|
|
|
|
|
|
122 |
|
123 |
+
def visualize_stocks():
|
124 |
# Plot the graph of partial correlations
|
125 |
partial_correlations = edge_model.precision_.copy()
|
126 |
d = 1 / np.sqrt(np.diag(partial_correlations))
|
|
|
129 |
non_zero = np.abs(np.triu(partial_correlations, k=1)) > 0.02
|
130 |
|
131 |
# Plot the nodes using the coordinates of our embedding
|
132 |
+
scatter = go.Scatter3d(
|
133 |
+
x=embedding[0],
|
134 |
+
y=embedding[1],
|
135 |
+
z=embedding[2],
|
136 |
+
mode="markers",
|
137 |
+
marker=dict(size=35 * d**2, color=labels, colorscale="Viridis"),
|
138 |
+
hovertext=names,
|
139 |
+
hovertemplate="%{hovertext}<br>",
|
140 |
)
|
141 |
|
142 |
+
# # Plot the edges
|
143 |
start_idx, end_idx = np.where(non_zero)
|
144 |
+
# print(non_zero, non_zero.shape)
|
145 |
+
# print(start_idx, start_idx.shape)
|
146 |
segments = [
|
147 |
+
dict(
|
148 |
+
x=[embedding[0][start], embedding[0][stop]],
|
149 |
+
y=[embedding[1][start], embedding[1][stop]],
|
150 |
+
z=[embedding[2][start], embedding[2][stop]],
|
151 |
+
colorscale="Hot",
|
152 |
+
color=np.abs(partial_correlations[start, stop]),
|
153 |
+
line=dict(width=10 * np.abs(partial_correlations[start, stop])),
|
154 |
+
)
|
155 |
+
for start, stop in zip(start_idx, end_idx)
|
156 |
]
|
157 |
+
fig = go.Figure(data=[scatter])
|
158 |
+
|
159 |
+
for idx, segment in enumerate(segments, 1):
|
160 |
+
fig.add_trace(
|
161 |
+
go.Scatter3d(
|
162 |
+
x=segment["x"], # x-coordinates of the line segment
|
163 |
+
y=segment["y"], # y-coordinates of the line segment
|
164 |
+
z=segment["z"], # z-coordinates of the line segment
|
165 |
+
mode="lines", # type of the plot (line)
|
166 |
+
line=dict(
|
167 |
+
color=segment["color"], # color of the line
|
168 |
+
colorscale=segment["colorscale"], # color scale of the line
|
169 |
+
width=segment["line"]["width"] * 2.5, # width of the line
|
170 |
+
),
|
171 |
+
hoverinfo="none", # disable hover for the line segments
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
),
|
173 |
)
|
174 |
+
fig.data[idx].showlegend = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
|
176 |
return fig
|
177 |
+
|
178 |
+
|
179 |
import gradio as gr
|
180 |
|
181 |
title = " π Visualizing the stock market structure π"
|
|
|
183 |
with gr.Blocks(title=title) as demo:
|
184 |
gr.Markdown(f"# {title}")
|
185 |
gr.Markdown(" Data is of 56 stocks between the period of 2003 - 2008 <br>")
|
186 |
+
gr.Markdown(
|
187 |
+
" Stocks the move in together with each other are grouped together in a cluster <br>"
|
188 |
+
)
|
189 |
|
190 |
+
gr.Markdown(
|
191 |
+
" **[Demo is based on sklearn docs](https://scikit-learn.org/stable/auto_examples/applications/plot_stock_market.html)**"
|
192 |
+
)
|
193 |
|
194 |
for i in range(n_labels + 1):
|
195 |
+
gr.Markdown(f"Cluster {i + 1}: {', '.join(names[labels == i])}")
|
196 |
+
|
197 |
btn = gr.Button(value="Visualize")
|
198 |
+
btn.click(
|
199 |
+
visualize_stocks, outputs=gr.Plot(label="Visualizing stock into clusters")
|
200 |
+
)
|
201 |
+
gr.Markdown(f"## In progress")
|
202 |
+
demo.launch()
|