|
import gradio as gr |
|
import numpy as np |
|
import xgboost as xgb |
|
from ucimlrepo import fetch_ucirepo |
|
from sklearn.impute import SimpleImputer |
|
from sklearn.preprocessing import StandardScaler |
|
from imblearn.over_sampling import SMOTE |
|
import os |
|
|
|
|
|
MODEL_PATH = "heart_disease_model.json" |
|
|
|
|
|
heart_disease = fetch_ucirepo(id=45) |
|
X = heart_disease.data.features |
|
y = np.ravel(heart_disease.data.targets) |
|
|
|
imputer = SimpleImputer(strategy="mean") |
|
X = imputer.fit_transform(X) |
|
scaler = StandardScaler() |
|
X = scaler.fit_transform(X) |
|
smote = SMOTE(random_state=42) |
|
X_resampled, y_resampled = smote.fit_resample(X, y) |
|
|
|
|
|
if os.path.exists(MODEL_PATH): |
|
|
|
model = xgb.Booster() |
|
model.load_model(MODEL_PATH) |
|
else: |
|
|
|
dtrain = xgb.DMatrix(X_resampled, label=y_resampled) |
|
param_grid = { |
|
"objective": "multi:softmax", |
|
"num_class": len(np.unique(y_resampled)), |
|
"eval_metric": "mlogloss", |
|
"learning_rate": 0.1, |
|
"max_depth": 5, |
|
"subsample": 0.8, |
|
"colsample_bytree": 0.8, |
|
} |
|
model = xgb.train(params=param_grid, dtrain=dtrain, num_boost_round=100) |
|
|
|
model.save_model(MODEL_PATH) |
|
|
|
|
|
|
|
def predict( |
|
age, sex, cp, trestbps, chol, fbs, restecg, thalach, exang, oldpeak, slope, ca, thal |
|
): |
|
|
|
sex = int(sex) |
|
cp = int(cp) |
|
fbs = int(fbs) |
|
restecg = int(restecg) |
|
exang = int(exang) |
|
slope = int(slope) |
|
thal = int(thal) |
|
|
|
|
|
features = np.array( |
|
[ |
|
age, |
|
sex, |
|
cp, |
|
trestbps, |
|
chol, |
|
fbs, |
|
restecg, |
|
thalach, |
|
exang, |
|
oldpeak, |
|
slope, |
|
ca, |
|
thal, |
|
] |
|
).reshape(1, -1) |
|
|
|
|
|
features = scaler.transform(imputer.transform(features)) |
|
|
|
|
|
dmatrix = xgb.DMatrix(features) |
|
prediction = model.predict(dmatrix) |
|
return int(prediction[0]) |
|
|
|
|
|
|
|
feature_inputs = [ |
|
gr.Number(label="Age (years)"), |
|
gr.Radio(label="Sex", choices=["0", "1"], type="value"), |
|
gr.Radio(label="Chest Pain Type (cp)", choices=["0", "1", "2", "3"], type="value"), |
|
gr.Number(label="Resting Blood Pressure (mm Hg)"), |
|
gr.Number(label="Serum Cholestoral (mg/dl)"), |
|
gr.Radio( |
|
label="Fasting Blood Sugar > 120 mg/dl (fbs)", choices=["0", "1"], type="value" |
|
), |
|
gr.Radio( |
|
label="Resting ECG Results (restecg)", choices=["0", "1", "2"], type="value" |
|
), |
|
gr.Number(label="Maximum Heart Rate Achieved (thalach)"), |
|
gr.Radio(label="Exercise Induced Angina (exang)", choices=["0", "1"], type="value"), |
|
gr.Number(label="ST Depression Induced by Exercise (oldpeak)"), |
|
gr.Radio( |
|
label="Slope of the Peak Exercise ST Segment (slope)", |
|
choices=["0", "1", "2"], |
|
type="value", |
|
), |
|
gr.Number(label="Number of Major Vessels Colored by Fluoroscopy (ca)"), |
|
gr.Radio(label="Thalassemia (thal)", choices=["0", "1", "2", "3"], type="value"), |
|
] |
|
|
|
|
|
|
|
interface = gr.Interface( |
|
fn=predict, |
|
inputs=feature_inputs, |
|
outputs="label", |
|
title="Heart Disease Prediction", |
|
description=( |
|
"Predicts heart disease based on patient information. " |
|
"Provide the required features to get a diagnosis prediction." |
|
), |
|
) |
|
|
|
if __name__ == "__main__": |
|
interface.launch() |
|
|