Spaces:
Sleeping
Sleeping
import gradio as gr | |
import xgboost as xgb | |
import joblib | |
import numpy as np | |
from sklearn.preprocessing import StandardScaler | |
import pandas as pd | |
import shap | |
import matplotlib.pyplot as plt | |
# Load the model and the scaler | |
model = joblib.load('best_XGB.pkl') | |
scaler = joblib.load('scaler.pkl') # Ensure the scaler is saved and loaded with the same scikit-learn version | |
cutoff = 42 # Custom cutoff probability | |
# Use TreeExplainer for XGBoost models | |
explainer = shap.TreeExplainer(model) | |
# Define the prediction function with preprocessing, scaling, and SHAP analysis | |
def predict_heart_attack(Gender, age, cigsPerDay, BPMeds, prevalentHyp, diabetes, totChol, sysBP, diaBP, BMI, heartRate, glucose): | |
# Define feature names in the same order as the training data | |
feature_names = ['Gender', 'age', 'cigsPerDay', 'BPMeds', 'prevalentHyp', 'diabetes', 'totChol', 'sysBP', 'diaBP', 'BMI', 'heartRate', 'glucose'] | |
# Create a DataFrame with the correct feature names for prediction | |
features = pd.DataFrame([[Gender, age, cigsPerDay, BPMeds, prevalentHyp, diabetes, totChol, sysBP, diaBP, BMI, heartRate, glucose]], columns=feature_names) | |
# Standardize the features (scaling) | |
scaled_features = scaler.transform(features) | |
# Predict probabilities | |
proba = model.predict_proba(scaled_features)[:, 1] # Probability of class 1 (heart attack) | |
# Apply custom cutoff | |
if proba[0]*100 >= cutoff: | |
prediction_class = "Heart_Attack_Risk.Consult your doctor" | |
else: | |
prediction_class = "No_Heart_Attack_Risk.Still make regular checkup " | |
# Generate SHAP values for the prediction using the explainer | |
shap_values = explainer(features) | |
# Plot SHAP values | |
plt.figure(figsize=(8, 6)) | |
shap.waterfall_plot(shap_values[0]) # Using the SHAP Explanation object | |
plt.savefig('shap_plot.png') # Save SHAP plot to a file | |
result = f"Predicted Probability: {proba[0]*100:.2f}%. Predicted Class with cutoff {cutoff}%: {prediction_class}" | |
return result, 'shap_plot.png' # Return the prediction and SHAP plot | |
# Create the Gradio interface with preprocessing, prediction, and SHAP visualization | |
with gr.Blocks() as app: | |
with gr.Row(): | |
with gr.Column(): | |
Gender = gr.Radio([0, 1], label="Gender (0=Female, 1=Male)") | |
cigsPerDay = gr.Slider(0, 40, step=1, label="Cigarettes per Day") | |
prevalentHyp = gr.Radio([0, 1], label="Prevalent Hypertension (0=No, 1=Yes)") | |
totChol = gr.Slider(100, 400, step=1, label="Total Cholesterol in mg/dl") | |
diaBP = gr.Slider(60, 120, step=1, label="Diastolic/Lower BP") | |
heartRate = gr.Slider(50, 120, step=1, label="Heart Rate") | |
with gr.Column(): | |
age = gr.Slider(20, 80, step=1, label="Age (years)") | |
BPMeds = gr.Radio([0, 1], label="On BP Medications (0=No, 1=Yes)") | |
diabetes = gr.Radio([0, 1], label="Diabetes (0=No, 1=Yes)") | |
sysBP = gr.Slider(90, 200, step=1, label="Systolic BP/Higher BP") | |
BMI = gr.Slider(15, 40, step=0.1, label="Body Mass Index (BMI) in kg/m2") | |
glucose = gr.Slider(50, 250, step=1, label="Fasting Glucose Level") | |
# Center-aligned prediction output | |
with gr.Row(): | |
gr.HTML("<div style='text-align: center; width: 100%'>Heart Attack Prediction</div>") | |
with gr.Row(): | |
prediction_output = gr.Textbox(label="", interactive=False, elem_id="prediction_output") | |
with gr.Row(): | |
shap_plot_output = gr.Image(label="SHAP Analysis") | |
# Link inputs and prediction output | |
submit_btn = gr.Button("Submit") | |
submit_btn.click(fn=predict_heart_attack, inputs=[Gender, age, cigsPerDay, BPMeds, prevalentHyp, diabetes, totChol, sysBP, diaBP, BMI, heartRate, glucose], outputs=[prediction_output, shap_plot_output]) | |
app.launch(share = True) | |