gradio_app / app.py
minusquare's picture
Update app.py
d7a3122 verified
import gradio as gr
import xgboost as xgb
import joblib
import numpy as np
from sklearn.preprocessing import StandardScaler
import pandas as pd
import shap
import matplotlib.pyplot as plt
# Load the model and the scaler
model = joblib.load('best_XGB.pkl')
scaler = joblib.load('scaler.pkl') # Ensure the scaler is saved and loaded with the same scikit-learn version
cutoff = 42 # Custom cutoff probability
# Use TreeExplainer for XGBoost models
explainer = shap.TreeExplainer(model)
# Define the function to draw the semicircular scale
def draw_scale(probability):
fig, ax = plt.subplots(figsize=(6, 3))
# Plot the semicircular scale
ax.barh(0, 1, color='green', left=0, height=0.3)
ax.barh(0, 1, color='red', left=cutoff / 100, height=0.3)
# Add arrow indicator based on predicted probability
arrow_position = probability / 100
color = 'green' if probability < cutoff else 'red'
ax.annotate('', xy=(arrow_position, 0.15), xytext=(arrow_position, 0.3),
arrowprops=dict(facecolor=color, shrink=0.05))
# Remove axes and add labels
ax.set_xlim(0, 1)
ax.set_ylim(-0.5, 0.5)
ax.axis('off')
# Save the image
plt.savefig('scale_plot.png', bbox_inches='tight')
plt.close()
# Define the prediction function with preprocessing, scaling, and SHAP analysis
def predict_heart_attack(Gender, age, cigsPerDay, BPMeds, prevalentHyp, diabetes, totChol, sysBP, diaBP, BMI, heartRate, glucose):
# Define feature names in the same order as the training data
feature_names = ['Gender', 'age', 'cigsPerDay', 'BPMeds', 'prevalentHyp', 'diabetes', 'totChol', 'sysBP', 'diaBP', 'BMI', 'heartRate', 'glucose']
# Create a DataFrame with the correct feature names for prediction
features = pd.DataFrame([[Gender, age, cigsPerDay, BPMeds, prevalentHyp, diabetes, totChol, sysBP, diaBP, BMI, heartRate, glucose]], columns=feature_names)
# Standardize the features (scaling)
scaled_features = scaler.transform(features)
# Predict probabilities
proba = model.predict_proba(scaled_features)[:, 1] # Probability of class 1 (heart attack)
# Apply custom cutoff
if proba[0] * 100 >= cutoff:
prediction_class = "Heart_Attack_Risk.Consult your doctor"
else:
prediction_class = "No_Heart_Attack_Risk.Still make regular checkup"
# Generate SHAP values for the prediction using the explainer
shap_values = explainer(features)
# Plot SHAP values
plt.figure(figsize=(8, 6))
shap.waterfall_plot(shap_values[0]) # Using the SHAP Explanation object
plt.savefig('shap_plot.png') # Save SHAP plot to a file
# Draw semicircular scale
draw_scale(proba[0] * 100)
result = f"Predicted Probability: {proba[0] * 100:.2f}%. Predicted Class with cutoff {cutoff}%: {prediction_class}"
return result, 'scale_plot.png', 'shap_plot.png' # Return the prediction, scale, and SHAP plot
# Create the Gradio interface with preprocessing, prediction, and SHAP visualization
with gr.Blocks() as app:
with gr.Row():
with gr.Column():
Gender = gr.Radio([0, 1], label="Gender (0=Female, 1=Male)")
cigsPerDay = gr.Slider(0, 40, step=1, label="Cigarettes per Day")
prevalentHyp = gr.Radio([0, 1], label="Prevalent Hypertension (0=No, 1=Yes)")
totChol = gr.Slider(100, 400, step=1, label="Total Cholesterol in mg/dl")
diaBP = gr.Slider(60, 120, step=1, label="Diastolic/Lower BP")
heartRate = gr.Slider(50, 120, step=1, label="Heart Rate")
with gr.Column():
age = gr.Slider(20, 80, step=1, label="Age (years)")
BPMeds = gr.Radio([0, 1], label="On BP Medications (0=No, 1=Yes)")
diabetes = gr.Radio([0, 1], label="Diabetes (0=No, 1=Yes)")
sysBP = gr.Slider(90, 200, step=1, label="Systolic BP/Higher BP")
BMI = gr.Slider(15, 40, step=0.1, label="Body Mass Index (BMI) in kg/m2")
glucose = gr.Slider(50, 250, step=1, label="Fasting Glucose Level")
# Display disclaimer in red uppercase letters above the scale
gr.HTML("<div style='text-align: center; color: red; font-weight: bold; font-size: 16px;'>RESULTS ARE NOT A SUBSTITUTE FOR ADVICE OF QUALIFIED MEDICAL PROFESSIONAL</div>")
# Center-aligned prediction output and semicircular scale
with gr.Row():
scale_output = gr.Image(label="Risk Indicator Scale")
with gr.Row():
prediction_output = gr.Textbox(label="", interactive=False, elem_id="prediction_output")
with gr.Row():
shap_plot_output = gr.Image(label="SHAP Analysis")
# Link inputs and prediction output
submit_btn = gr.Button("Submit")
submit_btn.click(fn=predict_heart_attack, inputs=[Gender, age, cigsPerDay, BPMeds, prevalentHyp, diabetes, totChol, sysBP, diaBP, BMI, heartRate, glucose], outputs=[prediction_output, scale_output, shap_plot_output])
app.launch(share=True)