zetavg commited on
Commit
9d46857
·
unverified ·
1 Parent(s): 726fa4d

update inference ui model warning message

Browse files
llama_lora/ui/inference_ui.py CHANGED
@@ -190,6 +190,23 @@ def reload_selections(current_lora_model, current_prompt_template):
190
  gr.Dropdown.update(choices=available_template_names_with_none, value=current_prompt_template))
191
 
192
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  def handle_prompt_template_change(prompt_template, lora_model):
194
  prompter = Prompter(prompt_template)
195
  var_names = prompter.get_variable_names()
@@ -203,37 +220,32 @@ def handle_prompt_template_change(prompt_template, lora_model):
203
 
204
  model_prompt_template_message_update = gr.Markdown.update(
205
  "", visible=False)
206
- lora_mode_info = get_info_of_available_lora_model(lora_model)
207
- if lora_mode_info and isinstance(lora_mode_info, dict):
208
- model_base_model = lora_mode_info.get("base_model")
209
- model_prompt_template = lora_mode_info.get("prompt_template")
210
- if model_base_model and model_base_model != Global.base_model_name:
211
- model_prompt_template_message_update = gr.Markdown.update(
212
- f"⚠️ This model was trained on top of base model `{model_base_model}`, it might not work properly with the selected base model `{Global.base_model_name}`.", visible=True)
213
- elif model_prompt_template and model_prompt_template != prompt_template:
214
- model_prompt_template_message_update = gr.Markdown.update(
215
- f"This model was trained with prompt template `{model_prompt_template}`.", visible=True)
216
 
217
  return [model_prompt_template_message_update] + gr_updates
218
 
219
 
220
  def handle_lora_model_change(lora_model, prompt_template):
221
  lora_mode_info = get_info_of_available_lora_model(lora_model)
222
- if not lora_mode_info:
223
- return gr.Markdown.update("", visible=False), prompt_template
224
-
225
- if not isinstance(lora_mode_info, dict):
226
- return gr.Markdown.update("", visible=False), prompt_template
227
 
228
- model_prompt_template = lora_mode_info.get("prompt_template")
229
- if not model_prompt_template:
230
- return gr.Markdown.update("", visible=False), prompt_template
 
 
 
231
 
232
- available_template_names = get_available_template_names()
233
- if model_prompt_template in available_template_names:
234
- return gr.Markdown.update("", visible=False), model_prompt_template
 
 
 
235
 
236
- return gr.Markdown.update(f"Trained with prompt template `{model_prompt_template}`", visible=True), prompt_template
237
 
238
 
239
  def update_prompt_preview(prompt_template,
 
190
  gr.Dropdown.update(choices=available_template_names_with_none, value=current_prompt_template))
191
 
192
 
193
+ def get_warning_message_for_lora_model_and_prompt_template(lora_model, prompt_template):
194
+ messages = []
195
+
196
+ lora_mode_info = get_info_of_available_lora_model(lora_model)
197
+
198
+ if lora_mode_info and isinstance(lora_mode_info, dict):
199
+ model_base_model = lora_mode_info.get("base_model")
200
+ if model_base_model and model_base_model != Global.base_model_name:
201
+ messages.append(f"⚠️ This model was trained on top of base model `{model_base_model}`, it might not work properly with the selected base model `{Global.base_model_name}`.")
202
+
203
+ model_prompt_template = lora_mode_info.get("prompt_template")
204
+ if model_prompt_template and model_prompt_template != prompt_template:
205
+ messages.append(f"This model was trained with prompt template `{model_prompt_template}`.")
206
+
207
+ return " ".join(messages)
208
+
209
+
210
  def handle_prompt_template_change(prompt_template, lora_model):
211
  prompter = Prompter(prompt_template)
212
  var_names = prompter.get_variable_names()
 
220
 
221
  model_prompt_template_message_update = gr.Markdown.update(
222
  "", visible=False)
223
+ warning_message = get_warning_message_for_lora_model_and_prompt_template(lora_model, prompt_template)
224
+ if warning_message:
225
+ model_prompt_template_message_update = gr.Markdown.update(
226
+ warning_message, visible=True)
 
 
 
 
 
 
227
 
228
  return [model_prompt_template_message_update] + gr_updates
229
 
230
 
231
  def handle_lora_model_change(lora_model, prompt_template):
232
  lora_mode_info = get_info_of_available_lora_model(lora_model)
 
 
 
 
 
233
 
234
+ if lora_mode_info and isinstance(lora_mode_info, dict):
235
+ model_prompt_template = lora_mode_info.get("prompt_template")
236
+ if model_prompt_template:
237
+ available_template_names = get_available_template_names()
238
+ if model_prompt_template in available_template_names:
239
+ prompt_template = model_prompt_template
240
 
241
+ model_prompt_template_message_update = gr.Markdown.update(
242
+ "", visible=False)
243
+ warning_message = get_warning_message_for_lora_model_and_prompt_template(lora_model, prompt_template)
244
+ if warning_message:
245
+ model_prompt_template_message_update = gr.Markdown.update(
246
+ warning_message, visible=True)
247
 
248
+ return model_prompt_template_message_update, prompt_template
249
 
250
 
251
  def update_prompt_preview(prompt_template,
llama_lora/ui/main_page.py CHANGED
@@ -248,12 +248,18 @@ def main_page_custom_css():
248
  #inference_lora_model_prompt_template_message:not(.hidden) + #inference_lora_model {
249
  padding-bottom: 28px;
250
  }
 
 
 
 
 
 
 
 
251
  #inference_lora_model_group > #inference_lora_model_prompt_template_message {
252
- position: absolute;
253
- bottom: 8px;
254
- left: 20px;
255
- z-index: 61;
256
- width: 999px;
257
  font-size: 12px;
258
  opacity: 0.7;
259
  }
@@ -608,7 +614,7 @@ def main_page_custom_css():
608
  }
609
 
610
  @media screen and (max-width: 392px) {
611
- #inference_lora_model, #finetune_template {
612
  border-bottom-left-radius: 0;
613
  border-bottom-right-radius: 0;
614
  }
 
248
  #inference_lora_model_prompt_template_message:not(.hidden) + #inference_lora_model {
249
  padding-bottom: 28px;
250
  }
251
+ #inference_lora_model_group {
252
+ flex-direction: column-reverse;
253
+ border-width: var(--block-border-width);
254
+ border-color: var(--block-border-color);
255
+ }
256
+ #inference_lora_model_group #inference_lora_model {
257
+ border: 0;
258
+ }
259
  #inference_lora_model_group > #inference_lora_model_prompt_template_message {
260
+ padding: var(--block-padding) !important;
261
+ margin-top: -50px !important;
262
+ margin-left: 4px !important;
 
 
263
  font-size: 12px;
264
  opacity: 0.7;
265
  }
 
614
  }
615
 
616
  @media screen and (max-width: 392px) {
617
+ #inference_lora_model, #inference_lora_model_group, #finetune_template {
618
  border-bottom-left-radius: 0;
619
  border-bottom-right-radius: 0;
620
  }