datasciencedojo commited on
Commit
22021ad
·
verified ·
1 Parent(s): 56edd62

upgraded code with newer version of gradio

Browse files
Files changed (1) hide show
  1. app.py +56 -61
app.py CHANGED
@@ -1,74 +1,53 @@
1
  import gradio as gr
2
  import pickle
3
- from sklearn import preprocessing
4
  import pandas as pd
5
 
 
6
  filename = 'knn_model.sav'
7
-
8
  loaded_model = pickle.load(open(filename, 'rb'))
9
 
10
-
11
-
12
  def hptension(hp):
13
- if hp == 'yes':
14
- return 1
15
- else:
16
- return 0
17
 
18
  def ht_dis(ht):
19
- if ht == 'yes':
20
- return 1
21
- else:
22
- return 0
23
-
24
 
25
  def gender_select(gen):
26
- if gen == 'male':
27
- return 1
28
- else:
29
- return 0
30
-
31
  def age_group_selector(age_grp):
32
- if age_grp == '0-16': return 0
33
- elif age_grp =='17-32': return 1
34
- elif age_grp =='33-48': return 2
35
- elif age_grp =='49-64': return 3
36
- else: return 4
37
-
38
 
39
  def smoker_cat(smoke):
40
- if smoke == 'formerly smoked': return 0
41
- elif smoke =='never smoked': return 1
42
- elif smoke =='smokes': return 2
43
- else: return 3
44
-
45
-
46
- def predict_insurance(input_gender,input_age_group,input_hypertension,input_heart_disease,input_avg_glucose_level,input_bmi,input_smoking_status):
47
-
48
- input_gender,input_age_group,input_hypertension,input_heart_disease,input_avg_glucose_level,input_bmi,input_smoking_status = input_gender,input_age_group,input_hypertension,input_heart_disease,input_avg_glucose_level,input_bmi,input_smoking_status
49
-
50
- series = {'gender': [gender_select(input_gender)],
51
- 'age_band': [age_group_selector(input_age_group)],
52
- 'hypertension': [hptension(input_hypertension)],
53
- 'heart_disease': [ht_dis(input_heart_disease)],
54
- 'avg_glucose_level': [input_avg_glucose_level /272],
55
- 'bmi': [input_bmi/49],
56
- 'smoking_status': [smoker_cat(input_smoking_status)],
57
- }
58
 
59
  vector = pd.DataFrame(series)
60
 
 
61
  result = loaded_model.predict(vector)
62
- if result[0] == 1:
63
- return "Risk of having stroke is high"
64
- else:
65
- return "Risk of having stroke is low"
66
 
 
67
  css = """
68
  footer {display:none !important}
69
  .output-markdown{display:none !important}
70
  footer {visibility: hidden}
71
-
72
  .gr-button-lg {
73
  z-index: 14;
74
  width: 113px;
@@ -107,25 +86,41 @@ footer {visibility: hidden}
107
  transition: box-shadow 200ms ease 0s, background 200ms ease 0s !important;
108
  box-shadow: rgb(0 0 0 / 23%) 0px 1px 7px 0px !important;
109
  }
110
-
111
  """
112
 
113
- with gr.Blocks(title="Brain Stroke Prediction | Data Science Dojo", css = css) as demo:
 
114
  with gr.Row():
115
- input_gender = gr.Radio(["male", "female"],label="Gender")
116
- input_hypertension = gr.Radio(["yes", "no"],label="Hypertension")
117
- input_heart_disease = gr.Radio(["yes", "no"],label="Heart disease")
 
118
  with gr.Row():
119
- input_age_group = gr.Dropdown(['0-16','17-32','33-48','49-64','64+'],label='Age Group')
120
- input_smoking_status = gr.Dropdown(['formerly smoked', 'never smoked', 'smokes', 'Prefer not to say'],label='Smoker')
 
121
  with gr.Row():
122
- input_avg_glucose_level = gr.Slider(0, 270,label='Average Glucose Level')
 
123
  with gr.Row():
124
- input_bmi = gr.Slider(0, 45,label='BMI Range')
 
125
  with gr.Row():
126
- stroke = gr.Textbox(label='Chances of stroke')
127
- btn_ins = gr.Button(value="Submit")
128
- btn_ins.click(fn=predict_insurance, inputs=[input_gender,input_age_group,input_hypertension,input_heart_disease,
129
- input_avg_glucose_level,input_bmi,input_smoking_status], outputs=[stroke])
130
 
131
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import pickle
 
3
  import pandas as pd
4
 
5
+ # Load the trained model
6
  filename = 'knn_model.sav'
 
7
  loaded_model = pickle.load(open(filename, 'rb'))
8
 
9
+ # Helper functions for input transformations
 
10
  def hptension(hp):
11
+ return 1 if hp == 'yes' else 0
 
 
 
12
 
13
  def ht_dis(ht):
14
+ return 1 if ht == 'yes' else 0
 
 
 
 
15
 
16
  def gender_select(gen):
17
+ return 1 if gen == 'male' else 0
18
+
 
 
 
19
  def age_group_selector(age_grp):
20
+ age_map = {'0-16': 0, '17-32': 1, '33-48': 2, '49-64': 3, '64+': 4}
21
+ return age_map.get(age_grp, 0)
 
 
 
 
22
 
23
  def smoker_cat(smoke):
24
+ smoke_map = {'formerly smoked': 0, 'never smoked': 1, 'smokes': 2, 'Prefer not to say': 3}
25
+ return smoke_map.get(smoke, 3)
26
+
27
+ # Prediction function
28
+ def predict_insurance(input_gender, input_age_group, input_hypertension, input_heart_disease, input_avg_glucose_level, input_bmi, input_smoking_status):
29
+ # Prepare the input data
30
+ series = {
31
+ 'gender': [gender_select(input_gender)],
32
+ 'age_band': [age_group_selector(input_age_group)],
33
+ 'hypertension': [hptension(input_hypertension)],
34
+ 'heart_disease': [ht_dis(input_heart_disease)],
35
+ 'avg_glucose_level': [input_avg_glucose_level / 272],
36
+ 'bmi': [input_bmi / 49],
37
+ 'smoking_status': [smoker_cat(input_smoking_status)],
38
+ }
 
 
 
39
 
40
  vector = pd.DataFrame(series)
41
 
42
+ # Perform prediction
43
  result = loaded_model.predict(vector)
44
+ return "Risk of having stroke is high" if result[0] == 1 else "Risk of having stroke is low"
 
 
 
45
 
46
+ # CSS to hide footer and markdown elements
47
  css = """
48
  footer {display:none !important}
49
  .output-markdown{display:none !important}
50
  footer {visibility: hidden}
 
51
  .gr-button-lg {
52
  z-index: 14;
53
  width: 113px;
 
86
  transition: box-shadow 200ms ease 0s, background 200ms ease 0s !important;
87
  box-shadow: rgb(0 0 0 / 23%) 0px 1px 7px 0px !important;
88
  }
 
89
  """
90
 
91
+ # Gradio app layout
92
+ with gr.Blocks(title="Brain Stroke Prediction | Data Science Dojo", css=css) as demo:
93
  with gr.Row():
94
+ input_gender = gr.Radio(["male", "female"], label="Gender")
95
+ input_hypertension = gr.Radio(["yes", "no"], label="Hypertension")
96
+ input_heart_disease = gr.Radio(["yes", "no"], label="Heart Disease")
97
+
98
  with gr.Row():
99
+ input_age_group = gr.Dropdown(['0-16', '17-32', '33-48', '49-64', '64+'], label='Age Group')
100
+ input_smoking_status = gr.Dropdown(['formerly smoked', 'never smoked', 'smokes', 'Prefer not to say'], label='Smoker')
101
+
102
  with gr.Row():
103
+ input_avg_glucose_level = gr.Slider(0, 270, label='Average Glucose Level')
104
+
105
  with gr.Row():
106
+ input_bmi = gr.Slider(0, 45, label='BMI Range')
107
+
108
  with gr.Row():
109
+ stroke = gr.Textbox(label='Chances of Stroke', interactive=False)
 
 
 
110
 
111
+ btn_ins = gr.Button(value="Submit")
112
+ btn_ins.click(
113
+ fn=predict_insurance,
114
+ inputs=[
115
+ input_gender,
116
+ input_age_group,
117
+ input_hypertension,
118
+ input_heart_disease,
119
+ input_avg_glucose_level,
120
+ input_bmi,
121
+ input_smoking_status,
122
+ ],
123
+ outputs=[stroke]
124
+ )
125
+
126
+ demo.launch(debug=True)