Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -15,6 +15,29 @@ cutoff = 42 # Custom cutoff probability
|
|
15 |
# Use TreeExplainer for XGBoost models
|
16 |
explainer = shap.TreeExplainer(model)
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
# Define the prediction function with preprocessing, scaling, and SHAP analysis
|
19 |
def predict_heart_attack(Gender, age, cigsPerDay, BPMeds, prevalentHyp, diabetes, totChol, sysBP, diaBP, BMI, heartRate, glucose):
|
20 |
# Define feature names in the same order as the training data
|
@@ -30,10 +53,10 @@ def predict_heart_attack(Gender, age, cigsPerDay, BPMeds, prevalentHyp, diabetes
|
|
30 |
proba = model.predict_proba(scaled_features)[:, 1] # Probability of class 1 (heart attack)
|
31 |
|
32 |
# Apply custom cutoff
|
33 |
-
if proba[0]*100 >= cutoff:
|
34 |
prediction_class = "Heart_Attack_Risk.Consult your doctor"
|
35 |
else:
|
36 |
-
prediction_class = "No_Heart_Attack_Risk.Still make regular checkup
|
37 |
|
38 |
# Generate SHAP values for the prediction using the explainer
|
39 |
shap_values = explainer(features)
|
@@ -43,9 +66,12 @@ def predict_heart_attack(Gender, age, cigsPerDay, BPMeds, prevalentHyp, diabetes
|
|
43 |
shap.waterfall_plot(shap_values[0]) # Using the SHAP Explanation object
|
44 |
plt.savefig('shap_plot.png') # Save SHAP plot to a file
|
45 |
|
46 |
-
|
|
|
|
|
|
|
47 |
|
48 |
-
return result, 'shap_plot.png' # Return the prediction and SHAP plot
|
49 |
|
50 |
# Create the Gradio interface with preprocessing, prediction, and SHAP visualization
|
51 |
with gr.Blocks() as app:
|
@@ -66,9 +92,12 @@ with gr.Blocks() as app:
|
|
66 |
BMI = gr.Slider(15, 40, step=0.1, label="Body Mass Index (BMI) in kg/m2")
|
67 |
glucose = gr.Slider(50, 250, step=1, label="Fasting Glucose Level")
|
68 |
|
69 |
-
#
|
|
|
|
|
|
|
70 |
with gr.Row():
|
71 |
-
gr.
|
72 |
|
73 |
with gr.Row():
|
74 |
prediction_output = gr.Textbox(label="", interactive=False, elem_id="prediction_output")
|
@@ -78,6 +107,6 @@ with gr.Blocks() as app:
|
|
78 |
|
79 |
# Link inputs and prediction output
|
80 |
submit_btn = gr.Button("Submit")
|
81 |
-
submit_btn.click(fn=predict_heart_attack, inputs=[Gender, age, cigsPerDay, BPMeds, prevalentHyp, diabetes, totChol, sysBP, diaBP, BMI, heartRate, glucose], outputs=[prediction_output, shap_plot_output])
|
82 |
|
83 |
-
app.launch(share
|
|
|
15 |
# Use TreeExplainer for XGBoost models
|
16 |
explainer = shap.TreeExplainer(model)
|
17 |
|
18 |
+
# Define the function to draw the semicircular scale
|
19 |
+
def draw_scale(probability):
|
20 |
+
fig, ax = plt.subplots(figsize=(6, 3))
|
21 |
+
|
22 |
+
# Plot the semicircular scale
|
23 |
+
ax.barh(0, 1, color='green', left=0, height=0.3)
|
24 |
+
ax.barh(0, 1, color='red', left=cutoff / 100, height=0.3)
|
25 |
+
|
26 |
+
# Add arrow indicator based on predicted probability
|
27 |
+
arrow_position = probability / 100
|
28 |
+
color = 'green' if probability < cutoff else 'red'
|
29 |
+
ax.annotate('', xy=(arrow_position, 0.15), xytext=(arrow_position, 0.3),
|
30 |
+
arrowprops=dict(facecolor=color, shrink=0.05))
|
31 |
+
|
32 |
+
# Remove axes and add labels
|
33 |
+
ax.set_xlim(0, 1)
|
34 |
+
ax.set_ylim(-0.5, 0.5)
|
35 |
+
ax.axis('off')
|
36 |
+
|
37 |
+
# Save the image
|
38 |
+
plt.savefig('scale_plot.png', bbox_inches='tight')
|
39 |
+
plt.close()
|
40 |
+
|
41 |
# Define the prediction function with preprocessing, scaling, and SHAP analysis
|
42 |
def predict_heart_attack(Gender, age, cigsPerDay, BPMeds, prevalentHyp, diabetes, totChol, sysBP, diaBP, BMI, heartRate, glucose):
|
43 |
# Define feature names in the same order as the training data
|
|
|
53 |
proba = model.predict_proba(scaled_features)[:, 1] # Probability of class 1 (heart attack)
|
54 |
|
55 |
# Apply custom cutoff
|
56 |
+
if proba[0] * 100 >= cutoff:
|
57 |
prediction_class = "Heart_Attack_Risk.Consult your doctor"
|
58 |
else:
|
59 |
+
prediction_class = "No_Heart_Attack_Risk.Still make regular checkup"
|
60 |
|
61 |
# Generate SHAP values for the prediction using the explainer
|
62 |
shap_values = explainer(features)
|
|
|
66 |
shap.waterfall_plot(shap_values[0]) # Using the SHAP Explanation object
|
67 |
plt.savefig('shap_plot.png') # Save SHAP plot to a file
|
68 |
|
69 |
+
# Draw semicircular scale
|
70 |
+
draw_scale(proba[0] * 100)
|
71 |
+
|
72 |
+
result = f"Predicted Probability: {proba[0] * 100:.2f}%. Predicted Class with cutoff {cutoff}%: {prediction_class}"
|
73 |
|
74 |
+
return result, 'scale_plot.png', 'shap_plot.png' # Return the prediction, scale, and SHAP plot
|
75 |
|
76 |
# Create the Gradio interface with preprocessing, prediction, and SHAP visualization
|
77 |
with gr.Blocks() as app:
|
|
|
92 |
BMI = gr.Slider(15, 40, step=0.1, label="Body Mass Index (BMI) in kg/m2")
|
93 |
glucose = gr.Slider(50, 250, step=1, label="Fasting Glucose Level")
|
94 |
|
95 |
+
# Display disclaimer in red uppercase letters above the scale
|
96 |
+
gr.HTML("<div style='text-align: center; color: red; font-weight: bold; font-size: 16px;'>RESULTS ARE NOT A SUBSTITUTE FOR ADVICE OF QUALIFIED MEDICAL PROFESSIONAL</div>")
|
97 |
+
|
98 |
+
# Center-aligned prediction output and semicircular scale
|
99 |
with gr.Row():
|
100 |
+
scale_output = gr.Image(label="Risk Indicator Scale")
|
101 |
|
102 |
with gr.Row():
|
103 |
prediction_output = gr.Textbox(label="", interactive=False, elem_id="prediction_output")
|
|
|
107 |
|
108 |
# Link inputs and prediction output
|
109 |
submit_btn = gr.Button("Submit")
|
110 |
+
submit_btn.click(fn=predict_heart_attack, inputs=[Gender, age, cigsPerDay, BPMeds, prevalentHyp, diabetes, totChol, sysBP, diaBP, BMI, heartRate, glucose], outputs=[prediction_output, scale_output, shap_plot_output])
|
111 |
|
112 |
+
app.launch(share=True)
|