none commited on
Commit
ffd3002
·
1 Parent(s): 50cf0f6

Add some text

Browse files
Files changed (1) hide show
  1. streamlit_viz.py +132 -55
streamlit_viz.py CHANGED
@@ -56,43 +56,94 @@ FEATS = [
56
  'ct_dst_src_ltm',
57
  ]
58
 
 
 
59
  COLORS = [
60
- 'aliceblue','aqua','aquamarine','azure',
61
- 'bisque','black','blanchedalmond','blue',
62
- 'blueviolet','brown','burlywood','cadetblue',
63
- 'chartreuse','chocolate','coral','cornflowerblue',
64
- 'cornsilk','crimson','cyan','darkblue','darkcyan',
65
- 'darkgoldenrod','darkgray','darkgreen',
66
- 'darkkhaki','darkmagenta','darkolivegreen','darkorange',
67
- 'darkorchid','darkred','darksalmon','darkseagreen',
68
- 'darkslateblue','darkslategray',
69
- 'darkturquoise','darkviolet','deeppink','deepskyblue',
70
- 'dimgray','dodgerblue',
71
- 'forestgreen','fuchsia','gainsboro',
72
- 'gold','goldenrod','gray','green',
73
- 'greenyellow','honeydew','hotpink','indianred','indigo',
74
- 'ivory','khaki','lavender','lavenderblush','lawngreen',
75
- 'lemonchiffon','lightblue','lightcoral','lightcyan',
76
- 'lightgoldenrodyellow','lightgray',
77
- 'lightgreen','lightpink','lightsalmon','lightseagreen',
78
- 'lightskyblue','lightslategray',
79
- 'lightsteelblue','lightyellow','lime','limegreen',
80
- 'linen','magenta','maroon','mediumaquamarine',
81
- 'mediumblue','mediumorchid','mediumpurple',
82
- 'mediumseagreen','mediumslateblue','mediumspringgreen',
83
- 'mediumturquoise','mediumvioletred','midnightblue',
84
- 'mintcream','mistyrose','moccasin','navy',
85
- 'oldlace','olive','olivedrab','orange','orangered',
86
- 'orchid','palegoldenrod','palegreen','paleturquoise',
87
- 'palevioletred','papayawhip','peachpuff','peru','pink',
88
- 'plum','powderblue','purple','red','rosybrown',
89
- 'royalblue','saddlebrown','salmon','sandybrown',
90
- 'seagreen','seashell','sienna','silver','skyblue',
91
- 'slateblue','slategray','slategrey','snow','springgreen',
92
- 'steelblue','tan','teal','thistle','tomato','turquoise',
93
- 'violet','wheat','yellow','yellowgreen'
 
 
 
 
 
 
 
 
 
 
 
 
94
  ]
95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  def build_parents(tree, visit_order, node_id2plot_id):
97
  parents = [None]
98
  parent_plot_ids = [None]
@@ -188,27 +239,29 @@ def main():
188
  frames = [go.Frame(data=graph_obj) for graph_obj in graph_objs]
189
  # show them with streamlit
190
 
 
 
 
 
 
191
 
192
- # This works the way I want
193
- # but the plot is tiny
194
- # also it recalcualtes all of the plots
195
- # every time the slider value changes
196
- #
197
- # I tried to cache the plots but build_plot() takes
198
- # a DataFrame which is mutable and therefore unhashable I guess
199
- # so it won't let me cache that function
200
- # I could pack the dataframe bytes to smuggle them past that check
201
- # but whatever
202
- idx = st.slider(
203
- label='which step to show',
204
- min_value=0,
205
- max_value=len(figures)-1,
206
- value=0,
207
- step=1
208
- )
209
- st.plotly_chart(figures[idx])
210
- st.markdown(f'## Tree {idx}')
211
- st.dataframe(trees[idx])
212
 
213
  # Maybe just show a Plotly animated chart
214
  # https://plotly.com/python/animations/#using-a-slider-and-buttons
@@ -259,8 +312,32 @@ def main():
259
  )
260
  st.plotly_chart(ani_fig)
261
 
262
- st.markdown(f'## {len(FEATS)}')
 
 
 
 
263
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
 
265
  if __name__=='__main__':
266
  main()
 
56
  'ct_dst_src_ltm',
57
  ]
58
 
59
+ # Generated from
60
+ # mokole.com/palette.html
61
  COLORS = [
62
+ '#808080',
63
+ '#2f4f4f',
64
+ '#556b2f',
65
+ '#8b4513',
66
+ '#6b8e23',
67
+ '#2e8b57',
68
+ '#800000',
69
+ '#191970',
70
+ '#006400',
71
+ '#b8860b',
72
+ '#4682b4',
73
+ '#d2691e',
74
+ '#9acd32',
75
+ '#20b2aa',
76
+ '#cd5c5c',
77
+ '#00008b',
78
+ '#32cd32',
79
+ '#8fbc8f',
80
+ '#800080',
81
+ '#b03060',
82
+ '#d2b48c',
83
+ '#ff4500',
84
+ '#ffa500',
85
+ '#ffff00',
86
+ '#c71585',
87
+ '#0000cd',
88
+ '#00ff00',
89
+ '#00ff7f',
90
+ '#dc143c',
91
+ '#00ffff',
92
+ '#00bfff',
93
+ '#f4a460',
94
+ '#9370db',
95
+ '#a020f0',
96
+ '#adff2f',
97
+ '#ff6347',
98
+ '#da70d6',
99
+ '#b0c4de',
100
+ '#ff00ff',
101
+ '#f0e68c',
102
+ '#6495ed',
103
+ '#dda0dd',
104
+ '#afeeee',
105
+ '#98fb98',
106
+ '#7fffd4',
107
+ '#ffb6c1',
108
  ]
109
 
110
+ #COLORS = [
111
+ # 'aliceblue','aqua','aquamarine','azure',
112
+ # 'bisque','black','blanchedalmond','blue',
113
+ # 'blueviolet','brown','burlywood','cadetblue',
114
+ # 'chartreuse','chocolate','coral','cornflowerblue',
115
+ # 'cornsilk','crimson','cyan','darkblue','darkcyan',
116
+ # 'darkgoldenrod','darkgray','darkgreen',
117
+ # 'darkkhaki','darkmagenta','darkolivegreen','darkorange',
118
+ # 'darkorchid','darkred','darksalmon','darkseagreen',
119
+ # 'darkslateblue','darkslategray',
120
+ # 'darkturquoise','darkviolet','deeppink','deepskyblue',
121
+ # 'dimgray','dodgerblue',
122
+ # 'forestgreen','fuchsia','gainsboro',
123
+ # 'gold','goldenrod','gray','green',
124
+ # 'greenyellow','honeydew','hotpink','indianred','indigo',
125
+ # 'ivory','khaki','lavender','lavenderblush','lawngreen',
126
+ # 'lemonchiffon','lightblue','lightcoral','lightcyan',
127
+ # 'lightgoldenrodyellow','lightgray',
128
+ # 'lightgreen','lightpink','lightsalmon','lightseagreen',
129
+ # 'lightskyblue','lightslategray',
130
+ # 'lightsteelblue','lightyellow','lime','limegreen',
131
+ # 'linen','magenta','maroon','mediumaquamarine',
132
+ # 'mediumblue','mediumorchid','mediumpurple',
133
+ # 'mediumseagreen','mediumslateblue','mediumspringgreen',
134
+ # 'mediumturquoise','mediumvioletred','midnightblue',
135
+ # 'mintcream','mistyrose','moccasin','navy',
136
+ # 'oldlace','olive','olivedrab','orange','orangered',
137
+ # 'orchid','palegoldenrod','palegreen','paleturquoise',
138
+ # 'palevioletred','papayawhip','peachpuff','peru','pink',
139
+ # 'plum','powderblue','purple','red','rosybrown',
140
+ # 'royalblue','saddlebrown','salmon','sandybrown',
141
+ # 'seagreen','seashell','sienna','silver','skyblue',
142
+ # 'slateblue','slategray','slategrey','snow','springgreen',
143
+ # 'steelblue','tan','teal','thistle','tomato','turquoise',
144
+ # 'violet','wheat','yellow','yellowgreen'
145
+ #]
146
+
147
  def build_parents(tree, visit_order, node_id2plot_id):
148
  parents = [None]
149
  parent_plot_ids = [None]
 
239
  frames = [go.Frame(data=graph_obj) for graph_obj in graph_objs]
240
  # show them with streamlit
241
 
242
+ st.markdown("""
243
+ I trained a
244
+ [Histogram-based Gradient Boosting Classification Tree](https://scikit-learn.org/stable/modules/ensemble.html#histogram-based-gradient-boosting)
245
+ on some data.
246
+ That algoritm looks at its mistakes and tries to avoid those mistakes the next time around.
247
 
248
+ To do that, it starts off with a decision tree.
249
+ From there, it looks at the points that tree got wrong and makes another decision tree that tries
250
+ to get those points right.
251
+ Then it looks at that second tree's mistakes and makes another tree that tries to fix those mistakes.
252
+ And so on.
253
+
254
+ My model ends up with 10 trees.
255
+ I've plotted the progression of those trees as an animated series of tree maps.
256
+ The boxes are color-coded by which feature the decision tree is using to make that split and I've labeled each one with the exact decision boundary of that split.
257
+ It takes a second to get going after you hit "Play."
258
+
259
+ I recommend expanding the plot by clicking the arrows in the top right corner since Streamlit makes the plot really small.
260
+
261
+ """)
262
+
263
+
264
+ st.markdown('## My Trees')
 
 
 
265
 
266
  # Maybe just show a Plotly animated chart
267
  # https://plotly.com/python/animations/#using-a-slider-and-buttons
 
312
  )
313
  st.plotly_chart(ani_fig)
314
 
315
+ st.markdown("""
316
+ This actually turned out to be a lot harder than I thought it would be.
317
+ """)
318
+
319
+ st.markdown('# Check out each tree!')
320
 
321
+ # This works the way I want
322
+ # but the plot is tiny
323
+ # also it recalcualtes all of the plots
324
+ # every time the slider value changes
325
+ #
326
+ # I tried to cache the plots but build_plot() takes
327
+ # a DataFrame which is mutable and therefore unhashable I guess
328
+ # so it won't let me cache that function
329
+ # I could pack the dataframe bytes to smuggle them past that check
330
+ # but whatever
331
+ idx = st.slider(
332
+ label='Which tree do you want to see?',
333
+ min_value=0,
334
+ max_value=len(figures)-1,
335
+ value=0,
336
+ step=1
337
+ )
338
+ st.plotly_chart(figures[idx])
339
+ st.markdown(f'## Tree {idx}')
340
+ st.dataframe(trees[idx])
341
 
342
  if __name__=='__main__':
343
  main()