Commit
·
4b2a917
1
Parent(s):
2f0ea00
update model to increase accuracy
Browse files
main.py
CHANGED
@@ -1,10 +1,18 @@
|
|
1 |
import gradio as gr
|
2 |
import numpy as np
|
3 |
import xgboost as xgb
|
4 |
-
from ucimlrepo import fetch_ucirepo
|
5 |
from sklearn.impute import SimpleImputer
|
6 |
from sklearn.preprocessing import StandardScaler
|
7 |
from imblearn.over_sampling import SMOTE
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
import os
|
9 |
|
10 |
# Paths for saving/loading the model
|
@@ -15,6 +23,7 @@ heart_disease = fetch_ucirepo(id=45)
|
|
15 |
X = heart_disease.data.features
|
16 |
y = np.ravel(heart_disease.data.targets)
|
17 |
|
|
|
18 |
imputer = SimpleImputer(strategy="mean")
|
19 |
X = imputer.fit_transform(X)
|
20 |
scaler = StandardScaler()
|
@@ -22,26 +31,58 @@ X = scaler.fit_transform(X)
|
|
22 |
smote = SMOTE(random_state=42)
|
23 |
X_resampled, y_resampled = smote.fit_resample(X, y)
|
24 |
|
|
|
|
|
|
|
|
|
|
|
25 |
# Train or load the model
|
26 |
if os.path.exists(MODEL_PATH):
|
27 |
# Load pre-trained model
|
28 |
model = xgb.Booster()
|
29 |
model.load_model(MODEL_PATH)
|
30 |
else:
|
31 |
-
#
|
32 |
-
dtrain = xgb.DMatrix(X_resampled, label=y_resampled)
|
33 |
param_grid = {
|
34 |
-
"
|
35 |
-
"
|
36 |
-
"
|
37 |
-
"
|
38 |
-
"
|
39 |
-
"
|
40 |
-
"
|
|
|
41 |
}
|
42 |
-
|
43 |
-
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
|
47 |
# Define prediction function
|
@@ -49,7 +90,7 @@ def predict(
|
|
49 |
age, sex, cp, trestbps, chol, fbs, restecg, thalach, exang, oldpeak, slope, ca, thal
|
50 |
):
|
51 |
# Convert string values to numeric where needed
|
52 |
-
sex = int(sex)
|
53 |
cp = int(cp)
|
54 |
fbs = int(fbs)
|
55 |
restecg = int(restecg)
|
@@ -85,7 +126,7 @@ def predict(
|
|
85 |
return int(prediction[0])
|
86 |
|
87 |
|
88 |
-
#
|
89 |
feature_inputs = [
|
90 |
gr.Number(label="Age (years)"),
|
91 |
gr.Radio(label="Sex", choices=["0", "1"], type="value"), # Male: 1, Female: 0
|
@@ -110,8 +151,6 @@ feature_inputs = [
|
|
110 |
gr.Radio(label="Thalassemia (thal)", choices=["0", "1", "2", "3"], type="value"),
|
111 |
]
|
112 |
|
113 |
-
|
114 |
-
# Define the Gradio interface
|
115 |
interface = gr.Interface(
|
116 |
fn=predict,
|
117 |
inputs=feature_inputs,
|
|
|
1 |
import gradio as gr
|
2 |
import numpy as np
|
3 |
import xgboost as xgb
|
|
|
4 |
from sklearn.impute import SimpleImputer
|
5 |
from sklearn.preprocessing import StandardScaler
|
6 |
from imblearn.over_sampling import SMOTE
|
7 |
+
from sklearn.model_selection import train_test_split, GridSearchCV
|
8 |
+
from sklearn.metrics import (
|
9 |
+
accuracy_score,
|
10 |
+
precision_score,
|
11 |
+
recall_score,
|
12 |
+
f1_score,
|
13 |
+
confusion_matrix,
|
14 |
+
)
|
15 |
+
from ucimlrepo import fetch_ucirepo
|
16 |
import os
|
17 |
|
18 |
# Paths for saving/loading the model
|
|
|
23 |
X = heart_disease.data.features
|
24 |
y = np.ravel(heart_disease.data.targets)
|
25 |
|
26 |
+
# Preprocessing pipeline
|
27 |
imputer = SimpleImputer(strategy="mean")
|
28 |
X = imputer.fit_transform(X)
|
29 |
scaler = StandardScaler()
|
|
|
31 |
smote = SMOTE(random_state=42)
|
32 |
X_resampled, y_resampled = smote.fit_resample(X, y)
|
33 |
|
34 |
+
# Train-test split
|
35 |
+
X_train, X_test, y_train, y_test = train_test_split(
|
36 |
+
X_resampled, y_resampled, test_size=0.2, random_state=42, stratify=y_resampled
|
37 |
+
)
|
38 |
+
|
39 |
# Train or load the model
|
40 |
if os.path.exists(MODEL_PATH):
|
41 |
# Load pre-trained model
|
42 |
model = xgb.Booster()
|
43 |
model.load_model(MODEL_PATH)
|
44 |
else:
|
45 |
+
# Hyperparameter tuning
|
|
|
46 |
param_grid = {
|
47 |
+
"max_depth": [4, 5, 6],
|
48 |
+
"learning_rate": [0.01, 0.05, 0.1],
|
49 |
+
"n_estimators": [100, 200, 300],
|
50 |
+
"subsample": [0.8, 1.0],
|
51 |
+
"colsample_bytree": [0.8, 1.0],
|
52 |
+
"gamma": [0, 1, 5],
|
53 |
+
"lambda": [1, 2, 3],
|
54 |
+
"alpha": [0, 1],
|
55 |
}
|
56 |
+
|
57 |
+
model = xgb.XGBClassifier(use_label_encoder=False, eval_metric="mlogloss")
|
58 |
+
grid_search = GridSearchCV(
|
59 |
+
estimator=model, param_grid=param_grid, scoring="accuracy", cv=5, verbose=1
|
60 |
+
)
|
61 |
+
grid_search.fit(X_train, y_train)
|
62 |
+
|
63 |
+
# Best model
|
64 |
+
best_model = grid_search.best_estimator_
|
65 |
+
best_model.save_model(MODEL_PATH)
|
66 |
+
|
67 |
+
# Load the best model
|
68 |
+
model = xgb.Booster()
|
69 |
+
model.load_model(MODEL_PATH)
|
70 |
+
|
71 |
+
# Evaluate model
|
72 |
+
X_test_dmatrix = xgb.DMatrix(X_test)
|
73 |
+
y_pred = model.predict(X_test_dmatrix)
|
74 |
+
accuracy = accuracy_score(y_test, y_pred)
|
75 |
+
precision = precision_score(y_test, y_pred, average="weighted")
|
76 |
+
recall = recall_score(y_test, y_pred, average="weighted")
|
77 |
+
f1 = f1_score(y_test, y_pred, average="weighted")
|
78 |
+
conf_matrix = confusion_matrix(y_test, y_pred)
|
79 |
+
|
80 |
+
print(f"Accuracy: {accuracy * 100:.2f}%")
|
81 |
+
print(f"Precision: {precision:.2f}")
|
82 |
+
print(f"Recall: {recall:.2f}")
|
83 |
+
print(f"F1 Score: {f1:.2f}")
|
84 |
+
print("Confusion Matrix:")
|
85 |
+
print(conf_matrix)
|
86 |
|
87 |
|
88 |
# Define prediction function
|
|
|
90 |
age, sex, cp, trestbps, chol, fbs, restecg, thalach, exang, oldpeak, slope, ca, thal
|
91 |
):
|
92 |
# Convert string values to numeric where needed
|
93 |
+
sex = int(sex)
|
94 |
cp = int(cp)
|
95 |
fbs = int(fbs)
|
96 |
restecg = int(restecg)
|
|
|
126 |
return int(prediction[0])
|
127 |
|
128 |
|
129 |
+
# Gradio interface
|
130 |
feature_inputs = [
|
131 |
gr.Number(label="Age (years)"),
|
132 |
gr.Radio(label="Sex", choices=["0", "1"], type="value"), # Male: 1, Female: 0
|
|
|
151 |
gr.Radio(label="Thalassemia (thal)", choices=["0", "1", "2", "3"], type="value"),
|
152 |
]
|
153 |
|
|
|
|
|
154 |
interface = gr.Interface(
|
155 |
fn=predict,
|
156 |
inputs=feature_inputs,
|