minusquare commited on
Commit
9c68205
·
verified ·
1 Parent(s): 9ba38e6

Upload folder using huggingface_hub

Browse files
.ipynb_checkpoints/gradio_hearttack_app-checkpoint.py CHANGED
@@ -4,18 +4,24 @@ import joblib
4
  import numpy as np
5
  from sklearn.preprocessing import StandardScaler
6
  import pandas as pd
 
 
7
 
8
  # Load the model and the scaler
9
  model = joblib.load('best_XGB.pkl')
10
- scaler = joblib.load('scaler.pkl') # Load the scaler if you saved it during training
11
  cutoff = 0.42 # Custom cutoff probability
12
 
13
- # Define the prediction function with preprocessing and scaling
 
 
 
14
  def predict_heart_attack(Gender, age, cigsPerDay, BPMeds, prevalentHyp, diabetes, totChol, sysBP, diaBP, BMI, heartRate, glucose):
15
  # Define feature names in the same order as the training data
16
  feature_names = ['Gender', 'age', 'cigsPerDay', 'BPMeds', 'prevalentHyp', 'diabetes', 'totChol', 'sysBP', 'diaBP', 'BMI', 'heartRate', 'glucose']
 
17
  # Create a DataFrame with the correct feature names for prediction
18
- features = pd.DataFrame([[Gender, age, cigsPerDay, BPMeds, prevalentHyp, diabetes, totChol, sysBP, diaBP, BMI, heartRate, glucose]], columns=feature_names)
19
 
20
  # Standardize the features (scaling)
21
  scaled_features = scaler.transform(features)
@@ -29,11 +35,19 @@ def predict_heart_attack(Gender, age, cigsPerDay, BPMeds, prevalentHyp, diabetes
29
  else:
30
  prediction_class = 0
31
 
 
 
 
 
 
 
 
 
32
  result = f"Predicted Probability: {proba[0]*100:.2f}%. Predicted Class with cutoff {cutoff}: {prediction_class}"
33
 
34
- return result
35
 
36
- # Create the Gradio interface with preprocessing and prediction logic
37
  with gr.Blocks() as app:
38
  with gr.Row():
39
  with gr.Column():
@@ -41,14 +55,14 @@ with gr.Blocks() as app:
41
  cigsPerDay = gr.Slider(0, 40, step=1, label="Cigarettes per Day")
42
  prevalentHyp = gr.Radio([0, 1], label="Prevalent Hypertension (0=No, 1=Yes)")
43
  totChol = gr.Slider(100, 400, step=1, label="Total Cholesterol in mg/dl")
44
- diaBP = gr.Slider(60, 120, step=1, label="Diastolic/Higher BP")
45
  heartRate = gr.Slider(50, 120, step=1, label="Heart Rate")
46
 
47
  with gr.Column():
48
  age = gr.Slider(20, 80, step=1, label="Age (years)")
49
  BPMeds = gr.Radio([0, 1], label="On BP Medications (0=No, 1=Yes)")
50
  diabetes = gr.Radio([0, 1], label="Diabetes (0=No, 1=Yes)")
51
- sysBP = gr.Slider(90, 200, step=1, label="Systolic BP/Lower BP")
52
  BMI = gr.Slider(15, 40, step=0.1, label="Body Mass Index (BMI) in kg/m2")
53
  glucose = gr.Slider(50, 250, step=1, label="Fasting Glucose Level")
54
 
@@ -59,8 +73,11 @@ with gr.Blocks() as app:
59
  with gr.Row():
60
  prediction_output = gr.Textbox(label="", interactive=False, elem_id="prediction_output")
61
 
 
 
 
62
  # Link inputs and prediction output
63
  submit_btn = gr.Button("Submit")
64
- submit_btn.click(fn=predict_heart_attack, inputs=[Gender, age, cigsPerDay, BPMeds, prevalentHyp, diabetes, totChol, sysBP, diaBP, BMI, heartRate, glucose], outputs=prediction_output)
65
 
66
  app.launch(share = True)
 
4
  import numpy as np
5
  from sklearn.preprocessing import StandardScaler
6
  import pandas as pd
7
+ import shap
8
+ import matplotlib.pyplot as plt
9
 
10
  # Load the model and the scaler
11
  model = joblib.load('best_XGB.pkl')
12
+ scaler = joblib.load('scaler.pkl') # Ensure the scaler is saved and loaded with the same scikit-learn version
13
  cutoff = 0.42 # Custom cutoff probability
14
 
15
+ # Use TreeExplainer for XGBoost models
16
+ explainer = shap.TreeExplainer(model)
17
+
18
+ # Define the prediction function with preprocessing, scaling, and SHAP analysis
19
  def predict_heart_attack(Gender, age, cigsPerDay, BPMeds, prevalentHyp, diabetes, totChol, sysBP, diaBP, BMI, heartRate, glucose):
20
  # Define feature names in the same order as the training data
21
  feature_names = ['Gender', 'age', 'cigsPerDay', 'BPMeds', 'prevalentHyp', 'diabetes', 'totChol', 'sysBP', 'diaBP', 'BMI', 'heartRate', 'glucose']
22
+
23
  # Create a DataFrame with the correct feature names for prediction
24
+ features = pd.DataFrame([[Gender, age, cigsPerDay, BPMeds, prevalentHyp, diabetes, totChol, sysBP, diaBP, BMI, heartRate, glucose]], columns=feature_names)
25
 
26
  # Standardize the features (scaling)
27
  scaled_features = scaler.transform(features)
 
35
  else:
36
  prediction_class = 0
37
 
38
+ # Generate SHAP values for the prediction using the explainer
39
+ shap_values = explainer(features)
40
+
41
+ # Plot SHAP values
42
+ plt.figure(figsize=(8, 6))
43
+ shap.waterfall_plot(shap_values[0]) # Using the SHAP Explanation object
44
+ plt.savefig('shap_plot.png') # Save SHAP plot to a file
45
+
46
  result = f"Predicted Probability: {proba[0]*100:.2f}%. Predicted Class with cutoff {cutoff}: {prediction_class}"
47
 
48
+ return result, 'shap_plot.png' # Return the prediction and SHAP plot
49
 
50
+ # Create the Gradio interface with preprocessing, prediction, and SHAP visualization
51
  with gr.Blocks() as app:
52
  with gr.Row():
53
  with gr.Column():
 
55
  cigsPerDay = gr.Slider(0, 40, step=1, label="Cigarettes per Day")
56
  prevalentHyp = gr.Radio([0, 1], label="Prevalent Hypertension (0=No, 1=Yes)")
57
  totChol = gr.Slider(100, 400, step=1, label="Total Cholesterol in mg/dl")
58
+ diaBP = gr.Slider(60, 120, step=1, label="Diastolic/Lower BP")
59
  heartRate = gr.Slider(50, 120, step=1, label="Heart Rate")
60
 
61
  with gr.Column():
62
  age = gr.Slider(20, 80, step=1, label="Age (years)")
63
  BPMeds = gr.Radio([0, 1], label="On BP Medications (0=No, 1=Yes)")
64
  diabetes = gr.Radio([0, 1], label="Diabetes (0=No, 1=Yes)")
65
+ sysBP = gr.Slider(90, 200, step=1, label="Systolic BP/Higher BP")
66
  BMI = gr.Slider(15, 40, step=0.1, label="Body Mass Index (BMI) in kg/m2")
67
  glucose = gr.Slider(50, 250, step=1, label="Fasting Glucose Level")
68
 
 
73
  with gr.Row():
74
  prediction_output = gr.Textbox(label="", interactive=False, elem_id="prediction_output")
75
 
76
+ with gr.Row():
77
+ shap_plot_output = gr.Image(label="SHAP Analysis")
78
+
79
  # Link inputs and prediction output
80
  submit_btn = gr.Button("Submit")
81
+ submit_btn.click(fn=predict_heart_attack, inputs=[Gender, age, cigsPerDay, BPMeds, prevalentHyp, diabetes, totChol, sysBP, diaBP, BMI, heartRate, glucose], outputs=[prediction_output, shap_plot_output])
82
 
83
  app.launch(share = True)
.ipynb_checkpoints/gradio_hearttack_app_old-checkpoint.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import xgboost as xgb
3
+ import joblib
4
+ import numpy as np
5
+ from sklearn.preprocessing import StandardScaler
6
+ import pandas as pd
7
+ import shap
8
+ import matplotlib.pyplot as plt
9
+
10
+ # Load the model and the scaler
11
+ model = joblib.load('best_XGB.pkl')
12
+ scaler = joblib.load('scaler.pkl') # Load the scaler that was saved during training
13
+ cutoff = 0.42 # Custom cutoff probability
14
+
15
+ # Load SHAP explainer based on your XGBoost model
16
+ explainer = shap.Explainer(model)
17
+
18
+ # Define the prediction function with preprocessing, scaling, and SHAP analysis
19
+ def predict_heart_attack(Gender, age, cigsPerDay, BPMeds, prevalentHyp, diabetes, totChol, sysBP, diaBP, BMI, heartRate, glucose):
20
+ # Define feature names in the same order as the training data
21
+ feature_names = ['Gender', 'age', 'cigsPerDay', 'BPMeds', 'prevalentHyp', 'diabetes', 'totChol', 'sysBP', 'diaBP', 'BMI', 'heartRate', 'glucose']
22
+
23
+ # Create a DataFrame with the correct feature names for prediction
24
+ features = pd.DataFrame([[Gender, age, cigsPerDay, BPMeds, prevalentHyp, diabetes, totChol, sysBP, diaBP, BMI, heartRate, glucose]], columns=feature_names)
25
+
26
+ # Standardize the features (scaling)
27
+ scaled_features = scaler.transform(features)
28
+
29
+ # Predict probabilities
30
+ proba = model.predict_proba(scaled_features)[:, 1] # Probability of class 1 (heart attack)
31
+
32
+ # Apply custom cutoff
33
+ if proba[0] >= cutoff:
34
+ prediction_class = 1
35
+ else:
36
+ prediction_class = 0
37
+
38
+ # Generate SHAP values for the prediction
39
+ shap_values = explainer(scaled_features)
40
+
41
+ # Plot SHAP values
42
+ plt.figure(figsize=(8, 6))
43
+ shap.waterfall_plot(shap_values[0])
44
+ plt.savefig('shap_plot.png') # Save SHAP plot to a file
45
+
46
+ result = f"Predicted Probability: {proba[0]*100:.2f}%. Predicted Class with cutoff {cutoff}: {prediction_class}"
47
+
48
+ return result, 'shap_plot.png' # Return the prediction and SHAP plot
49
+
50
+ # Create the Gradio interface with preprocessing, prediction, and SHAP visualization
51
+ with gr.Blocks() as app:
52
+ with gr.Row():
53
+ with gr.Column():
54
+ Gender = gr.Radio([0, 1], label="Gender (0=Female, 1=Male)")
55
+ cigsPerDay = gr.Slider(0, 40, step=1, label="Cigarettes per Day")
56
+ prevalentHyp = gr.Radio([0, 1], label="Prevalent Hypertension (0=No, 1=Yes)")
57
+ totChol = gr.Slider(100, 400, step=1, label="Total Cholesterol in mg/dl")
58
+ diaBP = gr.Slider(60, 120, step=1, label="Diastolic/Lower BP")
59
+ heartRate = gr.Slider(50, 120, step=1, label="Heart Rate")
60
+
61
+ with gr.Column():
62
+ age = gr.Slider(20, 80, step=1, label="Age (years)")
63
+ BPMeds = gr.Radio([0, 1], label="On BP Medications (0=No, 1=Yes)")
64
+ diabetes = gr.Radio([0, 1], label="Diabetes (0=No, 1=Yes)")
65
+ sysBP = gr.Slider(90, 200, step=1, label="Systolic BP/Higher BP")
66
+ BMI = gr.Slider(15, 40, step=0.1, label="Body Mass Index (weight in kg/ height in meter squared)(BMI) in kg/m2")
67
+ glucose = gr.Slider(50, 250, step=1, label="Fasting Glucose Level")
68
+
69
+ # Center-aligned prediction output
70
+ with gr.Row():
71
+ gr.HTML("<div style='text-align: center; width: 100%'>Heart Attack Prediction</div>")
72
+
73
+ with gr.Row():
74
+ prediction_output = gr.Textbox(label="", interactive=False, elem_id="prediction_output")
75
+
76
+ with gr.Row():
77
+ shap_plot_output = gr.Image(label="SHAP Analysis")
78
+
79
+ # Link inputs and prediction output
80
+ submit_btn = gr.Button("Submit")
81
+ submit_btn.click(fn=predict_heart_attack, inputs=[Gender, age, cigsPerDay, BPMeds, prevalentHyp, diabetes, totChol, sysBP, diaBP, BMI, heartRate, glucose], outputs=[prediction_output, shap_plot_output])
82
+
83
+ app.launch()
.ipynb_checkpoints/requirements-checkpoint.txt CHANGED
@@ -1,247 +1,7 @@
1
- absl-py==2.1.0
2
- aiofiles==23.2.1
3
- alembic==1.13.3
4
- altair==5.3.0
5
- aniso8601==9.0.1
6
- annotated-types==0.7.0
7
- anyio==4.6.2.post1
8
- asn1crypto==1.5.1
9
- asttokens==2.4.1
10
- astunparse==1.6.3
11
- attrs==23.2.0
12
- Automat==22.10.0
13
- bayesian-optimization==1.4.3
14
- beautifulsoup4==4.12.3
15
- blinker==1.8.2
16
- cachetools==5.3.3
17
- certifi==2024.2.2
18
- cffi==1.16.0
19
- charset-normalizer==3.3.2
20
- chromedriver-autoinstaller==0.6.4
21
- click==8.1.7
22
- cloudpickle==3.1.0
23
- colorama==0.4.6
24
- comm==0.2.2
25
- constantly==23.10.4
26
- contourpy==1.2.1
27
- convertdate==2.4.0
28
- cryptography==43.0.0
29
- cssselect==1.2.0
30
- cycler==0.12.1
31
- Cython==3.0.10
32
- dash==2.17.0
33
- dash-core-components==2.0.0
34
- dash-html-components==2.0.0
35
- dash-table==5.0.0
36
- databricks-sdk==0.34.0
37
- dearpygui==1.11.1
38
- debugpy==1.8.1
39
- decorator==5.1.1
40
- defusedxml==0.7.1
41
- Deprecated==1.2.14
42
- dnspython==2.6.1
43
- docker==7.1.0
44
- docutils==0.21.2
45
- et-xmlfile==1.1.0
46
- executing==2.0.1
47
- fakeredis==2.23.2
48
- fastapi==0.115.2
49
- fastjsonschema==2.19.1
50
- ffmpy==0.4.0
51
- filelock==3.15.4
52
- Flask==3.0.3
53
- flatbuffers==24.3.25
54
- fonttools==4.51.0
55
- frozendict==2.4.4
56
- fsspec==2024.9.0
57
- gast==0.5.4
58
- gitdb==4.0.11
59
- GitPython==3.1.43
60
- google-auth==2.35.0
61
- google-pasta==0.2.0
62
  gradio==5.1.0
63
  gradio_client==1.4.0
64
- graphene==3.3
65
- graphql-core==3.2.5
66
- graphql-relay==3.2.0
67
- graphviz==0.20.3
68
- greenlet==3.0.3
69
- grpcio==1.64.1
70
- gunicorn==23.0.0
71
- h11==0.14.0
72
- h5py==3.11.0
73
- holidays==0.53
74
- html5lib==1.1
75
- httpcore==1.0.6
76
- httpx==0.27.2
77
- huggingface-hub==0.25.2
78
- hyperlink==21.0.0
79
- idna==3.7
80
- imbalanced-learn==0.12.4
81
- imblearn==0.0
82
- importlib_metadata==7.1.0
83
- incremental==24.7.0
84
- install==1.3.5
85
- ipykernel==6.29.4
86
- ipython==8.24.0
87
- itemadapter==0.9.0
88
- itemloaders==1.3.1
89
- itsdangerous==2.2.0
90
- jedi==0.19.1
91
- Jinja2==3.1.4
92
- jmespath==1.0.1
93
  joblib==1.4.2
94
- jsonschema==4.22.0
95
- jsonschema-specifications==2023.12.1
96
- jupyter_client==8.6.1
97
- jupyter_core==5.7.2
98
- keras==3.3.3
99
- Kivy==2.3.0
100
- Kivy-Garden==0.1.5
101
- kiwisolver==1.4.5
102
- libclang==18.1.1
103
- lxml==5.2.2
104
- Mako==1.3.5
105
- Markdown==3.6
106
- markdown-it-py==3.0.0
107
- MarkupSafe==2.1.5
108
- matplotlib==3.8.4
109
- matplotlib-inline==0.1.7
110
- mdurl==0.1.2
111
- ml-dtypes==0.3.2
112
- mlflow==2.17.0
113
- mlflow-skinny==2.17.0
114
- mlxtend==0.23.1
115
- multitasking==0.0.11
116
- namex==0.0.8
117
- nbformat==5.10.4
118
- nest-asyncio==1.6.0
119
- networkx==3.3
120
  numpy==1.26.4
121
- nvidia-cublas-cu12==12.3.4.1
122
- nvidia-cuda-cupti-cu12==12.3.101
123
- nvidia-cuda-nvcc-cu12==12.3.107
124
- nvidia-cuda-nvrtc-cu12==12.3.107
125
- nvidia-cuda-runtime-cu12==12.3.101
126
- nvidia-cudnn-cu12==8.9.7.29
127
- nvidia-cufft-cu12==11.0.12.1
128
- nvidia-curand-cu12==10.3.4.107
129
- nvidia-cusolver-cu12==11.5.4.101
130
- nvidia-cusparse-cu12==12.2.0.103
131
- nvidia-nccl-cu12==2.19.3
132
- nvidia-nvjitlink-cu12==12.3.101
133
- openpyxl==3.1.2
134
- opentelemetry-api==1.27.0
135
- opentelemetry-sdk==1.27.0
136
- opentelemetry-semantic-conventions==0.48b0
137
- opt-einsum==3.3.0
138
- optree==0.11.0
139
- orjson==3.10.7
140
- outcome==1.3.0.post0
141
- packaging==24.0
142
  pandas==2.2.2
143
- parsel==1.9.1
144
- parso==0.8.4
145
- patsy==0.5.6
146
- peewee==3.17.5
147
- pexpect==4.9.0
148
- pg8000==1.31.2
149
- pillow==10.3.0
150
- platformdirs==4.2.1
151
- plotly==5.22.0
152
- pmdarima==2.0.4
153
- prompt-toolkit==3.0.43
154
- Protego==0.3.1
155
- protobuf==4.25.3
156
- psutil==5.9.8
157
- ptyprocess==0.7.0
158
- pure-eval==0.2.2
159
- pyarrow==16.1.0
160
- pyasn1==0.6.0
161
- pyasn1_modules==0.4.0
162
- pycparser==2.22
163
- pydantic==2.9.2
164
- pydantic_core==2.23.4
165
- pydeck==0.9.1
166
- PyDispatcher==2.0.7
167
- pydot==2.0.0
168
- pydub==0.25.1
169
- Pygments==2.18.0
170
- PyMeeus==0.5.12
171
- pymongo==4.7.3
172
- pyOpenSSL==24.2.1
173
- pyparsing==3.1.2
174
- PySocks==1.7.1
175
- pystan==2.19.1.1
176
- python-dateutil==2.9.0.post0
177
- python-dotenv==1.0.1
178
- python-multipart==0.0.12
179
- pytz==2024.1
180
- PyYAML==6.0.2
181
- pyzmq==26.0.3
182
- queuelib==1.7.0
183
- redis==5.0.6
184
- referencing==0.35.1
185
- requests==2.31.0
186
- requests-file==2.1.0
187
- retrying==1.3.4
188
- rich==13.7.1
189
- rpds-py==0.18.1
190
- rsa==4.9
191
- ruff==0.7.0
192
  scikit-learn==1.4.2
193
- scipy==1.13.0
194
- scramp==1.4.5
195
- Scrapy==2.11.2
196
- seaborn==0.13.2
197
- selenium==4.23.1
198
- semantic-version==2.10.0
199
- service-identity==24.1.0
200
- shellingham==1.5.4
201
- six==1.16.0
202
- smmap==5.0.1
203
- sniffio==1.3.1
204
- sortedcontainers==2.4.0
205
- soupsieve==2.5
206
- SQLAlchemy==2.0.31
207
- sqlparse==0.5.1
208
- stack-data==0.6.3
209
- starlette==0.40.0
210
- statsmodels==0.14.2
211
- streamlit==1.36.0
212
- tenacity==8.3.0
213
- tensorboard==2.17.1
214
- tensorboard-data-server==0.7.2
215
- tensorflow==2.17.0
216
- tensorflow-io-gcs-filesystem==0.37.0
217
- termcolor==2.4.0
218
- threadpoolctl==3.5.0
219
- tldextract==5.1.2
220
- toml==0.10.2
221
- tomlkit==0.12.0
222
- toolz==0.12.1
223
- tornado==6.4
224
- tqdm==4.66.5
225
- traitlets==5.14.3
226
- trio==0.26.0
227
- trio-websocket==0.11.1
228
- Twisted==24.3.0
229
- typer==0.12.5
230
- typing_extensions==4.11.0
231
- tzdata==2024.1
232
- urllib3==2.2.1
233
- uvicorn==0.32.0
234
- w3lib==2.2.1
235
- watchdog==4.0.1
236
- wcwidth==0.2.13
237
- webdriver-manager==4.0.2
238
- webencodings==0.5.1
239
- websocket-client==1.8.0
240
- websockets==12.0
241
- Werkzeug==3.0.3
242
- wrapt==1.16.0
243
- wsproto==1.2.0
244
- xgboost==2.0.3
245
- yfinance==0.2.40
246
- zipp==3.19.0
247
- zope.interface==6.4.post2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  gradio==5.1.0
2
  gradio_client==1.4.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  joblib==1.4.2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  numpy==1.26.4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  pandas==2.2.2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  scikit-learn==1.4.2
7
+ xgboost==2.0.3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Untitled.ipynb CHANGED
@@ -1,43 +1,9 @@
1
  {
2
  "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": 1,
6
- "id": "b1110063-e160-456d-ae0b-80d9cae3b8a5",
7
- "metadata": {},
8
- "outputs": [
9
- {
10
- "name": "stderr",
11
- "output_type": "stream",
12
- "text": [
13
- "/home/ecube/basicds_py311/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
14
- " from .autonotebook import tqdm as notebook_tqdm\n"
15
- ]
16
- }
17
- ],
18
- "source": [
19
- "import gradio as gr\n",
20
- "import xgboost as xgb\n",
21
- "import joblib\n",
22
- "import numpy as np\n",
23
- "from sklearn.preprocessing import StandardScaler\n",
24
- "import pandas as pd"
25
- ]
26
- },
27
- {
28
- "cell_type": "code",
29
- "execution_count": 2,
30
- "id": "5e2655ee-1663-44f1-b05c-708b32c23e6a",
31
- "metadata": {},
32
- "outputs": [],
33
- "source": [
34
- "!pip freeze >> requirements.txt"
35
- ]
36
- },
37
  {
38
  "cell_type": "code",
39
  "execution_count": null,
40
- "id": "c71ec40a-634c-4735-bba9-31da888e3a5b",
41
  "metadata": {},
42
  "outputs": [],
43
  "source": []
@@ -45,9 +11,9 @@
45
  ],
46
  "metadata": {
47
  "kernelspec": {
48
- "display_name": "basicds_py311",
49
  "language": "python",
50
- "name": "basicds_py311"
51
  },
52
  "language_info": {
53
  "codemirror_mode": {
 
1
  {
2
  "cells": [
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  {
4
  "cell_type": "code",
5
  "execution_count": null,
6
+ "id": "0ab0b011-35e0-4bef-a713-eb4a49b4e8a1",
7
  "metadata": {},
8
  "outputs": [],
9
  "source": []
 
11
  ],
12
  "metadata": {
13
  "kernelspec": {
14
+ "display_name": "heart_disease_prediction",
15
  "language": "python",
16
+ "name": "heart_disease_prediction"
17
  },
18
  "language_info": {
19
  "codemirror_mode": {
gradio_hearttack_app.py CHANGED
@@ -4,18 +4,24 @@ import joblib
4
  import numpy as np
5
  from sklearn.preprocessing import StandardScaler
6
  import pandas as pd
 
 
7
 
8
  # Load the model and the scaler
9
  model = joblib.load('best_XGB.pkl')
10
- scaler = joblib.load('scaler.pkl') # Load the scaler if you saved it during training
11
  cutoff = 0.42 # Custom cutoff probability
12
 
13
- # Define the prediction function with preprocessing and scaling
 
 
 
14
  def predict_heart_attack(Gender, age, cigsPerDay, BPMeds, prevalentHyp, diabetes, totChol, sysBP, diaBP, BMI, heartRate, glucose):
15
  # Define feature names in the same order as the training data
16
  feature_names = ['Gender', 'age', 'cigsPerDay', 'BPMeds', 'prevalentHyp', 'diabetes', 'totChol', 'sysBP', 'diaBP', 'BMI', 'heartRate', 'glucose']
 
17
  # Create a DataFrame with the correct feature names for prediction
18
- features = pd.DataFrame([[Gender, age, cigsPerDay, BPMeds, prevalentHyp, diabetes, totChol, sysBP, diaBP, BMI, heartRate, glucose]], columns=feature_names)
19
 
20
  # Standardize the features (scaling)
21
  scaled_features = scaler.transform(features)
@@ -29,11 +35,19 @@ def predict_heart_attack(Gender, age, cigsPerDay, BPMeds, prevalentHyp, diabetes
29
  else:
30
  prediction_class = 0
31
 
 
 
 
 
 
 
 
 
32
  result = f"Predicted Probability: {proba[0]*100:.2f}%. Predicted Class with cutoff {cutoff}: {prediction_class}"
33
 
34
- return result
35
 
36
- # Create the Gradio interface with preprocessing and prediction logic
37
  with gr.Blocks() as app:
38
  with gr.Row():
39
  with gr.Column():
@@ -59,8 +73,11 @@ with gr.Blocks() as app:
59
  with gr.Row():
60
  prediction_output = gr.Textbox(label="", interactive=False, elem_id="prediction_output")
61
 
 
 
 
62
  # Link inputs and prediction output
63
  submit_btn = gr.Button("Submit")
64
- submit_btn.click(fn=predict_heart_attack, inputs=[Gender, age, cigsPerDay, BPMeds, prevalentHyp, diabetes, totChol, sysBP, diaBP, BMI, heartRate, glucose], outputs=prediction_output)
65
 
66
  app.launch(share = True)
 
4
  import numpy as np
5
  from sklearn.preprocessing import StandardScaler
6
  import pandas as pd
7
+ import shap
8
+ import matplotlib.pyplot as plt
9
 
10
  # Load the model and the scaler
11
  model = joblib.load('best_XGB.pkl')
12
+ scaler = joblib.load('scaler.pkl') # Ensure the scaler is saved and loaded with the same scikit-learn version
13
  cutoff = 0.42 # Custom cutoff probability
14
 
15
+ # Use TreeExplainer for XGBoost models
16
+ explainer = shap.TreeExplainer(model)
17
+
18
+ # Define the prediction function with preprocessing, scaling, and SHAP analysis
19
  def predict_heart_attack(Gender, age, cigsPerDay, BPMeds, prevalentHyp, diabetes, totChol, sysBP, diaBP, BMI, heartRate, glucose):
20
  # Define feature names in the same order as the training data
21
  feature_names = ['Gender', 'age', 'cigsPerDay', 'BPMeds', 'prevalentHyp', 'diabetes', 'totChol', 'sysBP', 'diaBP', 'BMI', 'heartRate', 'glucose']
22
+
23
  # Create a DataFrame with the correct feature names for prediction
24
+ features = pd.DataFrame([[Gender, age, cigsPerDay, BPMeds, prevalentHyp, diabetes, totChol, sysBP, diaBP, BMI, heartRate, glucose]], columns=feature_names)
25
 
26
  # Standardize the features (scaling)
27
  scaled_features = scaler.transform(features)
 
35
  else:
36
  prediction_class = 0
37
 
38
+ # Generate SHAP values for the prediction using the explainer
39
+ shap_values = explainer(features)
40
+
41
+ # Plot SHAP values
42
+ plt.figure(figsize=(8, 6))
43
+ shap.waterfall_plot(shap_values[0]) # Using the SHAP Explanation object
44
+ plt.savefig('shap_plot.png') # Save SHAP plot to a file
45
+
46
  result = f"Predicted Probability: {proba[0]*100:.2f}%. Predicted Class with cutoff {cutoff}: {prediction_class}"
47
 
48
+ return result, 'shap_plot.png' # Return the prediction and SHAP plot
49
 
50
+ # Create the Gradio interface with preprocessing, prediction, and SHAP visualization
51
  with gr.Blocks() as app:
52
  with gr.Row():
53
  with gr.Column():
 
73
  with gr.Row():
74
  prediction_output = gr.Textbox(label="", interactive=False, elem_id="prediction_output")
75
 
76
+ with gr.Row():
77
+ shap_plot_output = gr.Image(label="SHAP Analysis")
78
+
79
  # Link inputs and prediction output
80
  submit_btn = gr.Button("Submit")
81
+ submit_btn.click(fn=predict_heart_attack, inputs=[Gender, age, cigsPerDay, BPMeds, prevalentHyp, diabetes, totChol, sysBP, diaBP, BMI, heartRate, glucose], outputs=[prediction_output, shap_plot_output])
82
 
83
  app.launch(share = True)
gradio_hearttack_app_old.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import xgboost as xgb
3
+ import joblib
4
+ import numpy as np
5
+ from sklearn.preprocessing import StandardScaler
6
+ import pandas as pd
7
+ import shap
8
+ import matplotlib.pyplot as plt
9
+
10
+ # Load the model and the scaler
11
+ model = joblib.load('best_XGB.pkl')
12
+ scaler = joblib.load('scaler.pkl') # Load the scaler that was saved during training
13
+ cutoff = 0.42 # Custom cutoff probability
14
+
15
+ # Load SHAP explainer based on your XGBoost model
16
+ explainer = shap.Explainer(model)
17
+
18
+ # Define the prediction function with preprocessing, scaling, and SHAP analysis
19
+ def predict_heart_attack(Gender, age, cigsPerDay, BPMeds, prevalentHyp, diabetes, totChol, sysBP, diaBP, BMI, heartRate, glucose):
20
+ # Define feature names in the same order as the training data
21
+ feature_names = ['Gender', 'age', 'cigsPerDay', 'BPMeds', 'prevalentHyp', 'diabetes', 'totChol', 'sysBP', 'diaBP', 'BMI', 'heartRate', 'glucose']
22
+
23
+ # Create a DataFrame with the correct feature names for prediction
24
+ features = pd.DataFrame([[Gender, age, cigsPerDay, BPMeds, prevalentHyp, diabetes, totChol, sysBP, diaBP, BMI, heartRate, glucose]], columns=feature_names)
25
+
26
+ # Standardize the features (scaling)
27
+ scaled_features = scaler.transform(features)
28
+
29
+ # Predict probabilities
30
+ proba = model.predict_proba(scaled_features)[:, 1] # Probability of class 1 (heart attack)
31
+
32
+ # Apply custom cutoff
33
+ if proba[0] >= cutoff:
34
+ prediction_class = 1
35
+ else:
36
+ prediction_class = 0
37
+
38
+ # Generate SHAP values for the prediction
39
+ shap_values = explainer(scaled_features)
40
+
41
+ # Plot SHAP values
42
+ plt.figure(figsize=(8, 6))
43
+ shap.waterfall_plot(shap_values[0])
44
+ plt.savefig('shap_plot.png') # Save SHAP plot to a file
45
+
46
+ result = f"Predicted Probability: {proba[0]*100:.2f}%. Predicted Class with cutoff {cutoff}: {prediction_class}"
47
+
48
+ return result, 'shap_plot.png' # Return the prediction and SHAP plot
49
+
50
+ # Create the Gradio interface with preprocessing, prediction, and SHAP visualization
51
+ with gr.Blocks() as app:
52
+ with gr.Row():
53
+ with gr.Column():
54
+ Gender = gr.Radio([0, 1], label="Gender (0=Female, 1=Male)")
55
+ cigsPerDay = gr.Slider(0, 40, step=1, label="Cigarettes per Day")
56
+ prevalentHyp = gr.Radio([0, 1], label="Prevalent Hypertension (0=No, 1=Yes)")
57
+ totChol = gr.Slider(100, 400, step=1, label="Total Cholesterol in mg/dl")
58
+ diaBP = gr.Slider(60, 120, step=1, label="Diastolic/Lower BP")
59
+ heartRate = gr.Slider(50, 120, step=1, label="Heart Rate")
60
+
61
+ with gr.Column():
62
+ age = gr.Slider(20, 80, step=1, label="Age (years)")
63
+ BPMeds = gr.Radio([0, 1], label="On BP Medications (0=No, 1=Yes)")
64
+ diabetes = gr.Radio([0, 1], label="Diabetes (0=No, 1=Yes)")
65
+ sysBP = gr.Slider(90, 200, step=1, label="Systolic BP/Higher BP")
66
+ BMI = gr.Slider(15, 40, step=0.1, label="Body Mass Index (weight in kg/ height in meter squared)(BMI) in kg/m2")
67
+ glucose = gr.Slider(50, 250, step=1, label="Fasting Glucose Level")
68
+
69
+ # Center-aligned prediction output
70
+ with gr.Row():
71
+ gr.HTML("<div style='text-align: center; width: 100%'>Heart Attack Prediction</div>")
72
+
73
+ with gr.Row():
74
+ prediction_output = gr.Textbox(label="", interactive=False, elem_id="prediction_output")
75
+
76
+ with gr.Row():
77
+ shap_plot_output = gr.Image(label="SHAP Analysis")
78
+
79
+ # Link inputs and prediction output
80
+ submit_btn = gr.Button("Submit")
81
+ submit_btn.click(fn=predict_heart_attack, inputs=[Gender, age, cigsPerDay, BPMeds, prevalentHyp, diabetes, totChol, sysBP, diaBP, BMI, heartRate, glucose], outputs=[prediction_output, shap_plot_output])
82
+
83
+ app.launch()
requirements.txt CHANGED
@@ -1,7 +1,10 @@
 
1
  gradio==5.1.0
2
  gradio_client==1.4.0
3
  joblib==1.4.2
4
  numpy==1.26.4
5
- pandas
 
6
  scikit-learn==1.4.2
7
- xgboost
 
 
1
+ cloudpickle==3.1.0
2
  gradio==5.1.0
3
  gradio_client==1.4.0
4
  joblib==1.4.2
5
  numpy==1.26.4
6
+ pandas==2.2.2
7
+ shap==0.46.0
8
  scikit-learn==1.4.2
9
+ slicer==0.0.8
10
+ xgboost==2.0.3
shap_plot.png ADDED
test_cases ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ Gender 1 age 62.729653 cigsPerDay 0 BPMeds 0 prevalentHyp 1 diabetes 0 totChol 172 sysBP 144 diaBP 84 BMI 26 heartRate 63 glucose 78.3 Predicted Probability (first observation): 59% Predicted Class with cutoff 0.42: 1 Stored y_test, y_proba, and custom predictions to disk.
2
+
3
+ Gender 1 age 49.000000 cigsPerDay 0 BPMeds 0 prevalentHyp 1 diabetes 0 totChol 170 sysBP 112 diaBP 79 BMI 21 heartRate 60 glucose 80
4
+ Predicted Probability (fifth observation): 0.13% Predicted Class with cutoff 0.42: 0
5
+