minusquare commited on
Commit
3d119e0
·
verified ·
1 Parent(s): 6b4be23

Upload 4 files

Browse files
Files changed (5) hide show
  1. .gitattributes +1 -0
  2. app.py +83 -0
  3. best_XGB.pkl +3 -0
  4. requirements.txt +11 -0
  5. scaler.pkl +0 -0
.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ best_XGB.pkl filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import xgboost as xgb
3
+ import joblib
4
+ import numpy as np
5
+ from sklearn.preprocessing import StandardScaler
6
+ import pandas as pd
7
+ import shap
8
+ import matplotlib.pyplot as plt
9
+
10
+ # Load the model and the scaler
11
+ model = joblib.load('best_XGB.pkl')
12
+ scaler = joblib.load('scaler.pkl') # Ensure the scaler is saved and loaded with the same scikit-learn version
13
+ cutoff = 42 # Custom cutoff probability
14
+
15
+ # Use TreeExplainer for XGBoost models
16
+ explainer = shap.TreeExplainer(model)
17
+
18
+ # Define the prediction function with preprocessing, scaling, and SHAP analysis
19
+ def predict_heart_attack(Gender, age, cigsPerDay, BPMeds, prevalentHyp, diabetes, totChol, sysBP, diaBP, BMI, heartRate, glucose):
20
+ # Define feature names in the same order as the training data
21
+ feature_names = ['Gender', 'age', 'cigsPerDay', 'BPMeds', 'prevalentHyp', 'diabetes', 'totChol', 'sysBP', 'diaBP', 'BMI', 'heartRate', 'glucose']
22
+
23
+ # Create a DataFrame with the correct feature names for prediction
24
+ features = pd.DataFrame([[Gender, age, cigsPerDay, BPMeds, prevalentHyp, diabetes, totChol, sysBP, diaBP, BMI, heartRate, glucose]], columns=feature_names)
25
+
26
+ # Standardize the features (scaling)
27
+ scaled_features = scaler.transform(features)
28
+
29
+ # Predict probabilities
30
+ proba = model.predict_proba(scaled_features)[:, 1] # Probability of class 1 (heart attack)
31
+
32
+ # Apply custom cutoff
33
+ if proba[0]*100 >= cutoff:
34
+ prediction_class = "Heart_Attack_Risk.Consult your doctor"
35
+ else:
36
+ prediction_class = "No_Heart_Attack_Risk.Still make regular checkup "
37
+
38
+ # Generate SHAP values for the prediction using the explainer
39
+ shap_values = explainer(features)
40
+
41
+ # Plot SHAP values
42
+ plt.figure(figsize=(8, 6))
43
+ shap.waterfall_plot(shap_values[0]) # Using the SHAP Explanation object
44
+ plt.savefig('shap_plot.png') # Save SHAP plot to a file
45
+
46
+ result = f"Predicted Probability: {proba[0]*100:.2f}%. Predicted Class with cutoff {cutoff}%: {prediction_class}"
47
+
48
+ return result, 'shap_plot.png' # Return the prediction and SHAP plot
49
+
50
+ # Create the Gradio interface with preprocessing, prediction, and SHAP visualization
51
+ with gr.Blocks() as app:
52
+ with gr.Row():
53
+ with gr.Column():
54
+ Gender = gr.Radio([0, 1], label="Gender (0=Female, 1=Male)")
55
+ cigsPerDay = gr.Slider(0, 40, step=1, label="Cigarettes per Day")
56
+ prevalentHyp = gr.Radio([0, 1], label="Prevalent Hypertension (0=No, 1=Yes)")
57
+ totChol = gr.Slider(100, 400, step=1, label="Total Cholesterol in mg/dl")
58
+ diaBP = gr.Slider(60, 120, step=1, label="Diastolic/Lower BP")
59
+ heartRate = gr.Slider(50, 120, step=1, label="Heart Rate")
60
+
61
+ with gr.Column():
62
+ age = gr.Slider(20, 80, step=1, label="Age (years)")
63
+ BPMeds = gr.Radio([0, 1], label="On BP Medications (0=No, 1=Yes)")
64
+ diabetes = gr.Radio([0, 1], label="Diabetes (0=No, 1=Yes)")
65
+ sysBP = gr.Slider(90, 200, step=1, label="Systolic BP/Higher BP")
66
+ BMI = gr.Slider(15, 40, step=0.1, label="Body Mass Index (BMI) in kg/m2")
67
+ glucose = gr.Slider(50, 250, step=1, label="Fasting Glucose Level")
68
+
69
+ # Center-aligned prediction output
70
+ with gr.Row():
71
+ gr.HTML("<div style='text-align: center; width: 100%'>Heart Attack Prediction</div>")
72
+
73
+ with gr.Row():
74
+ prediction_output = gr.Textbox(label="", interactive=False, elem_id="prediction_output")
75
+
76
+ with gr.Row():
77
+ shap_plot_output = gr.Image(label="SHAP Analysis")
78
+
79
+ # Link inputs and prediction output
80
+ submit_btn = gr.Button("Submit")
81
+ submit_btn.click(fn=predict_heart_attack, inputs=[Gender, age, cigsPerDay, BPMeds, prevalentHyp, diabetes, totChol, sysBP, diaBP, BMI, heartRate, glucose], outputs=[prediction_output, shap_plot_output])
82
+
83
+ app.launch(share = True)
best_XGB.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3d5422417966b1cc081d0b6e9772e6cee262ee72f203853be4a28d53859fbbcf
3
+ size 1347623
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cloudpickle==3.1.0
2
+ gradio==5.1.0
3
+ gradio_client==1.4.0
4
+ joblib==1.4.2
5
+ numpy==1.26.4
6
+ pandas==2.2.2
7
+ shap==0.46.0
8
+ scikit-learn==1.4.2
9
+ slicer==0.0.8
10
+ xgboost==2.0.3
11
+ matplotlib
scaler.pkl ADDED
Binary file (941 Bytes). View file