datasciencedojo's picture
upgraded code with newer version of gradio
22021ad verified
import gradio as gr
import pickle
import pandas as pd
# Load the trained model
filename = 'knn_model.sav'
loaded_model = pickle.load(open(filename, 'rb'))
# Helper functions for input transformations
def hptension(hp):
return 1 if hp == 'yes' else 0
def ht_dis(ht):
return 1 if ht == 'yes' else 0
def gender_select(gen):
return 1 if gen == 'male' else 0
def age_group_selector(age_grp):
age_map = {'0-16': 0, '17-32': 1, '33-48': 2, '49-64': 3, '64+': 4}
return age_map.get(age_grp, 0)
def smoker_cat(smoke):
smoke_map = {'formerly smoked': 0, 'never smoked': 1, 'smokes': 2, 'Prefer not to say': 3}
return smoke_map.get(smoke, 3)
# Prediction function
def predict_insurance(input_gender, input_age_group, input_hypertension, input_heart_disease, input_avg_glucose_level, input_bmi, input_smoking_status):
# Prepare the input data
series = {
'gender': [gender_select(input_gender)],
'age_band': [age_group_selector(input_age_group)],
'hypertension': [hptension(input_hypertension)],
'heart_disease': [ht_dis(input_heart_disease)],
'avg_glucose_level': [input_avg_glucose_level / 272],
'bmi': [input_bmi / 49],
'smoking_status': [smoker_cat(input_smoking_status)],
}
vector = pd.DataFrame(series)
# Perform prediction
result = loaded_model.predict(vector)
return "Risk of having stroke is high" if result[0] == 1 else "Risk of having stroke is low"
# CSS to hide footer and markdown elements
css = """
footer {display:none !important}
.output-markdown{display:none !important}
footer {visibility: hidden}
.gr-button-lg {
z-index: 14;
width: 113px;
height: 30px;
left: 0px;
top: 0px;
padding: 0px;
cursor: pointer !important;
background: none rgb(17, 20, 45) !important;
border: none !important;
text-align: center !important;
font-size: 14px !important;
font-weight: 500 !important;
color: rgb(255, 255, 255) !important;
line-height: 1 !important;
border-radius: 6px !important;
transition: box-shadow 200ms ease 0s, background 200ms ease 0s !important;
box-shadow: none !important;
}
.gr-button-lg:hover{
z-index: 14;
width: 113px;
height: 30px;
left: 0px;
top: 0px;
padding: 0px;
cursor: pointer !important;
background: none rgb(66, 133, 244) !important;
border: none !important;
text-align: center !important;
font-size: 14px !important;
font-weight: 500 !important;
color: rgb(255, 255, 255) !important;
line-height: 1 !important;
border-radius: 6px !important;
transition: box-shadow 200ms ease 0s, background 200ms ease 0s !important;
box-shadow: rgb(0 0 0 / 23%) 0px 1px 7px 0px !important;
}
"""
# Gradio app layout
with gr.Blocks(title="Brain Stroke Prediction | Data Science Dojo", css=css) as demo:
with gr.Row():
input_gender = gr.Radio(["male", "female"], label="Gender")
input_hypertension = gr.Radio(["yes", "no"], label="Hypertension")
input_heart_disease = gr.Radio(["yes", "no"], label="Heart Disease")
with gr.Row():
input_age_group = gr.Dropdown(['0-16', '17-32', '33-48', '49-64', '64+'], label='Age Group')
input_smoking_status = gr.Dropdown(['formerly smoked', 'never smoked', 'smokes', 'Prefer not to say'], label='Smoker')
with gr.Row():
input_avg_glucose_level = gr.Slider(0, 270, label='Average Glucose Level')
with gr.Row():
input_bmi = gr.Slider(0, 45, label='BMI Range')
with gr.Row():
stroke = gr.Textbox(label='Chances of Stroke', interactive=False)
btn_ins = gr.Button(value="Submit")
btn_ins.click(
fn=predict_insurance,
inputs=[
input_gender,
input_age_group,
input_hypertension,
input_heart_disease,
input_avg_glucose_level,
input_bmi,
input_smoking_status,
],
outputs=[stroke]
)
demo.launch(debug=True)