AashishNKumar commited on
Commit
5e10f6c
·
0 Parent(s):

initial commit

Browse files
Files changed (4) hide show
  1. .gitignore +2 -0
  2. heart_disease_model.json +0 -0
  3. main.py +127 -0
  4. requirements.txt +24 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .venv
2
+ .ropeproject
heart_disease_model.json ADDED
The diff for this file is too large to render. See raw diff
 
main.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import xgboost as xgb
4
+ from ucimlrepo import fetch_ucirepo
5
+ from sklearn.impute import SimpleImputer
6
+ from sklearn.preprocessing import StandardScaler
7
+ from imblearn.over_sampling import SMOTE
8
+ import os
9
+
10
+ # Paths for saving/loading the model
11
+ MODEL_PATH = "heart_disease_model.json"
12
+
13
+ # Load and preprocess the data
14
+ heart_disease = fetch_ucirepo(id=45)
15
+ X = heart_disease.data.features
16
+ y = np.ravel(heart_disease.data.targets)
17
+
18
+ imputer = SimpleImputer(strategy="mean")
19
+ X = imputer.fit_transform(X)
20
+ scaler = StandardScaler()
21
+ X = scaler.fit_transform(X)
22
+ smote = SMOTE(random_state=42)
23
+ X_resampled, y_resampled = smote.fit_resample(X, y)
24
+
25
+ # Train or load the model
26
+ if os.path.exists(MODEL_PATH):
27
+ # Load pre-trained model
28
+ model = xgb.Booster()
29
+ model.load_model(MODEL_PATH)
30
+ else:
31
+ # Train the model
32
+ dtrain = xgb.DMatrix(X_resampled, label=y_resampled)
33
+ param_grid = {
34
+ "objective": "multi:softmax",
35
+ "num_class": len(np.unique(y_resampled)),
36
+ "eval_metric": "mlogloss",
37
+ "learning_rate": 0.1,
38
+ "max_depth": 5,
39
+ "subsample": 0.8,
40
+ "colsample_bytree": 0.8,
41
+ }
42
+ model = xgb.train(params=param_grid, dtrain=dtrain, num_boost_round=100)
43
+ # Save the model
44
+ model.save_model(MODEL_PATH)
45
+
46
+
47
+ # Define prediction function
48
+ def predict(
49
+ age, sex, cp, trestbps, chol, fbs, restecg, thalach, exang, oldpeak, slope, ca, thal
50
+ ):
51
+ # Convert string values to numeric where needed
52
+ sex = int(sex) # Convert '1: Male' -> 1
53
+ cp = int(cp)
54
+ fbs = int(fbs)
55
+ restecg = int(restecg)
56
+ exang = int(exang)
57
+ slope = int(slope)
58
+ thal = int(thal)
59
+
60
+ # Combine inputs into a single feature list
61
+ features = np.array(
62
+ [
63
+ age,
64
+ sex,
65
+ cp,
66
+ trestbps,
67
+ chol,
68
+ fbs,
69
+ restecg,
70
+ thalach,
71
+ exang,
72
+ oldpeak,
73
+ slope,
74
+ ca,
75
+ thal,
76
+ ]
77
+ ).reshape(1, -1)
78
+
79
+ # Preprocess the inputs
80
+ features = scaler.transform(imputer.transform(features))
81
+
82
+ # Predict using the trained model
83
+ dmatrix = xgb.DMatrix(features)
84
+ prediction = model.predict(dmatrix)
85
+ return int(prediction[0])
86
+
87
+
88
+ # Update Gradio interface to return numeric values
89
+ feature_inputs = [
90
+ gr.Number(label="Age (years)"),
91
+ gr.Radio(label="Sex", choices=["0", "1"], type="value"), # Male: 1, Female: 0
92
+ gr.Radio(label="Chest Pain Type (cp)", choices=["0", "1", "2", "3"], type="value"),
93
+ gr.Number(label="Resting Blood Pressure (mm Hg)"),
94
+ gr.Number(label="Serum Cholestoral (mg/dl)"),
95
+ gr.Radio(
96
+ label="Fasting Blood Sugar > 120 mg/dl (fbs)", choices=["0", "1"], type="value"
97
+ ),
98
+ gr.Radio(
99
+ label="Resting ECG Results (restecg)", choices=["0", "1", "2"], type="value"
100
+ ),
101
+ gr.Number(label="Maximum Heart Rate Achieved (thalach)"),
102
+ gr.Radio(label="Exercise Induced Angina (exang)", choices=["0", "1"], type="value"),
103
+ gr.Number(label="ST Depression Induced by Exercise (oldpeak)"),
104
+ gr.Radio(
105
+ label="Slope of the Peak Exercise ST Segment (slope)",
106
+ choices=["0", "1", "2"],
107
+ type="value",
108
+ ),
109
+ gr.Number(label="Number of Major Vessels Colored by Fluoroscopy (ca)"),
110
+ gr.Radio(label="Thalassemia (thal)", choices=["0", "1", "2", "3"], type="value"),
111
+ ]
112
+
113
+
114
+ # Define the Gradio interface
115
+ interface = gr.Interface(
116
+ fn=predict,
117
+ inputs=feature_inputs,
118
+ outputs="label",
119
+ title="Heart Disease Prediction",
120
+ description=(
121
+ "Predicts heart disease based on patient information. "
122
+ "Provide the required features to get a diagnosis prediction."
123
+ ),
124
+ )
125
+
126
+ if __name__ == "__main__":
127
+ interface.launch()
requirements.txt ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ certifi==2024.12.14
2
+ contourpy==1.3.1
3
+ cycler==0.12.1
4
+ fonttools==4.55.3
5
+ imbalanced-learn==0.13.0
6
+ joblib==1.4.2
7
+ kiwisolver==1.4.8
8
+ matplotlib==3.10.0
9
+ numpy==2.2.1
10
+ nvidia-nccl-cu12==2.23.4
11
+ packaging==24.2
12
+ pandas==2.2.3
13
+ pillow==11.0.0
14
+ pyparsing==3.2.0
15
+ python-dateutil==2.9.0.post0
16
+ pytz==2024.2
17
+ scikit-learn==1.6.0
18
+ scipy==1.14.1
19
+ six==1.17.0
20
+ sklearn-compat==0.1.3
21
+ threadpoolctl==3.5.0
22
+ tzdata==2024.2
23
+ ucimlrepo==0.0.7
24
+ xgboost==2.1.3