Spaces:
Sleeping
Sleeping
Upload 4 files
Browse files- .gitattributes +1 -0
- app.py +83 -0
- best_XGB.pkl +3 -0
- requirements.txt +11 -0
- scaler.pkl +0 -0
.gitattributes
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
best_XGB.pkl filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import xgboost as xgb
|
3 |
+
import joblib
|
4 |
+
import numpy as np
|
5 |
+
from sklearn.preprocessing import StandardScaler
|
6 |
+
import pandas as pd
|
7 |
+
import shap
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
|
10 |
+
# Load the model and the scaler
|
11 |
+
model = joblib.load('best_XGB.pkl')
|
12 |
+
scaler = joblib.load('scaler.pkl') # Ensure the scaler is saved and loaded with the same scikit-learn version
|
13 |
+
cutoff = 42 # Custom cutoff probability
|
14 |
+
|
15 |
+
# Use TreeExplainer for XGBoost models
|
16 |
+
explainer = shap.TreeExplainer(model)
|
17 |
+
|
18 |
+
# Define the prediction function with preprocessing, scaling, and SHAP analysis
|
19 |
+
def predict_heart_attack(Gender, age, cigsPerDay, BPMeds, prevalentHyp, diabetes, totChol, sysBP, diaBP, BMI, heartRate, glucose):
|
20 |
+
# Define feature names in the same order as the training data
|
21 |
+
feature_names = ['Gender', 'age', 'cigsPerDay', 'BPMeds', 'prevalentHyp', 'diabetes', 'totChol', 'sysBP', 'diaBP', 'BMI', 'heartRate', 'glucose']
|
22 |
+
|
23 |
+
# Create a DataFrame with the correct feature names for prediction
|
24 |
+
features = pd.DataFrame([[Gender, age, cigsPerDay, BPMeds, prevalentHyp, diabetes, totChol, sysBP, diaBP, BMI, heartRate, glucose]], columns=feature_names)
|
25 |
+
|
26 |
+
# Standardize the features (scaling)
|
27 |
+
scaled_features = scaler.transform(features)
|
28 |
+
|
29 |
+
# Predict probabilities
|
30 |
+
proba = model.predict_proba(scaled_features)[:, 1] # Probability of class 1 (heart attack)
|
31 |
+
|
32 |
+
# Apply custom cutoff
|
33 |
+
if proba[0]*100 >= cutoff:
|
34 |
+
prediction_class = "Heart_Attack_Risk.Consult your doctor"
|
35 |
+
else:
|
36 |
+
prediction_class = "No_Heart_Attack_Risk.Still make regular checkup "
|
37 |
+
|
38 |
+
# Generate SHAP values for the prediction using the explainer
|
39 |
+
shap_values = explainer(features)
|
40 |
+
|
41 |
+
# Plot SHAP values
|
42 |
+
plt.figure(figsize=(8, 6))
|
43 |
+
shap.waterfall_plot(shap_values[0]) # Using the SHAP Explanation object
|
44 |
+
plt.savefig('shap_plot.png') # Save SHAP plot to a file
|
45 |
+
|
46 |
+
result = f"Predicted Probability: {proba[0]*100:.2f}%. Predicted Class with cutoff {cutoff}%: {prediction_class}"
|
47 |
+
|
48 |
+
return result, 'shap_plot.png' # Return the prediction and SHAP plot
|
49 |
+
|
50 |
+
# Create the Gradio interface with preprocessing, prediction, and SHAP visualization
|
51 |
+
with gr.Blocks() as app:
|
52 |
+
with gr.Row():
|
53 |
+
with gr.Column():
|
54 |
+
Gender = gr.Radio([0, 1], label="Gender (0=Female, 1=Male)")
|
55 |
+
cigsPerDay = gr.Slider(0, 40, step=1, label="Cigarettes per Day")
|
56 |
+
prevalentHyp = gr.Radio([0, 1], label="Prevalent Hypertension (0=No, 1=Yes)")
|
57 |
+
totChol = gr.Slider(100, 400, step=1, label="Total Cholesterol in mg/dl")
|
58 |
+
diaBP = gr.Slider(60, 120, step=1, label="Diastolic/Lower BP")
|
59 |
+
heartRate = gr.Slider(50, 120, step=1, label="Heart Rate")
|
60 |
+
|
61 |
+
with gr.Column():
|
62 |
+
age = gr.Slider(20, 80, step=1, label="Age (years)")
|
63 |
+
BPMeds = gr.Radio([0, 1], label="On BP Medications (0=No, 1=Yes)")
|
64 |
+
diabetes = gr.Radio([0, 1], label="Diabetes (0=No, 1=Yes)")
|
65 |
+
sysBP = gr.Slider(90, 200, step=1, label="Systolic BP/Higher BP")
|
66 |
+
BMI = gr.Slider(15, 40, step=0.1, label="Body Mass Index (BMI) in kg/m2")
|
67 |
+
glucose = gr.Slider(50, 250, step=1, label="Fasting Glucose Level")
|
68 |
+
|
69 |
+
# Center-aligned prediction output
|
70 |
+
with gr.Row():
|
71 |
+
gr.HTML("<div style='text-align: center; width: 100%'>Heart Attack Prediction</div>")
|
72 |
+
|
73 |
+
with gr.Row():
|
74 |
+
prediction_output = gr.Textbox(label="", interactive=False, elem_id="prediction_output")
|
75 |
+
|
76 |
+
with gr.Row():
|
77 |
+
shap_plot_output = gr.Image(label="SHAP Analysis")
|
78 |
+
|
79 |
+
# Link inputs and prediction output
|
80 |
+
submit_btn = gr.Button("Submit")
|
81 |
+
submit_btn.click(fn=predict_heart_attack, inputs=[Gender, age, cigsPerDay, BPMeds, prevalentHyp, diabetes, totChol, sysBP, diaBP, BMI, heartRate, glucose], outputs=[prediction_output, shap_plot_output])
|
82 |
+
|
83 |
+
app.launch(share = True)
|
best_XGB.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3d5422417966b1cc081d0b6e9772e6cee262ee72f203853be4a28d53859fbbcf
|
3 |
+
size 1347623
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
cloudpickle==3.1.0
|
2 |
+
gradio==5.1.0
|
3 |
+
gradio_client==1.4.0
|
4 |
+
joblib==1.4.2
|
5 |
+
numpy==1.26.4
|
6 |
+
pandas==2.2.2
|
7 |
+
shap==0.46.0
|
8 |
+
scikit-learn==1.4.2
|
9 |
+
slicer==0.0.8
|
10 |
+
xgboost==2.0.3
|
11 |
+
matplotlib
|
scaler.pkl
ADDED
Binary file (941 Bytes). View file
|
|