import gradio as gr import numpy as np import xgboost as xgb from ucimlrepo import fetch_ucirepo from sklearn.impute import SimpleImputer from sklearn.preprocessing import StandardScaler from imblearn.over_sampling import SMOTE import os # Paths for saving/loading the model MODEL_PATH = "heart_disease_model.json" # Load and preprocess the data heart_disease = fetch_ucirepo(id=45) X = heart_disease.data.features y = np.ravel(heart_disease.data.targets) imputer = SimpleImputer(strategy="mean") X = imputer.fit_transform(X) scaler = StandardScaler() X = scaler.fit_transform(X) smote = SMOTE(random_state=42) X_resampled, y_resampled = smote.fit_resample(X, y) # Train or load the model if os.path.exists(MODEL_PATH): # Load pre-trained model model = xgb.Booster() model.load_model(MODEL_PATH) else: # Train the model dtrain = xgb.DMatrix(X_resampled, label=y_resampled) param_grid = { "objective": "multi:softmax", "num_class": len(np.unique(y_resampled)), "eval_metric": "mlogloss", "learning_rate": 0.1, "max_depth": 5, "subsample": 0.8, "colsample_bytree": 0.8, } model = xgb.train(params=param_grid, dtrain=dtrain, num_boost_round=100) # Save the model model.save_model(MODEL_PATH) # Define prediction function def predict( age, sex, cp, trestbps, chol, fbs, restecg, thalach, exang, oldpeak, slope, ca, thal ): # Convert string values to numeric where needed sex = int(sex) # Convert '1: Male' -> 1 cp = int(cp) fbs = int(fbs) restecg = int(restecg) exang = int(exang) slope = int(slope) thal = int(thal) # Combine inputs into a single feature list features = np.array( [ age, sex, cp, trestbps, chol, fbs, restecg, thalach, exang, oldpeak, slope, ca, thal, ] ).reshape(1, -1) # Preprocess the inputs features = scaler.transform(imputer.transform(features)) # Predict using the trained model dmatrix = xgb.DMatrix(features) prediction = model.predict(dmatrix) return int(prediction[0]) # Update Gradio interface to return numeric values feature_inputs = [ gr.Number(label="Age (years)"), gr.Radio(label="Sex", choices=["0", "1"], type="value"), # Male: 1, Female: 0 gr.Radio(label="Chest Pain Type (cp)", choices=["0", "1", "2", "3"], type="value"), gr.Number(label="Resting Blood Pressure (mm Hg)"), gr.Number(label="Serum Cholestoral (mg/dl)"), gr.Radio( label="Fasting Blood Sugar > 120 mg/dl (fbs)", choices=["0", "1"], type="value" ), gr.Radio( label="Resting ECG Results (restecg)", choices=["0", "1", "2"], type="value" ), gr.Number(label="Maximum Heart Rate Achieved (thalach)"), gr.Radio(label="Exercise Induced Angina (exang)", choices=["0", "1"], type="value"), gr.Number(label="ST Depression Induced by Exercise (oldpeak)"), gr.Radio( label="Slope of the Peak Exercise ST Segment (slope)", choices=["0", "1", "2"], type="value", ), gr.Number(label="Number of Major Vessels Colored by Fluoroscopy (ca)"), gr.Radio(label="Thalassemia (thal)", choices=["0", "1", "2", "3"], type="value"), ] # Define the Gradio interface interface = gr.Interface( fn=predict, inputs=feature_inputs, outputs="label", title="Heart Disease Prediction", description=( "Predicts heart disease based on patient information. " "Provide the required features to get a diagnosis prediction." ), ) if __name__ == "__main__": interface.launch()