Spaces:
Runtime error
Runtime error
import joblib | |
import time | |
import plotly.graph_objects as go | |
import streamlit as st | |
import pandas as pd | |
import numpy as np | |
FEATS = [ | |
'srcip', | |
'sport', | |
'dstip', | |
'dsport', | |
'proto', | |
#'state', I dropped this one when I trained the model | |
'dur', | |
'sbytes', | |
'dbytes', | |
'sttl', | |
'dttl', | |
'sloss', | |
'dloss', | |
'service', | |
'Sload', | |
'Dload', | |
'Spkts', | |
'Dpkts', | |
'swin', | |
'dwin', | |
'stcpb', | |
'dtcpb', | |
'smeansz', | |
'dmeansz', | |
'trans_depth', | |
'res_bdy_len', | |
'Sjit', | |
'Djit', | |
'Stime', | |
'Ltime', | |
'Sintpkt', | |
'Dintpkt', | |
'tcprtt', | |
'synack', | |
'ackdat', | |
'is_sm_ips_ports', | |
'ct_state_ttl', | |
'ct_flw_http_mthd', | |
'is_ftp_login', | |
'ct_ftp_cmd', | |
'ct_srv_src', | |
'ct_srv_dst', | |
'ct_dst_ltm', | |
'ct_src_ltm', | |
'ct_src_dport_ltm', | |
'ct_dst_sport_ltm', | |
'ct_dst_src_ltm', | |
] | |
# Generated from | |
# mokole.com/palette.html | |
COLORS = [ | |
'#808080', | |
'#2f4f4f', | |
'#556b2f', | |
'#8b4513', | |
'#6b8e23', | |
'#2e8b57', | |
'#800000', | |
'#191970', | |
'#006400', | |
'#b8860b', | |
'#4682b4', | |
'#d2691e', | |
'#9acd32', | |
'#20b2aa', | |
'#cd5c5c', | |
'#00008b', | |
'#32cd32', | |
'#8fbc8f', | |
'#800080', | |
'#b03060', | |
'#d2b48c', | |
'#ff4500', | |
'#ffa500', | |
'#ffff00', | |
'#c71585', | |
'#0000cd', | |
'#00ff00', | |
'#00ff7f', | |
'#dc143c', | |
'#00ffff', | |
'#00bfff', | |
'#f4a460', | |
'#9370db', | |
'#a020f0', | |
'#adff2f', | |
'#ff6347', | |
'#da70d6', | |
'#b0c4de', | |
'#ff00ff', | |
'#f0e68c', | |
'#6495ed', | |
'#dda0dd', | |
'#afeeee', | |
'#98fb98', | |
'#7fffd4', | |
'#ffb6c1', | |
] | |
#COLORS = [ | |
# 'aliceblue','aqua','aquamarine','azure', | |
# 'bisque','black','blanchedalmond','blue', | |
# 'blueviolet','brown','burlywood','cadetblue', | |
# 'chartreuse','chocolate','coral','cornflowerblue', | |
# 'cornsilk','crimson','cyan','darkblue','darkcyan', | |
# 'darkgoldenrod','darkgray','darkgreen', | |
# 'darkkhaki','darkmagenta','darkolivegreen','darkorange', | |
# 'darkorchid','darkred','darksalmon','darkseagreen', | |
# 'darkslateblue','darkslategray', | |
# 'darkturquoise','darkviolet','deeppink','deepskyblue', | |
# 'dimgray','dodgerblue', | |
# 'forestgreen','fuchsia','gainsboro', | |
# 'gold','goldenrod','gray','green', | |
# 'greenyellow','honeydew','hotpink','indianred','indigo', | |
# 'ivory','khaki','lavender','lavenderblush','lawngreen', | |
# 'lemonchiffon','lightblue','lightcoral','lightcyan', | |
# 'lightgoldenrodyellow','lightgray', | |
# 'lightgreen','lightpink','lightsalmon','lightseagreen', | |
# 'lightskyblue','lightslategray', | |
# 'lightsteelblue','lightyellow','lime','limegreen', | |
# 'linen','magenta','maroon','mediumaquamarine', | |
# 'mediumblue','mediumorchid','mediumpurple', | |
# 'mediumseagreen','mediumslateblue','mediumspringgreen', | |
# 'mediumturquoise','mediumvioletred','midnightblue', | |
# 'mintcream','mistyrose','moccasin','navy', | |
# 'oldlace','olive','olivedrab','orange','orangered', | |
# 'orchid','palegoldenrod','palegreen','paleturquoise', | |
# 'palevioletred','papayawhip','peachpuff','peru','pink', | |
# 'plum','powderblue','purple','red','rosybrown', | |
# 'royalblue','saddlebrown','salmon','sandybrown', | |
# 'seagreen','seashell','sienna','silver','skyblue', | |
# 'slateblue','slategray','slategrey','snow','springgreen', | |
# 'steelblue','tan','teal','thistle','tomato','turquoise', | |
# 'violet','wheat','yellow','yellowgreen' | |
#] | |
def build_parents(tree, visit_order, node_id2plot_id): | |
parents = [None] | |
parent_plot_ids = [None] | |
directions = [None] | |
for i in visit_order[1:]: | |
parent = tree[tree['right']==i].index | |
if parent.empty: | |
p = tree[tree['left']==i].index[0] | |
parent_plot_ids.append(str(node_id2plot_id[p])) | |
parents.append(p) | |
directions.append('l') | |
else: | |
parent_plot_ids.append(str(node_id2plot_id[parent[0]])) | |
parents.append(parent[0]) | |
directions.append('r') | |
return parents, parent_plot_ids, directions | |
def build_labels_colors(tree, visit_order, parents, parent_plot_ids, directions): | |
labels = ['Histogram Gradient-Boosted Decision Tree'] | |
colors = ['white'] | |
for i, parent, parent_plot_id, direction in zip( | |
visit_order, | |
parents, | |
parent_plot_ids, | |
directions | |
): | |
# skip the first one (the root) | |
if i == 0: | |
continue | |
node = tree.loc[i] | |
feat = FEATS[int(tree.loc[int(parent), 'feature_idx'])] | |
thresh = tree.loc[int(parent), 'num_threshold'] | |
if direction == 'l': | |
labels.append(f"[{parent_plot_id}.L] {feat} <= {thresh}") | |
else: | |
labels.append(f"[{parent_plot_id}.R] {feat} > {thresh}") | |
# colors | |
offset = FEATS.index(feat) | |
colors.append(COLORS[offset]) | |
return labels, colors | |
def build_plot(tree): | |
#https://stackoverflow.com/questions/64393535/python-plotly-treemap-ids-format-and-how-to-display-multiple-duplicated-labels-i | |
# if you use `ids`, then `parents` has to be in terms of `ids` | |
visit_order = breadth_first_traverse(tree) | |
node_id2plot_id = {node:i for i, node in enumerate(visit_order)} | |
parents, parent_plot_ids, directions = build_parents(tree, visit_order, node_id2plot_id) | |
labels, colors = build_labels_colors(tree, visit_order, parents, parent_plot_ids, directions) | |
# this should just be ['0', '1', '2', . . .] | |
plot_ids = [str(node_id2plot_id[x]) for x in visit_order] | |
return go.Treemap( | |
values=tree['count'].to_numpy(), | |
labels=labels, | |
ids=plot_ids, | |
parents=parent_plot_ids, | |
marker_colors=colors, | |
) | |
def breadth_first_traverse(tree): | |
""" | |
https://www.101computing.net/breadth-first-traversal-of-a-binary-tree/ | |
Iterative version makes more sense since I have the whole tree in a table | |
instead of just nodes and pointers | |
""" | |
q = [0] | |
visited_nodes = [] | |
while len(q) != 0: | |
cur = q.pop(0) | |
visited_nodes.append(cur) | |
if tree.loc[cur, 'left'] != 0: | |
q.append(tree.loc[cur, 'left']) | |
if tree.loc[cur, 'right'] != 0: | |
q.append(tree.loc[cur, 'right']) | |
return visited_nodes | |
def main(): | |
# load the data | |
hgb = joblib.load('hgb_classifier.joblib') | |
trees = [pd.DataFrame(x[0].nodes) for x in hgb._predictors] | |
# make the plots | |
graph_objs = [build_plot(tree) for tree in trees] | |
figures = [go.Figure(graph_obj) for graph_obj in graph_objs] | |
frames = [go.Frame(data=graph_obj) for graph_obj in graph_objs] | |
# show them with streamlit | |
st.markdown(""" | |
I trained a | |
[Histogram-based Gradient Boosting Classification Tree](https://scikit-learn.org/stable/modules/ensemble.html#histogram-based-gradient-boosting) | |
on some data. | |
That algoritm looks at its mistakes and tries to avoid those mistakes the next time around. | |
To do that, it starts off with a decision tree. | |
From there, it looks at the points that tree got wrong and makes another decision tree that tries | |
to get those points right. | |
Then it looks at that second tree's mistakes and makes another tree that tries to fix those mistakes. | |
And so on. | |
My model ends up with 10 trees. | |
I've plotted the progression of those trees as an animated series of tree maps. | |
The boxes are color-coded by which feature the decision tree is using to make that split and I've labeled each one with the exact decision boundary of that split. | |
It takes a second to get going after you hit "Play." | |
I recommend expanding the plot by clicking the arrows in the top right corner since Streamlit makes the plot really small. | |
""") | |
st.markdown('## My Trees') | |
# Maybe just show a Plotly animated chart | |
# https://plotly.com/python/animations/#using-a-slider-and-buttons | |
# They don't really document the animation stuff on their website | |
# but it's in here | |
# https://raw.githubusercontent.com/plotly/plotly.js/master/dist/plot-schema.json | |
# I guess it's only in the JS docs and hasn't made it to the Python docs yet | |
# https://plotly.com/javascript/animations/ | |
# trying to find stuff here instead | |
# https://plotly.com/python-api-reference/generated/plotly.graph_objects.layout.updatemenu.html?highlight=updatemenu | |
# this one finally set the speed | |
# no mention of how they figured this out but thank goodness I found it | |
# https://towardsdatascience.com/basic-animation-with-matplotlib-and-plotly-5eef4ad6c5aa | |
# this also has custom animation speeds in it | |
# https://plotly.com/python/custom-buttons/#reference | |
ani_fig = go.Figure( | |
data=graph_objs[0], | |
frames=frames, | |
layout=go.Layout( | |
updatemenus=[{ | |
'type':'buttons', | |
# https://plotly.com/python/reference/layout/updatemenus/ | |
# Always show the background color on buttons | |
# streamlit breaks the background color of the active button in darkmode | |
'showactive': False, | |
# background color of the buttons | |
'bgcolor': '#fff', | |
# font in the buttons | |
'font': {'color': '#000'}, | |
# border color of the buttons | |
'bordercolor': '#000', | |
# Play button | |
'buttons':[{ | |
'label':'Play', | |
'method': 'animate', | |
'args':[None, { | |
'frame': {'duration':5000}, | |
'transition': {'duration': 2500}, | |
}], | |
} | |
] | |
}] | |
) | |
) | |
st.plotly_chart(ani_fig) | |
st.markdown(""" | |
This actually turned out to be a lot harder than I thought it would be. | |
""") | |
st.markdown('# Check out each tree!') | |
# This works the way I want | |
# but the plot is tiny | |
# also it recalcualtes all of the plots | |
# every time the slider value changes | |
# | |
# I tried to cache the plots but build_plot() takes | |
# a DataFrame which is mutable and therefore unhashable I guess | |
# so it won't let me cache that function | |
# I could pack the dataframe bytes to smuggle them past that check | |
# but whatever | |
idx = st.slider( | |
label='Which tree do you want to see?', | |
min_value=0, | |
max_value=len(figures)-1, | |
value=0, | |
step=1 | |
) | |
st.plotly_chart(figures[idx]) | |
st.markdown(f'## Tree {idx}') | |
st.dataframe(trees[idx]) | |
if __name__=='__main__': | |
main() | |