File size: 5,016 Bytes
3d119e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d7a3122
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d119e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d7a3122
3d119e0
 
d7a3122
3d119e0
 
 
 
 
 
 
 
 
d7a3122
 
 
 
3d119e0
d7a3122
3d119e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d7a3122
 
 
 
3d119e0
d7a3122
3d119e0
 
 
 
 
 
 
 
 
d7a3122
3d119e0
d7a3122
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import gradio as gr
import xgboost as xgb
import joblib
import numpy as np
from sklearn.preprocessing import StandardScaler
import pandas as pd
import shap
import matplotlib.pyplot as plt

# Load the model and the scaler
model = joblib.load('best_XGB.pkl')
scaler = joblib.load('scaler.pkl')  # Ensure the scaler is saved and loaded with the same scikit-learn version
cutoff = 42  # Custom cutoff probability

# Use TreeExplainer for XGBoost models
explainer = shap.TreeExplainer(model)

# Define the function to draw the semicircular scale
def draw_scale(probability):
    fig, ax = plt.subplots(figsize=(6, 3))
    
    # Plot the semicircular scale
    ax.barh(0, 1, color='green', left=0, height=0.3)
    ax.barh(0, 1, color='red', left=cutoff / 100, height=0.3)
    
    # Add arrow indicator based on predicted probability
    arrow_position = probability / 100
    color = 'green' if probability < cutoff else 'red'
    ax.annotate('', xy=(arrow_position, 0.15), xytext=(arrow_position, 0.3),
                arrowprops=dict(facecolor=color, shrink=0.05))
    
    # Remove axes and add labels
    ax.set_xlim(0, 1)
    ax.set_ylim(-0.5, 0.5)
    ax.axis('off')
    
    # Save the image
    plt.savefig('scale_plot.png', bbox_inches='tight')
    plt.close()

# Define the prediction function with preprocessing, scaling, and SHAP analysis
def predict_heart_attack(Gender, age, cigsPerDay, BPMeds, prevalentHyp, diabetes, totChol, sysBP, diaBP, BMI, heartRate, glucose):
    # Define feature names in the same order as the training data
    feature_names = ['Gender', 'age', 'cigsPerDay', 'BPMeds', 'prevalentHyp', 'diabetes', 'totChol', 'sysBP', 'diaBP', 'BMI', 'heartRate', 'glucose']
    
    # Create a DataFrame with the correct feature names for prediction
    features = pd.DataFrame([[Gender, age, cigsPerDay, BPMeds, prevalentHyp, diabetes, totChol, sysBP, diaBP, BMI, heartRate, glucose]], columns=feature_names)
    
    # Standardize the features (scaling)
    scaled_features = scaler.transform(features)
    
    # Predict probabilities
    proba = model.predict_proba(scaled_features)[:, 1]  # Probability of class 1 (heart attack)
    
    # Apply custom cutoff
    if proba[0] * 100 >= cutoff:
        prediction_class = "Heart_Attack_Risk.Consult your doctor"
    else:
        prediction_class = "No_Heart_Attack_Risk.Still make regular checkup"
    
    # Generate SHAP values for the prediction using the explainer
    shap_values = explainer(features)

    # Plot SHAP values
    plt.figure(figsize=(8, 6))
    shap.waterfall_plot(shap_values[0])  # Using the SHAP Explanation object
    plt.savefig('shap_plot.png')  # Save SHAP plot to a file
    
    # Draw semicircular scale
    draw_scale(proba[0] * 100)
    
    result = f"Predicted Probability: {proba[0] * 100:.2f}%. Predicted Class with cutoff {cutoff}%: {prediction_class}"
    
    return result, 'scale_plot.png', 'shap_plot.png'  # Return the prediction, scale, and SHAP plot

# Create the Gradio interface with preprocessing, prediction, and SHAP visualization
with gr.Blocks() as app:
    with gr.Row():
        with gr.Column():
            Gender = gr.Radio([0, 1], label="Gender (0=Female, 1=Male)")
            cigsPerDay = gr.Slider(0, 40, step=1, label="Cigarettes per Day")
            prevalentHyp = gr.Radio([0, 1], label="Prevalent Hypertension (0=No, 1=Yes)")
            totChol = gr.Slider(100, 400, step=1, label="Total Cholesterol in mg/dl")
            diaBP = gr.Slider(60, 120, step=1, label="Diastolic/Lower BP")
            heartRate = gr.Slider(50, 120, step=1, label="Heart Rate")
        
        with gr.Column():
            age = gr.Slider(20, 80, step=1, label="Age (years)")
            BPMeds = gr.Radio([0, 1], label="On BP Medications (0=No, 1=Yes)")
            diabetes = gr.Radio([0, 1], label="Diabetes (0=No, 1=Yes)")
            sysBP = gr.Slider(90, 200, step=1, label="Systolic BP/Higher BP")
            BMI = gr.Slider(15, 40, step=0.1, label="Body Mass Index (BMI) in kg/m2")
            glucose = gr.Slider(50, 250, step=1, label="Fasting Glucose Level")
    
    # Display disclaimer in red uppercase letters above the scale
    gr.HTML("<div style='text-align: center; color: red; font-weight: bold; font-size: 16px;'>RESULTS ARE NOT A SUBSTITUTE FOR ADVICE OF QUALIFIED MEDICAL PROFESSIONAL</div>")
    
    # Center-aligned prediction output and semicircular scale
    with gr.Row():
        scale_output = gr.Image(label="Risk Indicator Scale")
    
    with gr.Row():
        prediction_output = gr.Textbox(label="", interactive=False, elem_id="prediction_output")
    
    with gr.Row():
        shap_plot_output = gr.Image(label="SHAP Analysis")

    # Link inputs and prediction output
    submit_btn = gr.Button("Submit")
    submit_btn.click(fn=predict_heart_attack, inputs=[Gender, age, cigsPerDay, BPMeds, prevalentHyp, diabetes, totChol, sysBP, diaBP, BMI, heartRate, glucose], outputs=[prediction_output, scale_output, shap_plot_output])

app.launch(share=True)