Spaces:
Runtime error
Runtime error
upgraded code with newer version of gradio
Browse files
app.py
CHANGED
@@ -1,74 +1,53 @@
|
|
1 |
import gradio as gr
|
2 |
import pickle
|
3 |
-
from sklearn import preprocessing
|
4 |
import pandas as pd
|
5 |
|
|
|
6 |
filename = 'knn_model.sav'
|
7 |
-
|
8 |
loaded_model = pickle.load(open(filename, 'rb'))
|
9 |
|
10 |
-
|
11 |
-
|
12 |
def hptension(hp):
|
13 |
-
|
14 |
-
return 1
|
15 |
-
else:
|
16 |
-
return 0
|
17 |
|
18 |
def ht_dis(ht):
|
19 |
-
|
20 |
-
return 1
|
21 |
-
else:
|
22 |
-
return 0
|
23 |
-
|
24 |
|
25 |
def gender_select(gen):
|
26 |
-
|
27 |
-
|
28 |
-
else:
|
29 |
-
return 0
|
30 |
-
|
31 |
def age_group_selector(age_grp):
|
32 |
-
|
33 |
-
|
34 |
-
elif age_grp =='33-48': return 2
|
35 |
-
elif age_grp =='49-64': return 3
|
36 |
-
else: return 4
|
37 |
-
|
38 |
|
39 |
def smoker_cat(smoke):
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
'bmi': [input_bmi/49],
|
56 |
-
'smoking_status': [smoker_cat(input_smoking_status)],
|
57 |
-
}
|
58 |
|
59 |
vector = pd.DataFrame(series)
|
60 |
|
|
|
61 |
result = loaded_model.predict(vector)
|
62 |
-
if result[0] == 1
|
63 |
-
return "Risk of having stroke is high"
|
64 |
-
else:
|
65 |
-
return "Risk of having stroke is low"
|
66 |
|
|
|
67 |
css = """
|
68 |
footer {display:none !important}
|
69 |
.output-markdown{display:none !important}
|
70 |
footer {visibility: hidden}
|
71 |
-
|
72 |
.gr-button-lg {
|
73 |
z-index: 14;
|
74 |
width: 113px;
|
@@ -107,25 +86,41 @@ footer {visibility: hidden}
|
|
107 |
transition: box-shadow 200ms ease 0s, background 200ms ease 0s !important;
|
108 |
box-shadow: rgb(0 0 0 / 23%) 0px 1px 7px 0px !important;
|
109 |
}
|
110 |
-
|
111 |
"""
|
112 |
|
113 |
-
|
|
|
114 |
with gr.Row():
|
115 |
-
input_gender = gr.Radio(["male", "female"],label="Gender")
|
116 |
-
input_hypertension = gr.Radio(["yes", "no"],label="Hypertension")
|
117 |
-
input_heart_disease = gr.Radio(["yes", "no"],label="Heart
|
|
|
118 |
with gr.Row():
|
119 |
-
input_age_group = gr.Dropdown(['0-16','17-32','33-48','49-64','64+'],label='Age Group')
|
120 |
-
input_smoking_status = gr.Dropdown(['formerly smoked', 'never smoked', 'smokes', 'Prefer not to say'],label='Smoker')
|
|
|
121 |
with gr.Row():
|
122 |
-
input_avg_glucose_level =
|
|
|
123 |
with gr.Row():
|
124 |
-
input_bmi =
|
|
|
125 |
with gr.Row():
|
126 |
-
stroke = gr.Textbox(label='Chances of
|
127 |
-
btn_ins = gr.Button(value="Submit")
|
128 |
-
btn_ins.click(fn=predict_insurance, inputs=[input_gender,input_age_group,input_hypertension,input_heart_disease,
|
129 |
-
input_avg_glucose_level,input_bmi,input_smoking_status], outputs=[stroke])
|
130 |
|
131 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
import pickle
|
|
|
3 |
import pandas as pd
|
4 |
|
5 |
+
# Load the trained model
|
6 |
filename = 'knn_model.sav'
|
|
|
7 |
loaded_model = pickle.load(open(filename, 'rb'))
|
8 |
|
9 |
+
# Helper functions for input transformations
|
|
|
10 |
def hptension(hp):
|
11 |
+
return 1 if hp == 'yes' else 0
|
|
|
|
|
|
|
12 |
|
13 |
def ht_dis(ht):
|
14 |
+
return 1 if ht == 'yes' else 0
|
|
|
|
|
|
|
|
|
15 |
|
16 |
def gender_select(gen):
|
17 |
+
return 1 if gen == 'male' else 0
|
18 |
+
|
|
|
|
|
|
|
19 |
def age_group_selector(age_grp):
|
20 |
+
age_map = {'0-16': 0, '17-32': 1, '33-48': 2, '49-64': 3, '64+': 4}
|
21 |
+
return age_map.get(age_grp, 0)
|
|
|
|
|
|
|
|
|
22 |
|
23 |
def smoker_cat(smoke):
|
24 |
+
smoke_map = {'formerly smoked': 0, 'never smoked': 1, 'smokes': 2, 'Prefer not to say': 3}
|
25 |
+
return smoke_map.get(smoke, 3)
|
26 |
+
|
27 |
+
# Prediction function
|
28 |
+
def predict_insurance(input_gender, input_age_group, input_hypertension, input_heart_disease, input_avg_glucose_level, input_bmi, input_smoking_status):
|
29 |
+
# Prepare the input data
|
30 |
+
series = {
|
31 |
+
'gender': [gender_select(input_gender)],
|
32 |
+
'age_band': [age_group_selector(input_age_group)],
|
33 |
+
'hypertension': [hptension(input_hypertension)],
|
34 |
+
'heart_disease': [ht_dis(input_heart_disease)],
|
35 |
+
'avg_glucose_level': [input_avg_glucose_level / 272],
|
36 |
+
'bmi': [input_bmi / 49],
|
37 |
+
'smoking_status': [smoker_cat(input_smoking_status)],
|
38 |
+
}
|
|
|
|
|
|
|
39 |
|
40 |
vector = pd.DataFrame(series)
|
41 |
|
42 |
+
# Perform prediction
|
43 |
result = loaded_model.predict(vector)
|
44 |
+
return "Risk of having stroke is high" if result[0] == 1 else "Risk of having stroke is low"
|
|
|
|
|
|
|
45 |
|
46 |
+
# CSS to hide footer and markdown elements
|
47 |
css = """
|
48 |
footer {display:none !important}
|
49 |
.output-markdown{display:none !important}
|
50 |
footer {visibility: hidden}
|
|
|
51 |
.gr-button-lg {
|
52 |
z-index: 14;
|
53 |
width: 113px;
|
|
|
86 |
transition: box-shadow 200ms ease 0s, background 200ms ease 0s !important;
|
87 |
box-shadow: rgb(0 0 0 / 23%) 0px 1px 7px 0px !important;
|
88 |
}
|
|
|
89 |
"""
|
90 |
|
91 |
+
# Gradio app layout
|
92 |
+
with gr.Blocks(title="Brain Stroke Prediction | Data Science Dojo", css=css) as demo:
|
93 |
with gr.Row():
|
94 |
+
input_gender = gr.Radio(["male", "female"], label="Gender")
|
95 |
+
input_hypertension = gr.Radio(["yes", "no"], label="Hypertension")
|
96 |
+
input_heart_disease = gr.Radio(["yes", "no"], label="Heart Disease")
|
97 |
+
|
98 |
with gr.Row():
|
99 |
+
input_age_group = gr.Dropdown(['0-16', '17-32', '33-48', '49-64', '64+'], label='Age Group')
|
100 |
+
input_smoking_status = gr.Dropdown(['formerly smoked', 'never smoked', 'smokes', 'Prefer not to say'], label='Smoker')
|
101 |
+
|
102 |
with gr.Row():
|
103 |
+
input_avg_glucose_level = gr.Slider(0, 270, label='Average Glucose Level')
|
104 |
+
|
105 |
with gr.Row():
|
106 |
+
input_bmi = gr.Slider(0, 45, label='BMI Range')
|
107 |
+
|
108 |
with gr.Row():
|
109 |
+
stroke = gr.Textbox(label='Chances of Stroke', interactive=False)
|
|
|
|
|
|
|
110 |
|
111 |
+
btn_ins = gr.Button(value="Submit")
|
112 |
+
btn_ins.click(
|
113 |
+
fn=predict_insurance,
|
114 |
+
inputs=[
|
115 |
+
input_gender,
|
116 |
+
input_age_group,
|
117 |
+
input_hypertension,
|
118 |
+
input_heart_disease,
|
119 |
+
input_avg_glucose_level,
|
120 |
+
input_bmi,
|
121 |
+
input_smoking_status,
|
122 |
+
],
|
123 |
+
outputs=[stroke]
|
124 |
+
)
|
125 |
+
|
126 |
+
demo.launch(debug=True)
|