File size: 3,892 Bytes
3d119e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import gradio as gr
import xgboost as xgb
import joblib
import numpy as np
from sklearn.preprocessing import StandardScaler
import pandas as pd
import shap
import matplotlib.pyplot as plt

# Load the model and the scaler
model = joblib.load('best_XGB.pkl')
scaler = joblib.load('scaler.pkl')  # Ensure the scaler is saved and loaded with the same scikit-learn version
cutoff = 42  # Custom cutoff probability

# Use TreeExplainer for XGBoost models
explainer = shap.TreeExplainer(model)

# Define the prediction function with preprocessing, scaling, and SHAP analysis
def predict_heart_attack(Gender, age, cigsPerDay, BPMeds, prevalentHyp, diabetes, totChol, sysBP, diaBP, BMI, heartRate, glucose):
    # Define feature names in the same order as the training data
    feature_names = ['Gender', 'age', 'cigsPerDay', 'BPMeds', 'prevalentHyp', 'diabetes', 'totChol', 'sysBP', 'diaBP', 'BMI', 'heartRate', 'glucose']
    
    # Create a DataFrame with the correct feature names for prediction
    features = pd.DataFrame([[Gender, age, cigsPerDay, BPMeds, prevalentHyp, diabetes, totChol, sysBP, diaBP, BMI, heartRate, glucose]], columns=feature_names)
    
    # Standardize the features (scaling)
    scaled_features = scaler.transform(features)
    
    # Predict probabilities
    proba = model.predict_proba(scaled_features)[:, 1]  # Probability of class 1 (heart attack)
    
    # Apply custom cutoff
    if proba[0]*100 >= cutoff:
        prediction_class = "Heart_Attack_Risk.Consult your doctor"
    else:
        prediction_class = "No_Heart_Attack_Risk.Still make regular checkup "
    
    # Generate SHAP values for the prediction using the explainer
    shap_values = explainer(features)

    # Plot SHAP values
    plt.figure(figsize=(8, 6))
    shap.waterfall_plot(shap_values[0])  # Using the SHAP Explanation object
    plt.savefig('shap_plot.png')  # Save SHAP plot to a file
    
    result = f"Predicted Probability: {proba[0]*100:.2f}%. Predicted Class with cutoff {cutoff}%: {prediction_class}"
    
    return result, 'shap_plot.png'  # Return the prediction and SHAP plot

# Create the Gradio interface with preprocessing, prediction, and SHAP visualization
with gr.Blocks() as app:
    with gr.Row():
        with gr.Column():
            Gender = gr.Radio([0, 1], label="Gender (0=Female, 1=Male)")
            cigsPerDay = gr.Slider(0, 40, step=1, label="Cigarettes per Day")
            prevalentHyp = gr.Radio([0, 1], label="Prevalent Hypertension (0=No, 1=Yes)")
            totChol = gr.Slider(100, 400, step=1, label="Total Cholesterol in mg/dl")
            diaBP = gr.Slider(60, 120, step=1, label="Diastolic/Lower BP")
            heartRate = gr.Slider(50, 120, step=1, label="Heart Rate")
        
        with gr.Column():
            age = gr.Slider(20, 80, step=1, label="Age (years)")
            BPMeds = gr.Radio([0, 1], label="On BP Medications (0=No, 1=Yes)")
            diabetes = gr.Radio([0, 1], label="Diabetes (0=No, 1=Yes)")
            sysBP = gr.Slider(90, 200, step=1, label="Systolic BP/Higher BP")
            BMI = gr.Slider(15, 40, step=0.1, label="Body Mass Index (BMI) in kg/m2")
            glucose = gr.Slider(50, 250, step=1, label="Fasting Glucose Level")
    
    # Center-aligned prediction output
    with gr.Row():
        gr.HTML("<div style='text-align: center; width: 100%'>Heart Attack Prediction</div>")
    
    with gr.Row():
        prediction_output = gr.Textbox(label="", interactive=False, elem_id="prediction_output")
    
    with gr.Row():
        shap_plot_output = gr.Image(label="SHAP Analysis")

    # Link inputs and prediction output
    submit_btn = gr.Button("Submit")
    submit_btn.click(fn=predict_heart_attack, inputs=[Gender, age, cigsPerDay, BPMeds, prevalentHyp, diabetes, totChol, sysBP, diaBP, BMI, heartRate, glucose], outputs=[prediction_output, shap_plot_output])

app.launch(share = True)