File size: 4,818 Bytes
c52a337 6fc156b c52a337 b984ae4 c52a337 b984ae4 c52a337 b984ae4 c52a337 b984ae4 c52a337 b984ae4 c52a337 b984ae4 c52a337 b984ae4 c52a337 b984ae4 c52a337 b984ae4 c52a337 b984ae4 c52a337 b984ae4 c52a337 b984ae4 c52a337 b984ae4 532ebbc b984ae4 f995639 b54f628 f995639 b984ae4 c52a337 532ebbc c52a337 b984ae4 c52a337 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import io
import pickle
import streamlit as st
import joblib
import shap
import pandas as pd
import matplotlib.pyplot as plt
# Load the LightGBM model and other necessary objects
with open('lgb1_model.pkl', 'rb') as f:
lgb1 = pickle.load(f)
categorical_features = joblib.load("categorical_features.joblib")
encoder = joblib.load("encoder.joblib")
# Sidebar option to select the dashboard
option = st.sidebar.selectbox("Which dashboard?", ("Model information", "Stroke prediction"))
st.title(option)
def get_pred():
"""
Function to display the stroke probability calculator and Shap force plot.
"""
st.header("Stroke probability calculator ")
# User input for prediction
gender = st.selectbox("Select gender: ", ["Male", "Female", 'Other'])
work_type = st.selectbox("Work type: ", ["Private", "Self_employed", 'children', 'Govt_job', 'Never_worked'])
residence_status = st.selectbox("Residence status: ", ["Urban", "Rural"])
smoking_status = st.selectbox("Smoking status: ", ["Unknown", "formerly smoked", 'never smoked', 'smokes'])
age = st.slider("Input age: ", 0, 120)
hypertension = st.select_slider("Do you have hypertension: ", [0, 1])
heart_disease = st.select_slider("Do you have heart disease: ", [0, 1])
ever_married = st.select_slider("Have you ever married? ", [0, 1])
avg_glucosis_lvl = st.slider("Average glucosis level: ", 50, 280)
bmi = st.slider("Input Bmi: ", 10, 100)
# User input data
data = {
"gender": gender,
"work_type": work_type,
"Residence_type": residence_status,
"smoking_status": smoking_status,
"age": age,
"hypertension": hypertension,
"heart_disease": heart_disease,
"ever_married": ever_married,
"avg_glucose_level": avg_glucosis_lvl,
"bmi": bmi
}
if st.button("Predict"):
# Convert input data to a DataFrame
X = pd.DataFrame([data])
encoded_features = encoder.transform(X[categorical_features])
feature_names = encoder.get_feature_names_out(input_features=categorical_features)
encoded_df = pd.DataFrame(encoded_features, columns=feature_names)
X_encoded = pd.concat([X.drop(columns=categorical_features), encoded_df], axis=1)
prediction_proba = lgb1.predict_proba(X_encoded)
explainer = shap.TreeExplainer(lgb1)
shap_values = explainer.shap_values(X_encoded)
probability = prediction_proba[0, 1] # Assuming binary classification
st.subheader(f"The predicted probability of stroke is {probability}.")
st.subheader("IF you see result , higher than 0.3, we advice you to see a doctor")
st.header("Shap forceplot")
st.subheader("Features values impact on model made prediction")
shap.force_plot(explainer.expected_value[1], shap_values[1], features=X_encoded.iloc[0, :], matplotlib=True)
buf = io.BytesIO()
plt.savefig(buf, format="png", dpi=800)
buf.seek(0)
st.image(buf, width=1100)
shap.summary_plot(shap_values[1], X_encoded)
shap_interaction_values = explainer.shap_interaction_values(X_encoded)
shap.summary_plot(shap_interaction_values, X_encoded)
if option == "Information about training data":
st.header("Stroke Prediction Dataset")
st.subheader("According to the World Health Organization (WHO) stroke is the 2nd leading cause of death globally, responsible for approximately 11% of total deaths. This dataset is used to predict whether a patient is likely to get stroke based on the input parameters like gender, age, various diseases, and smoking status. Each row in the data provides relavant information about the patient.")
st.subheader("Disclaimer: This project is made out of one American hospital data. For this model to be more relevant to predict your health, it has to bee trained on your population data")
st.subheader(" Stroke dataset has 5110 records and 12 features.")
st.subheader(" Correlation between features:.")
st.image(r'Correlation.png')
st.subheader("Features Shap values and how it effects Target variable: Stroke")
st.image(r'Shap_Values.png')
if option == "Stroke prediction":
get_pred()
st.subheader("Disclaimer: This project is made out of one American hospital data. For this model to be more relevant to predict your health, it has to bee trained on your population data")
if option == "Model information":
st.header("Light gradient boosting model")
st.subheader("First tree of light gradient boosting model and how it makes decisions")
st.image(r'lgbm_tree.png')
st.subheader("Shap values visualization of how features contribute to model prediction")
st.image(r'lgbm_model_shap_evaluation.png') |