Spaces:
Sleeping
Sleeping
File size: 5,016 Bytes
3d119e0 d7a3122 3d119e0 d7a3122 3d119e0 d7a3122 3d119e0 d7a3122 3d119e0 d7a3122 3d119e0 d7a3122 3d119e0 d7a3122 3d119e0 d7a3122 3d119e0 d7a3122 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
import gradio as gr
import xgboost as xgb
import joblib
import numpy as np
from sklearn.preprocessing import StandardScaler
import pandas as pd
import shap
import matplotlib.pyplot as plt
# Load the model and the scaler
model = joblib.load('best_XGB.pkl')
scaler = joblib.load('scaler.pkl') # Ensure the scaler is saved and loaded with the same scikit-learn version
cutoff = 42 # Custom cutoff probability
# Use TreeExplainer for XGBoost models
explainer = shap.TreeExplainer(model)
# Define the function to draw the semicircular scale
def draw_scale(probability):
fig, ax = plt.subplots(figsize=(6, 3))
# Plot the semicircular scale
ax.barh(0, 1, color='green', left=0, height=0.3)
ax.barh(0, 1, color='red', left=cutoff / 100, height=0.3)
# Add arrow indicator based on predicted probability
arrow_position = probability / 100
color = 'green' if probability < cutoff else 'red'
ax.annotate('', xy=(arrow_position, 0.15), xytext=(arrow_position, 0.3),
arrowprops=dict(facecolor=color, shrink=0.05))
# Remove axes and add labels
ax.set_xlim(0, 1)
ax.set_ylim(-0.5, 0.5)
ax.axis('off')
# Save the image
plt.savefig('scale_plot.png', bbox_inches='tight')
plt.close()
# Define the prediction function with preprocessing, scaling, and SHAP analysis
def predict_heart_attack(Gender, age, cigsPerDay, BPMeds, prevalentHyp, diabetes, totChol, sysBP, diaBP, BMI, heartRate, glucose):
# Define feature names in the same order as the training data
feature_names = ['Gender', 'age', 'cigsPerDay', 'BPMeds', 'prevalentHyp', 'diabetes', 'totChol', 'sysBP', 'diaBP', 'BMI', 'heartRate', 'glucose']
# Create a DataFrame with the correct feature names for prediction
features = pd.DataFrame([[Gender, age, cigsPerDay, BPMeds, prevalentHyp, diabetes, totChol, sysBP, diaBP, BMI, heartRate, glucose]], columns=feature_names)
# Standardize the features (scaling)
scaled_features = scaler.transform(features)
# Predict probabilities
proba = model.predict_proba(scaled_features)[:, 1] # Probability of class 1 (heart attack)
# Apply custom cutoff
if proba[0] * 100 >= cutoff:
prediction_class = "Heart_Attack_Risk.Consult your doctor"
else:
prediction_class = "No_Heart_Attack_Risk.Still make regular checkup"
# Generate SHAP values for the prediction using the explainer
shap_values = explainer(features)
# Plot SHAP values
plt.figure(figsize=(8, 6))
shap.waterfall_plot(shap_values[0]) # Using the SHAP Explanation object
plt.savefig('shap_plot.png') # Save SHAP plot to a file
# Draw semicircular scale
draw_scale(proba[0] * 100)
result = f"Predicted Probability: {proba[0] * 100:.2f}%. Predicted Class with cutoff {cutoff}%: {prediction_class}"
return result, 'scale_plot.png', 'shap_plot.png' # Return the prediction, scale, and SHAP plot
# Create the Gradio interface with preprocessing, prediction, and SHAP visualization
with gr.Blocks() as app:
with gr.Row():
with gr.Column():
Gender = gr.Radio([0, 1], label="Gender (0=Female, 1=Male)")
cigsPerDay = gr.Slider(0, 40, step=1, label="Cigarettes per Day")
prevalentHyp = gr.Radio([0, 1], label="Prevalent Hypertension (0=No, 1=Yes)")
totChol = gr.Slider(100, 400, step=1, label="Total Cholesterol in mg/dl")
diaBP = gr.Slider(60, 120, step=1, label="Diastolic/Lower BP")
heartRate = gr.Slider(50, 120, step=1, label="Heart Rate")
with gr.Column():
age = gr.Slider(20, 80, step=1, label="Age (years)")
BPMeds = gr.Radio([0, 1], label="On BP Medications (0=No, 1=Yes)")
diabetes = gr.Radio([0, 1], label="Diabetes (0=No, 1=Yes)")
sysBP = gr.Slider(90, 200, step=1, label="Systolic BP/Higher BP")
BMI = gr.Slider(15, 40, step=0.1, label="Body Mass Index (BMI) in kg/m2")
glucose = gr.Slider(50, 250, step=1, label="Fasting Glucose Level")
# Display disclaimer in red uppercase letters above the scale
gr.HTML("<div style='text-align: center; color: red; font-weight: bold; font-size: 16px;'>RESULTS ARE NOT A SUBSTITUTE FOR ADVICE OF QUALIFIED MEDICAL PROFESSIONAL</div>")
# Center-aligned prediction output and semicircular scale
with gr.Row():
scale_output = gr.Image(label="Risk Indicator Scale")
with gr.Row():
prediction_output = gr.Textbox(label="", interactive=False, elem_id="prediction_output")
with gr.Row():
shap_plot_output = gr.Image(label="SHAP Analysis")
# Link inputs and prediction output
submit_btn = gr.Button("Submit")
submit_btn.click(fn=predict_heart_attack, inputs=[Gender, age, cigsPerDay, BPMeds, prevalentHyp, diabetes, totChol, sysBP, diaBP, BMI, heartRate, glucose], outputs=[prediction_output, scale_output, shap_plot_output])
app.launch(share=True)
|