FranciscoLozDataScience commited on
Commit
fabb2c0
·
1 Parent(s): aa94aff

cleaned up code

Browse files
Files changed (2) hide show
  1. app.py +1 -76
  2. model.py +9 -13
app.py CHANGED
@@ -18,17 +18,6 @@ def predict(
18
  '''
19
  Predict the label for the data inputed
20
  '''
21
- # # Combine the input data into a NumPy array
22
- # input_array = np.array([
23
- # age, height, weight,
24
- # waist, eye_L, eye_R,
25
- # hear_L, hear_R, systolic,
26
- # relaxation, fasting_blood_sugar, cholesterol,
27
- # triglyceride, HDL, LDL,
28
- # hemoglobin, urine_protein,
29
- # serum_creatinine, AST, ALT,
30
- # Gtp, dental_caries
31
- # ])
32
 
33
  # Create a dictionary with input data and dataset var names
34
  input_data = {
@@ -56,11 +45,10 @@ def predict(
56
  "dental caries": dental_caries
57
  }
58
 
59
- # Convert the dictionary to a pandas DataFrame
60
  input_df = pd.DataFrame(input_data, index=[0])
61
 
62
  #predict
63
- # label = MODEL.predict(input_array)
64
  label = MODEL.predict(input_df)
65
 
66
  return label
@@ -82,18 +70,8 @@ def load_interface():
82
  Configure Gradio interface
83
  '''
84
 
85
- #example inputs
86
- ex=[ #TODO: delete if file works
87
- [20,85,135,190,30,125,53,126,0.1,9.9,0.1,9.9,1,2,1,2,79,240,40,140,55,505,72,371,16,405,4,618,1,1660,4.9,20.9,1,6,0.1,10.3,6,1311,1,2062,1,999,0,1],
88
- [40,170,65,75.1,1.0,0.9,1,1,120,70,102,225,260,41,132,15.7,1,0.8,24,26,32,0,45,170,75,89.0,0.7,1.2,1,1,100,67,96,258,345,49,140,15.7,1,1.1,26,28,138,0,30],
89
- [180,90,94.0,1.0,0.8,1,1,115,72,88,177,103,53,103,13.5,1,1.0,19,29,30,0,60,170,65,78.0,1.5,1.0,1,1,110,70,87,190,210,45,103,14.7,1,0.8,21,21,19,0,55],
90
- [175,60,75.0,1.0,1.0,1,1,100,64,93,186,80,86,84,15.4,3,1.0,39,20,35,0,40,160,55,69.0,1.5,1.5,1,1,112,78,90,177,68,78,85,12.4,1,0.5,15,9,14,0,55],
91
- [175,60,80.0,1.2,1.5,1.5,1,1,137,89,80,199,35,68,124,16.0,1,1.1,23,19,17,0,55,160,50,68.0,0.8,0.5,1,1,137,87,90,176,36,67,102,13.6,1,0.7,15,14,13,0]
92
- ]
93
-
94
  #set blocks
95
  info_page = gr.Blocks()
96
- # model_page = gr.Blocks()
97
 
98
  with info_page:
99
  # set title and description
@@ -141,59 +119,6 @@ def load_interface():
141
  """
142
  )
143
 
144
- # with model_page:
145
- # # set title and description
146
- # gr.Markdown(
147
- # """
148
- # # Interact with the Ensemble Classifier Model
149
- # Enter sample bio data to predict smoking status.\n
150
- # **Medical Disclaimer**: The predictions provided by this model are for educational purposes only and should not be considered a substitute for professional medical advice.
151
- # """)
152
-
153
- # #set inputs in rows of 3
154
- # with gr.Row():
155
- # age = gr.Number(label="Age", precision=0, minimum=0)
156
- # height = gr.Number(label="Height(cm)", precision=0, minimum=0)
157
- # weight = gr.Number(label="Weight(kg)", precision=0, minimum=0)
158
- # with gr.Row():
159
- # waist = gr.Number(label="Waist(cm)", minimum=0, info="Waist circumference length")
160
- # eye_L = gr.Number(label="Visual acuity of the left eye, measured in diopters (D)", minimum=0)
161
- # eye_R = gr.Number(label="Visual acuity of the right eye, measured in diopters (D)", minimum=0)
162
- # with gr.Row():
163
- # hear_L = gr.Radio(label="Is there any hearing ability in the left ear?",choices=[("Yes",1),("No",2)])
164
- # hear_R = gr.Radio(label="Is there any hearing ability in the right ear?",choices=[("Yes",1),("No",2)])
165
- # systolic = gr.Number(label="Systolic(mmHg)", precision=0, minimum=0, info="Blood Pressure")
166
- # with gr.Row():
167
- # relaxation = gr.Number(label="Relaxation(mmHg)", precision=0, minimum=0, info="Blood Pressure")
168
- # fasting_blood_sugar = gr.Number(label="Fasting Blood Sugar(mg/dL)", precision=0, minimum=0, info="the concentration of glucose (sugar) in the bloodstream after an extended period of fasting")
169
- # cholesterol = gr.Number(label="Total Cholesterol(mg/dL)", precision=0, minimum=0, info="Total amount of cholesterol present in the blood")
170
- # with gr.Row():
171
- # triglyceride = gr.Number(label="Triglyceride(mg/dL)", precision=0, minimum=0, info="A type of fat (lipid) found in blood")
172
- # HDL = gr.Number(label="High-Density Lipoprotein(mg/dL) ", precision=0, minimum=0, info="It is commonly referred to as 'good cholesterol'")
173
- # LDL = gr.Number(label="Low-Density Lipoprotein(mg/dL) ", precision=0, minimum=0, info="It is commonly referred to as 'bad cholesterol'")
174
- # with gr.Row():
175
- # hemoglobin = gr.Number(label="Hemoglobin(g/dL)", minimum=0, info="a protein found in red blood cells that is responsible for carrying oxygen from the lungs to the tissues and organs of the body")
176
- # urine_protein = gr.Radio(label="Does urine contain excessive traces of protein?",choices=[("Yes",2),("No",1)], info="when excessive protein is detected in the urine, it may indicate a problem with kidney function or other underlying health conditions.")
177
- # serum_creatinine = gr.Number(label="Serum creatinine(mg/dL)", minimum=0, info="Serum creatinine levels are commonly measured through a blood test and are used to assess kidney function")
178
- # with gr.Row():
179
- # AST = gr.Number(label="Aspartate Aminotransferase(IU/L)", precision=0, minimum=0, info="glutamic oxaloacetic transaminase type; AST is released into the bloodstream when cells are damaged or destroyed, such as during injury or disease affecting organs rich in AST.")
180
- # ALT = gr.Number(label="Alanine Aminotransferase(IU/L)", precision=0, minimum=0, info="glutamic oxaloacetic transaminase type; ALT is primarily found in the liver cells, and increased levels of ALT in the blood can indicate liver damage or disease")
181
- # Gtp = gr.Number(label="Gamma-glutamyl Transferase(IU/L)", precision=0, minimum=0, info="Elevated levels of GGT in the blood can indicate liver disease or bile duct obstruction. GGT levels are often measured alongside other liver function tests to assess liver health and function.")
182
- # dental_caries = gr.Radio(label="Are there any signs of dental cavities?",choices=[("Yes",1),("No",0)])
183
-
184
- # #set button row
185
- # with gr.Row():
186
- # pred_btn = gr.Button("Predict")
187
- # clear_btn = gr.Button("Clear")
188
-
189
- # #set label txt box
190
- # smoker_label = gr.Label(label="Predicted Label")
191
-
192
- # #set event listeners
193
- # inputs = [age, height, weight, waist, eye_L, eye_R, hear_L, hear_R, systolic, relaxation, fasting_blood_sugar, cholesterol, triglyceride, HDL, LDL, hemoglobin, urine_protein, serum_creatinine, AST, ALT, Gtp, dental_caries]
194
- # pred_btn.click(fn=predict, inputs=inputs, outputs=smoker_label)
195
- # clear_btn.click(lambda: [None]*22, outputs=inputs)
196
-
197
  age = gr.Number(label="Age", precision=0, minimum=0)
198
  height = gr.Number(label="Height(cm)", precision=0, minimum=0)
199
  weight = gr.Number(label="Weight(kg)", precision=0, minimum=0)
 
18
  '''
19
  Predict the label for the data inputed
20
  '''
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  # Create a dictionary with input data and dataset var names
23
  input_data = {
 
45
  "dental caries": dental_caries
46
  }
47
 
48
+ # Convert to DataFrame
49
  input_df = pd.DataFrame(input_data, index=[0])
50
 
51
  #predict
 
52
  label = MODEL.predict(input_df)
53
 
54
  return label
 
70
  Configure Gradio interface
71
  '''
72
 
 
 
 
 
 
 
 
 
 
73
  #set blocks
74
  info_page = gr.Blocks()
 
75
 
76
  with info_page:
77
  # set title and description
 
119
  """
120
  )
121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  age = gr.Number(label="Age", precision=0, minimum=0)
123
  height = gr.Number(label="Height(cm)", precision=0, minimum=0)
124
  weight = gr.Number(label="Weight(kg)", precision=0, minimum=0)
model.py CHANGED
@@ -9,6 +9,14 @@ from sklearn.neighbors import KNeighborsClassifier
9
  from sklearn.svm import SVC
10
 
11
  class SmokerModel:
 
 
 
 
 
 
 
 
12
  def __init__(self, model_path, scaler_path):
13
  self.model = load(model_path)
14
  self.scaler = load(scaler_path)
@@ -31,7 +39,7 @@ class SmokerModel:
31
 
32
  return new_data_scaled
33
 
34
- def predict(self, X: np.ndarray) -> str: #TODO: change type to pd df
35
  """
36
  Make a prediction on one sample using the loaded model.
37
 
@@ -47,22 +55,10 @@ class SmokerModel:
47
  # scale the data
48
  X_scaled = self.scale(X)
49
 
50
- # Check if the array is 1-dimensional aka one sample
51
- # if len(X_scaled.shape) != 1:
52
- # raise ValueError("Input array must be one-dimensional (one sample), but got a shape of {}".format(X.shape))
53
- # return
54
-
55
  #check array only has one sample
56
  if X.shape[0] != 1:
57
  raise ValueError("Input array must contain only one sample, but {} samples were found".format(X.shape[0]))
58
  return
59
-
60
- # Reshape the array
61
- # X = X.reshape(1, -1)
62
- # X_scaled = X_scaled.reshape(1, -1)
63
-
64
- # # scale the data
65
- # X_scaled = self.scale(X)
66
 
67
  # Now, use the scaled data to make predictions using the loaded model
68
  array = self.model.predict(X_scaled)
 
9
  from sklearn.svm import SVC
10
 
11
  class SmokerModel:
12
+ """
13
+ Smoker Model Class that can predict new instances
14
+
15
+ INPUTS
16
+ ---
17
+ model_path: the path to the model file
18
+ scaler_path: the path to the min max scaler file
19
+ """
20
  def __init__(self, model_path, scaler_path):
21
  self.model = load(model_path)
22
  self.scaler = load(scaler_path)
 
39
 
40
  return new_data_scaled
41
 
42
+ def predict(self, X: pd.DataFrame) -> str:
43
  """
44
  Make a prediction on one sample using the loaded model.
45
 
 
55
  # scale the data
56
  X_scaled = self.scale(X)
57
 
 
 
 
 
 
58
  #check array only has one sample
59
  if X.shape[0] != 1:
60
  raise ValueError("Input array must contain only one sample, but {} samples were found".format(X.shape[0]))
61
  return
 
 
 
 
 
 
 
62
 
63
  # Now, use the scaled data to make predictions using the loaded model
64
  array = self.model.predict(X_scaled)