gaspar-avit commited on
Commit
74a5af8
·
1 Parent(s): ba1f774

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -30
app.py CHANGED
@@ -7,6 +7,7 @@ Created on Fri Mar 31 17:45:36 2023
7
  import os
8
  import streamlit as st
9
 
 
10
  from htbuilder import HtmlElement, div, hr, a, p, styles
11
  from htbuilder.units import percent, px
12
  from catboost import CatBoostClassifier
@@ -77,47 +78,72 @@ def footer():
77
  layout(*myargs)
78
 
79
 
80
- def update_prediction():
81
  """Callback to automatically update prediction if button has already been
82
  clicked"""
83
  if is_clicked:
84
- launch_prediction()
85
 
86
 
87
- def input_layout():
 
 
 
 
 
 
88
 
89
  input_expander = st.expander('Input parameters', True)
90
  with input_expander:
91
  # Row 1
92
  col_age, col_sex = st.columns(2)
93
  with col_age:
94
- st.slider('Age', 18, 75, on_change=update_prediction())
 
95
  with col_sex:
96
- st.radio('Gender', ['Female', 'Male'],
97
- on_change=update_prediction())
98
- # st.write('div.row-widget.stRadio > div{flex-direction: row \
99
- # justify-content: center}', unsafe_allow_html=True)
100
 
101
  # Row 2
102
  col_height, col_weight = st.columns(2)
103
  with col_height:
104
- st.slider(
105
- 'Height', 140, 200, on_change=update_prediction())
 
106
  with col_weight:
107
- st.slider(
108
- 'Weight', 40, 140, on_change=update_prediction())
 
109
 
110
  # Row 3
111
  col_ap_hi, col_ap_lo = st.columns(2)
112
  with col_ap_hi:
113
- st.slider(
114
- 'Systolic blood pressure', 90, 200, on_change=update_prediction())
 
115
  with col_ap_lo:
116
- st.slider(
117
- 'Diastolic blood pressure', 50, 120, on_change=update_prediction())
118
-
 
119
  st.write("")
120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  ###############################
122
  ## --------- MAIN ---------- ##
123
  ###############################
@@ -144,24 +170,23 @@ if __name__ == "__main__":
144
  ## --------------------------- ##
145
 
146
  # Load classification model
147
- model = CatBoostClassifier() # parameters not required.
148
- model.load_model('./model.cbm')
149
 
150
- # Define inputs
151
- input_layout()
152
-
153
- ## Create button to trigger poster generation
154
  buffer1, col1, buffer2 = st.columns([1.3, 1, 1])
155
  is_clicked = col1.button(label="Generate predictions")
156
-
157
  st.text("")
158
  st.text("")
159
-
160
-
161
- ## Generate poster
162
  if is_clicked:
163
- st.write("Work in progress!")
164
- # poster = generate_poster(data[data.title_year==session.selected_movie])
165
-
166
  st.text("")
167
  st.text("")
 
7
  import os
8
  import streamlit as st
9
 
10
+ from streamlit import session_state as session
11
  from htbuilder import HtmlElement, div, hr, a, p, styles
12
  from htbuilder.units import percent, px
13
  from catboost import CatBoostClassifier
 
78
  layout(*myargs)
79
 
80
 
81
+ def update_prediction(input_data):
82
  """Callback to automatically update prediction if button has already been
83
  clicked"""
84
  if is_clicked:
85
+ launch_prediction(input_data)
86
 
87
 
88
+ def get_input_data():
89
+ """
90
+ Generate input layout and get input values.
91
+
92
+ -return: DataFrame with input data.
93
+ """
94
+ session.input_data = pd.DataFrame()
95
 
96
  input_expander = st.expander('Input parameters', True)
97
  with input_expander:
98
  # Row 1
99
  col_age, col_sex = st.columns(2)
100
  with col_age:
101
+ session.input_data['age'] = st.slider(
102
+ 'Age', 18, 75, on_change=update_prediction(session.input_data))
103
  with col_sex:
104
+ session.input_data['sex'] = st.radio(
105
+ 'Sex', ['Female', 'Male'],
106
+ on_change=update_prediction(session.input_data))
 
107
 
108
  # Row 2
109
  col_height, col_weight = st.columns(2)
110
  with col_height:
111
+ session.input_data['height'] = st.slider(
112
+ 'Height', 140, 200,
113
+ on_change=update_prediction(session.input_data))
114
  with col_weight:
115
+ session.input_data['weight'] = st.slider(
116
+ 'Weight', 40, 140,
117
+ on_change=update_prediction(session.input_data))
118
 
119
  # Row 3
120
  col_ap_hi, col_ap_lo = st.columns(2)
121
  with col_ap_hi:
122
+ session.input_data['ap_hi'] = st.slider(
123
+ 'Systolic blood pressure', 90, 200,
124
+ on_change=update_prediction(session.input_data))
125
  with col_ap_lo:
126
+ session.input_data['ap_lo'] = st.slider(
127
+ 'Diastolic blood pressure', 50, 120,
128
+ on_change=update_prediction(session.input_data))
129
+
130
  st.write("")
131
 
132
+ return session.input_data
133
+
134
+
135
+ def generate_prediction(input_data):
136
+ """
137
+ Generate prediction of cardiovascular disease probability based on input
138
+ data.
139
+
140
+ -param input_data: DataFrame with input data
141
+
142
+ -return: prediction of cardiovascular disease probability
143
+ """
144
+ return MODEL.predict(input_data)
145
+
146
+
147
  ###############################
148
  ## --------- MAIN ---------- ##
149
  ###############################
 
170
  ## --------------------------- ##
171
 
172
  # Load classification model
173
+ MODEL = CatBoostClassifier()
174
+ MODEL.load_model('./model.cbm')
175
 
176
+ # Get inputs
177
+ session.input_data = get_input_data()
178
+
179
+ # Create button to trigger poster generation
180
  buffer1, col1, buffer2 = st.columns([1.3, 1, 1])
181
  is_clicked = col1.button(label="Generate predictions")
182
+
183
  st.text("")
184
  st.text("")
185
+
186
+ # Generate poster
 
187
  if is_clicked:
188
+ prediction = generate_prediction(session.input_data)
189
+ st.write(prediction)
190
+
191
  st.text("")
192
  st.text("")