minusquare commited on
Commit
d7a3122
·
verified ·
1 Parent(s): d17f035

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -8
app.py CHANGED
@@ -15,6 +15,29 @@ cutoff = 42 # Custom cutoff probability
15
  # Use TreeExplainer for XGBoost models
16
  explainer = shap.TreeExplainer(model)
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  # Define the prediction function with preprocessing, scaling, and SHAP analysis
19
  def predict_heart_attack(Gender, age, cigsPerDay, BPMeds, prevalentHyp, diabetes, totChol, sysBP, diaBP, BMI, heartRate, glucose):
20
  # Define feature names in the same order as the training data
@@ -30,10 +53,10 @@ def predict_heart_attack(Gender, age, cigsPerDay, BPMeds, prevalentHyp, diabetes
30
  proba = model.predict_proba(scaled_features)[:, 1] # Probability of class 1 (heart attack)
31
 
32
  # Apply custom cutoff
33
- if proba[0]*100 >= cutoff:
34
  prediction_class = "Heart_Attack_Risk.Consult your doctor"
35
  else:
36
- prediction_class = "No_Heart_Attack_Risk.Still make regular checkup "
37
 
38
  # Generate SHAP values for the prediction using the explainer
39
  shap_values = explainer(features)
@@ -43,9 +66,12 @@ def predict_heart_attack(Gender, age, cigsPerDay, BPMeds, prevalentHyp, diabetes
43
  shap.waterfall_plot(shap_values[0]) # Using the SHAP Explanation object
44
  plt.savefig('shap_plot.png') # Save SHAP plot to a file
45
 
46
- result = f"Predicted Probability: {proba[0]*100:.2f}%. Predicted Class with cutoff {cutoff}%: {prediction_class}"
 
 
 
47
 
48
- return result, 'shap_plot.png' # Return the prediction and SHAP plot
49
 
50
  # Create the Gradio interface with preprocessing, prediction, and SHAP visualization
51
  with gr.Blocks() as app:
@@ -66,9 +92,12 @@ with gr.Blocks() as app:
66
  BMI = gr.Slider(15, 40, step=0.1, label="Body Mass Index (BMI) in kg/m2")
67
  glucose = gr.Slider(50, 250, step=1, label="Fasting Glucose Level")
68
 
69
- # Center-aligned prediction output
 
 
 
70
  with gr.Row():
71
- gr.HTML("<div style='text-align: center; width: 100%'>Heart Attack Prediction</div>")
72
 
73
  with gr.Row():
74
  prediction_output = gr.Textbox(label="", interactive=False, elem_id="prediction_output")
@@ -78,6 +107,6 @@ with gr.Blocks() as app:
78
 
79
  # Link inputs and prediction output
80
  submit_btn = gr.Button("Submit")
81
- submit_btn.click(fn=predict_heart_attack, inputs=[Gender, age, cigsPerDay, BPMeds, prevalentHyp, diabetes, totChol, sysBP, diaBP, BMI, heartRate, glucose], outputs=[prediction_output, shap_plot_output])
82
 
83
- app.launch(share = True)
 
15
  # Use TreeExplainer for XGBoost models
16
  explainer = shap.TreeExplainer(model)
17
 
18
+ # Define the function to draw the semicircular scale
19
+ def draw_scale(probability):
20
+ fig, ax = plt.subplots(figsize=(6, 3))
21
+
22
+ # Plot the semicircular scale
23
+ ax.barh(0, 1, color='green', left=0, height=0.3)
24
+ ax.barh(0, 1, color='red', left=cutoff / 100, height=0.3)
25
+
26
+ # Add arrow indicator based on predicted probability
27
+ arrow_position = probability / 100
28
+ color = 'green' if probability < cutoff else 'red'
29
+ ax.annotate('', xy=(arrow_position, 0.15), xytext=(arrow_position, 0.3),
30
+ arrowprops=dict(facecolor=color, shrink=0.05))
31
+
32
+ # Remove axes and add labels
33
+ ax.set_xlim(0, 1)
34
+ ax.set_ylim(-0.5, 0.5)
35
+ ax.axis('off')
36
+
37
+ # Save the image
38
+ plt.savefig('scale_plot.png', bbox_inches='tight')
39
+ plt.close()
40
+
41
  # Define the prediction function with preprocessing, scaling, and SHAP analysis
42
  def predict_heart_attack(Gender, age, cigsPerDay, BPMeds, prevalentHyp, diabetes, totChol, sysBP, diaBP, BMI, heartRate, glucose):
43
  # Define feature names in the same order as the training data
 
53
  proba = model.predict_proba(scaled_features)[:, 1] # Probability of class 1 (heart attack)
54
 
55
  # Apply custom cutoff
56
+ if proba[0] * 100 >= cutoff:
57
  prediction_class = "Heart_Attack_Risk.Consult your doctor"
58
  else:
59
+ prediction_class = "No_Heart_Attack_Risk.Still make regular checkup"
60
 
61
  # Generate SHAP values for the prediction using the explainer
62
  shap_values = explainer(features)
 
66
  shap.waterfall_plot(shap_values[0]) # Using the SHAP Explanation object
67
  plt.savefig('shap_plot.png') # Save SHAP plot to a file
68
 
69
+ # Draw semicircular scale
70
+ draw_scale(proba[0] * 100)
71
+
72
+ result = f"Predicted Probability: {proba[0] * 100:.2f}%. Predicted Class with cutoff {cutoff}%: {prediction_class}"
73
 
74
+ return result, 'scale_plot.png', 'shap_plot.png' # Return the prediction, scale, and SHAP plot
75
 
76
  # Create the Gradio interface with preprocessing, prediction, and SHAP visualization
77
  with gr.Blocks() as app:
 
92
  BMI = gr.Slider(15, 40, step=0.1, label="Body Mass Index (BMI) in kg/m2")
93
  glucose = gr.Slider(50, 250, step=1, label="Fasting Glucose Level")
94
 
95
+ # Display disclaimer in red uppercase letters above the scale
96
+ gr.HTML("<div style='text-align: center; color: red; font-weight: bold; font-size: 16px;'>RESULTS ARE NOT A SUBSTITUTE FOR ADVICE OF QUALIFIED MEDICAL PROFESSIONAL</div>")
97
+
98
+ # Center-aligned prediction output and semicircular scale
99
  with gr.Row():
100
+ scale_output = gr.Image(label="Risk Indicator Scale")
101
 
102
  with gr.Row():
103
  prediction_output = gr.Textbox(label="", interactive=False, elem_id="prediction_output")
 
107
 
108
  # Link inputs and prediction output
109
  submit_btn = gr.Button("Submit")
110
+ submit_btn.click(fn=predict_heart_attack, inputs=[Gender, age, cigsPerDay, BPMeds, prevalentHyp, diabetes, totChol, sysBP, diaBP, BMI, heartRate, glucose], outputs=[prediction_output, scale_output, shap_plot_output])
111
 
112
+ app.launch(share=True)