Robin Genolet commited on
Commit
0b6f5b0
·
1 Parent(s): 6ed9cc0

fix: default values

Browse files
Files changed (1) hide show
  1. app.py +6 -77
app.py CHANGED
@@ -85,12 +85,12 @@ def display_streamlit_sidebar():
85
 
86
  model_name_or_path = form.selectbox("Select model", gptq_model_options())
87
 
88
- temperature = form.slider(label="Temperature", min_value=0.0, max_value=1.0, step=0.01, value=0.01)
89
- do_sample = form.checkbox('do_sample')
90
- top_p = form.slider(label="top_p", min_value=0.0, max_value=1.0, step=0.01, value=0.95)
91
- top_k = form.slider(label="top_k", min_value=1, max_value=1000, step=1, value=40)
92
- max_new_tokens = form.slider(label="max_new_tokens", min_value=32, max_value=512, step=1, value=32)
93
- repetition_penalty = form.slider(label="repetition_penalty", min_value=0.0, max_value=1.0, step=0.01, value=0.95)
94
 
95
  temperature = form.slider('Temperature (0 = deterministic, 1 = more freedom)', min_value=0.0,
96
  max_value=1.0, value=st.session_state['temperature'], step=0.1)
@@ -115,74 +115,6 @@ def display_streamlit_sidebar():
115
  st.session_state["repetition_penalty"] = repetition_penalty
116
 
117
  st.rerun()
118
-
119
-
120
- def to_str(text):
121
- res = str(text)
122
- if res == "nan":
123
- return " "
124
- return " " + res
125
-
126
-
127
- def set_df_prompts(path, sheet_name):
128
- df_prompts = pd.read_excel(path, sheet_name, header=None)
129
- for i in range(3, df_prompts.shape[0]):
130
- df_prompts.iloc[2] += df_prompts.iloc[i].apply(to_str)
131
-
132
- df_prompts = df_prompts.T
133
- df_prompts = df_prompts[[0, 1, 2]]
134
- df_prompts[0] = df_prompts[0].astype(str)
135
- df_prompts[1] = df_prompts[1].astype(str)
136
- df_prompts[2] = df_prompts[2].astype(str)
137
-
138
- df_prompts.columns = ["Questionnaire", "Used Guideline", "Prompt"]
139
- df_prompts = df_prompts[1:]
140
- st.session_state["df_prompts"] = df_prompts
141
-
142
-
143
- def handle_nbq_click(c):
144
- question_without_source = re.sub(r'\[.*\]', '', c)
145
- question_without_source = question_without_source.strip()
146
- st.session_state['doctor_question'] = question_without_source
147
-
148
-
149
- def get_doctor_question_value():
150
- if 'doctor_question' in st.session_state:
151
- return st.session_state['doctor_question']
152
-
153
- return ''
154
-
155
-
156
- def update_chat_history(dr_question, patient_reply):
157
- print("update_chat_history" + str(dr_question) + " - " + str(patient_reply) + '...\n')
158
- if dr_question is not None:
159
- dr_msg = {
160
- "role": "Doctor",
161
- "content": dr_question
162
- }
163
- st.session_state["chat_history_array"].append(dr_msg)
164
-
165
- if patient_reply is not None:
166
- patient_msg = {
167
- "role": "Patient",
168
- "content": patient_reply
169
- }
170
- st.session_state["chat_history_array"].append(patient_msg)
171
-
172
- return st.session_state["chat_history_array"]
173
-
174
-
175
- def get_chat_history_string(chat_history):
176
- res = ''
177
- for i in chat_history:
178
- if i["role"] == "Doctor":
179
- res += '**Doctor**: ' + str(i["content"].strip()) + " \n "
180
- elif i["role"] == "Patient":
181
- res += '**Patient**: ' + str(i["content"].strip()) + " \n\n "
182
- else:
183
- raise Exception('Unknown role: ' + str(i["role"]))
184
-
185
- return res
186
 
187
 
188
  def init_session_state():
@@ -202,9 +134,6 @@ def get_genders():
202
  return ['Male', 'Female']
203
 
204
 
205
-
206
-
207
-
208
  def display_session_overview():
209
  st.subheader('History of LLM queries')
210
  st.write(st.session_state["llm_messages"])
 
85
 
86
  model_name_or_path = form.selectbox("Select model", gptq_model_options())
87
 
88
+ temperature = form.slider(label="Temperature", min_value=0.0, max_value=1.0, step=0.01, value=st.session_state["temperature"])
89
+ do_sample = form.checkbox('do_sample', value=st.session_state["do_sample"])
90
+ top_p = form.slider(label="top_p", min_value=0.0, max_value=1.0, step=0.01, value=st.session_state["top_p"])
91
+ top_k = form.slider(label="top_k", min_value=1, max_value=1000, step=1, value=st.session_state["top_k"])
92
+ max_new_tokens = form.slider(label="max_new_tokens", min_value=32, max_value=512, step=1, value=st.session_state["max_new_tokens"])
93
+ repetition_penalty = form.slider(label="repetition_penalty", min_value=0.0, max_value=5.0, step=0.01, value=st.session_state["repetition_penalty"])
94
 
95
  temperature = form.slider('Temperature (0 = deterministic, 1 = more freedom)', min_value=0.0,
96
  max_value=1.0, value=st.session_state['temperature'], step=0.1)
 
115
  st.session_state["repetition_penalty"] = repetition_penalty
116
 
117
  st.rerun()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
 
120
  def init_session_state():
 
134
  return ['Male', 'Female']
135
 
136
 
 
 
 
137
  def display_session_overview():
138
  st.subheader('History of LLM queries')
139
  st.write(st.session_state["llm_messages"])