from shap_plots import shap_summary_plot, shap_dependence_plot import plotly.tools as tls import dash_core_components as dcc import pandas as pd import numpy as np import xgboost import shap import matplotlib import plotly.graph_objs as go try: import matplotlib.pyplot as pl from matplotlib.colors import LinearSegmentedColormap from matplotlib.ticker import MaxNLocator except ImportError: pass from sklearn import preprocessing cdict1 = { 'red': ((0.0, 0.11764705882352941, 0.11764705882352941), (1.0, 0.9607843137254902, 0.9607843137254902)), 'green': ((0.0, 0.5333333333333333, 0.5333333333333333), (1.0, 0.15294117647058825, 0.15294117647058825)), 'blue': ((0.0, 0.8980392156862745, 0.8980392156862745), (1.0, 0.3411764705882353, 0.3411764705882353)), 'alpha': ((0.0, 1, 1), (0.5, 1, 1), (1.0, 1, 1)) } # #1E88E5 -> #ff0052 red_blue = LinearSegmentedColormap('RedBlue', cdict1) def matplotlib_to_plotly(cmap, pl_entries): h = 1.0/(pl_entries-1) pl_colorscale = [] for k in range(pl_entries): C = list(map(np.uint8, np.array(cmap(k*h)[:3])*255)) pl_colorscale.append([k*h, 'rgb'+str((C[0], C[1], C[2]))]) return pl_colorscale red_blue = matplotlib_to_plotly(red_blue, 255) def summary_plot_plotly_fig(shap_values, df_shap, feature_names, max_display = 8): #data = pd.read_csv(dataset, encoding="ISO-8859-1") #X = data.drop(['target column'], axis=1) #y = data[target] #y = y/max(y) #X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=7) #X_train.fillna((-999), inplace=True) #X_test.fillna((-999), inplace=True) #_, shap_values, feature_names = train_model_and_return_shap_values(X, y, target) mpl_fig = shap_summary_plot(shap_values, df_shap, feature_names=feature_names, max_display=20) plotly_fig = tls.mpl_to_plotly(mpl_fig) plotly_fig['layout'] = {'xaxis': {'title': 'SHAP value (impact on model output)'}} feature_order = np.argsort(np.sum(np.abs(shap_values), axis=0)[:-1]) feature_order = feature_order[-min(max_display, len(feature_order)):] text = [df_shap.index[i] for i in df_shap.index] text = iter(text) for i in range(1, len(plotly_fig['data']), 2): t = text.__next__() plotly_fig['data'][i]['name'] = '' plotly_fig['data'][i]['text'] = t plotly_fig['data'][i]['hoverinfo'] = 'text' #plotly_fig['data'][i]['text'] = df_shap.index plotly_fig['data'][i]['y'] = feature_names[feature_order] colorbar_trace = go.Scatter(x=[None], y=[None], mode='markers', marker=dict( colorscale=red_blue, showscale=True, cmin=-5, cmax=5, colorbar=dict(thickness=5, tickvals=[-5, 5], ticktext=['Low', 'High'], outlinewidth=0) ), hoverinfo='none' ) plotly_fig['layout']['showlegend'] = False plotly_fig['layout']['hovermode'] = 'closest' plotly_fig['layout']['height']=600 plotly_fig['layout']['width']=500 plotly_fig['layout']['xaxis'].update(zeroline=True, showline=True, ticklen=4, showgrid=False) plotly_fig['layout']['yaxis'].update(dict(visible=True)) plotly_fig.add_trace(colorbar_trace) plotly_fig.layout.update( annotations=[dict( x=1.18, align="right", valign="top", text='Gene', showarrow=False, xref="paper", yref="paper", xanchor="right", yanchor="middle", textangle=-90, font=dict(family='Calibri', size=14) ) ], margin=dict(t=20) ) return plotly_fig