zetavg commited on
Commit
750c900
Β·
unverified Β·
1 Parent(s): 9c78439

support flagging inference outputs

Browse files
llama_lora/lib/inference.py CHANGED
@@ -66,14 +66,14 @@ def generate(
66
  with generate_with_streaming(**generate_params) as generator:
67
  for output in generator:
68
  decoded_output = tokenizer.decode(output, skip_special_tokens=skip_special_tokens)
69
- yield decoded_output, output
70
  if output[-1] in [tokenizer.eos_token_id]:
71
  break
72
 
73
  if generation_output:
74
  output = generation_output.sequences[0]
75
  decoded_output = tokenizer.decode(output, skip_special_tokens=skip_special_tokens)
76
- yield decoded_output, output
77
 
78
  return # early return for stream_output
79
 
@@ -82,5 +82,5 @@ def generate(
82
  generation_output = model.generate(**generate_params)
83
  output = generation_output.sequences[0]
84
  decoded_output = tokenizer.decode(output, skip_special_tokens=skip_special_tokens)
85
- yield decoded_output, output
86
  return
 
66
  with generate_with_streaming(**generate_params) as generator:
67
  for output in generator:
68
  decoded_output = tokenizer.decode(output, skip_special_tokens=skip_special_tokens)
69
+ yield decoded_output, output, False
70
  if output[-1] in [tokenizer.eos_token_id]:
71
  break
72
 
73
  if generation_output:
74
  output = generation_output.sequences[0]
75
  decoded_output = tokenizer.decode(output, skip_special_tokens=skip_special_tokens)
76
+ yield decoded_output, output, True
77
 
78
  return # early return for stream_output
79
 
 
82
  generation_output = model.generate(**generate_params)
83
  output = generation_output.sequences[0]
84
  decoded_output = tokenizer.decode(output, skip_special_tokens=skip_special_tokens)
85
+ yield decoded_output, output, True
86
  return
llama_lora/ui/inference_ui.py CHANGED
@@ -1,4 +1,5 @@
1
  import gradio as gr
 
2
  import time
3
  import json
4
 
@@ -21,13 +22,21 @@ default_show_raw = True
21
  inference_output_lines = 12
22
 
23
 
 
 
 
 
 
 
 
 
24
  def prepare_inference(lora_model_name, progress=gr.Progress(track_tqdm=True)):
25
  base_model_name = Global.base_model_name
26
 
27
  try:
28
  get_tokenizer(base_model_name)
29
  get_model(base_model_name, lora_model_name)
30
- return ("", "")
31
 
32
  except Exception as e:
33
  raise gr.Error(e)
@@ -65,6 +74,31 @@ def do_inference(
65
  prompter = Prompter(prompt_template)
66
  prompt = prompter.generate_prompt(variables)
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  if Global.ui_dev_mode:
69
  message = f"Hi, I’m currently in UI-development mode and do not have access to resources to process your request. However, this behavior is similar to what will actually happen, so you can try and see how it will work!\n\nBase model: {base_model_name}\nLoRA model: {lora_model_name}\n\nThe following is your prompt:\n\n{prompt}"
70
  print(message)
@@ -83,35 +117,50 @@ def do_inference(
83
  out += "\n"
84
  yield out
85
 
 
86
  for partial_sentence in word_generator(message):
 
87
  yield (
88
  gr.Textbox.update(
89
- value=partial_sentence, lines=inference_output_lines),
 
90
  json.dumps(
91
- list(range(len(partial_sentence.split()))), indent=2)
 
 
 
 
 
92
  )
93
  time.sleep(0.05)
94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  return
96
  time.sleep(1)
97
  yield (
98
  gr.Textbox.update(value=message, lines=inference_output_lines),
99
- json.dumps(list(range(len(message.split()))), indent=2)
 
 
 
100
  )
101
  return
102
 
103
  tokenizer = get_tokenizer(base_model_name)
104
  model = get_model(base_model_name, lora_model_name)
105
 
106
- generation_config = GenerationConfig(
107
- temperature=float(temperature), # to avoid ValueError('`temperature` has to be a strictly positive float, but is 2')
108
- top_p=top_p,
109
- top_k=top_k,
110
- repetition_penalty=repetition_penalty,
111
- num_beams=num_beams,
112
- do_sample=temperature > 0, # https://github.com/huggingface/transformers/issues/22405#issuecomment-1485527953
113
- )
114
-
115
  def ui_generation_stopping_criteria(input_ids, score, **kwargs):
116
  if Global.should_stop_generating:
117
  return True
@@ -129,10 +178,8 @@ def do_inference(
129
  'stream_output': stream_output
130
  }
131
 
132
- for (decoded_output, output) in generate(**generation_args):
133
- raw_output_str = None
134
- if show_raw:
135
- raw_output_str = str(output)
136
  response = prompter.get_response(decoded_output)
137
 
138
  if Global.should_stop_generating:
@@ -141,7 +188,12 @@ def do_inference(
141
  yield (
142
  gr.Textbox.update(
143
  value=response, lines=inference_output_lines),
144
- raw_output_str)
 
 
 
 
 
145
 
146
  if Global.should_stop_generating:
147
  # If the user stops the generation, and then clicks the
@@ -199,11 +251,13 @@ def get_warning_message_for_lora_model_and_prompt_template(lora_model, prompt_te
199
  if lora_mode_info and isinstance(lora_mode_info, dict):
200
  model_base_model = lora_mode_info.get("base_model")
201
  if model_base_model and model_base_model != Global.base_model_name:
202
- messages.append(f"⚠️ This model was trained on top of base model `{model_base_model}`, it might not work properly with the selected base model `{Global.base_model_name}`.")
 
203
 
204
  model_prompt_template = lora_mode_info.get("prompt_template")
205
  if model_prompt_template and model_prompt_template != prompt_template:
206
- messages.append(f"This model was trained with prompt template `{model_prompt_template}`.")
 
207
 
208
  return " ".join(messages)
209
 
@@ -221,7 +275,8 @@ def handle_prompt_template_change(prompt_template, lora_model):
221
 
222
  model_prompt_template_message_update = gr.Markdown.update(
223
  "", visible=False)
224
- warning_message = get_warning_message_for_lora_model_and_prompt_template(lora_model, prompt_template)
 
225
  if warning_message:
226
  model_prompt_template_message_update = gr.Markdown.update(
227
  warning_message, visible=True)
@@ -241,7 +296,8 @@ def handle_lora_model_change(lora_model, prompt_template):
241
 
242
  model_prompt_template_message_update = gr.Markdown.update(
243
  "", visible=False)
244
- warning_message = get_warning_message_for_lora_model_and_prompt_template(lora_model, prompt_template)
 
245
  if warning_message:
246
  model_prompt_template_message_update = gr.Markdown.update(
247
  warning_message, visible=True)
@@ -260,6 +316,56 @@ def update_prompt_preview(prompt_template,
260
 
261
 
262
  def inference_ui():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
  things_that_might_timeout = []
264
 
265
  with gr.Blocks() as inference_ui_blocks:
@@ -387,6 +493,47 @@ def inference_ui():
387
  inference_output = gr.Textbox(
388
  lines=inference_output_lines, label="Output", elem_id="inference_output")
389
  inference_output.style(show_copy_button=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
390
  with gr.Accordion(
391
  "Raw Output",
392
  open=not default_show_raw,
@@ -400,7 +547,8 @@ def inference_ui():
400
  interactive=False,
401
  elem_id="inference_raw_output")
402
 
403
- reload_selected_models_btn = gr.Button("", elem_id="inference_reload_selected_models_btn")
 
404
 
405
  show_raw_change_event = show_raw.change(
406
  fn=lambda show_raw: gr.Accordion.update(visible=show_raw),
@@ -440,7 +588,8 @@ def inference_ui():
440
  generate_event = generate_btn.click(
441
  fn=prepare_inference,
442
  inputs=[lora_model],
443
- outputs=[inference_output, inference_raw_output],
 
444
  ).then(
445
  fn=do_inference,
446
  inputs=[
@@ -457,7 +606,8 @@ def inference_ui():
457
  stream_output,
458
  show_raw,
459
  ],
460
- outputs=[inference_output, inference_raw_output],
 
461
  api_name="inference"
462
  )
463
  stop_btn.click(
 
1
  import gradio as gr
2
+ import os
3
  import time
4
  import json
5
 
 
22
  inference_output_lines = 12
23
 
24
 
25
+ class LoggingItem:
26
+ def __init__(self, label):
27
+ self.label = label
28
+
29
+ def deserialize(self, value, **kwargs):
30
+ return value
31
+
32
+
33
  def prepare_inference(lora_model_name, progress=gr.Progress(track_tqdm=True)):
34
  base_model_name = Global.base_model_name
35
 
36
  try:
37
  get_tokenizer(base_model_name)
38
  get_model(base_model_name, lora_model_name)
39
+ return ("", "", gr.Textbox.update(visible=False))
40
 
41
  except Exception as e:
42
  raise gr.Error(e)
 
74
  prompter = Prompter(prompt_template)
75
  prompt = prompter.generate_prompt(variables)
76
 
77
+ generation_config = GenerationConfig(
78
+ # to avoid ValueError('`temperature` has to be a strictly positive float, but is 2')
79
+ temperature=float(temperature),
80
+ top_p=top_p,
81
+ top_k=top_k,
82
+ repetition_penalty=repetition_penalty,
83
+ num_beams=num_beams,
84
+ # https://github.com/huggingface/transformers/issues/22405#issuecomment-1485527953
85
+ do_sample=temperature > 0,
86
+ )
87
+
88
+ def get_output_for_flagging(output, raw_output, completed=True):
89
+ return json.dumps({
90
+ 'base_model': base_model_name,
91
+ 'adaptor_model': lora_model_name,
92
+ 'prompt': prompt,
93
+ 'output': output,
94
+ 'completed': completed,
95
+ 'raw_output': raw_output,
96
+ 'max_new_tokens': max_new_tokens,
97
+ 'prompt_template': prompt_template,
98
+ 'prompt_template_variables': variables,
99
+ 'generation_config': generation_config.to_dict(),
100
+ })
101
+
102
  if Global.ui_dev_mode:
103
  message = f"Hi, I’m currently in UI-development mode and do not have access to resources to process your request. However, this behavior is similar to what will actually happen, so you can try and see how it will work!\n\nBase model: {base_model_name}\nLoRA model: {lora_model_name}\n\nThe following is your prompt:\n\n{prompt}"
104
  print(message)
 
117
  out += "\n"
118
  yield out
119
 
120
+ output = ""
121
  for partial_sentence in word_generator(message):
122
+ output = partial_sentence
123
  yield (
124
  gr.Textbox.update(
125
+ value=output,
126
+ lines=inference_output_lines),
127
  json.dumps(
128
+ list(range(len(output.split()))),
129
+ indent=2),
130
+ gr.Textbox.update(
131
+ value=get_output_for_flagging(
132
+ output, "", completed=False),
133
+ visible=True)
134
  )
135
  time.sleep(0.05)
136
 
137
+ yield (
138
+ gr.Textbox.update(
139
+ value=output,
140
+ lines=inference_output_lines),
141
+ json.dumps(
142
+ list(range(len(output.split()))),
143
+ indent=2),
144
+ gr.Textbox.update(
145
+ value=get_output_for_flagging(
146
+ output, "", completed=True),
147
+ visible=True)
148
+ )
149
+
150
  return
151
  time.sleep(1)
152
  yield (
153
  gr.Textbox.update(value=message, lines=inference_output_lines),
154
+ json.dumps(list(range(len(message.split()))), indent=2),
155
+ gr.Textbox.update(
156
+ value=get_output_for_flagging(message, ""),
157
+ visible=True)
158
  )
159
  return
160
 
161
  tokenizer = get_tokenizer(base_model_name)
162
  model = get_model(base_model_name, lora_model_name)
163
 
 
 
 
 
 
 
 
 
 
164
  def ui_generation_stopping_criteria(input_ids, score, **kwargs):
165
  if Global.should_stop_generating:
166
  return True
 
178
  'stream_output': stream_output
179
  }
180
 
181
+ for (decoded_output, output, completed) in generate(**generation_args):
182
+ raw_output_str = str(output)
 
 
183
  response = prompter.get_response(decoded_output)
184
 
185
  if Global.should_stop_generating:
 
188
  yield (
189
  gr.Textbox.update(
190
  value=response, lines=inference_output_lines),
191
+ raw_output_str,
192
+ gr.Textbox.update(
193
+ value=get_output_for_flagging(
194
+ decoded_output, raw_output_str, completed=completed),
195
+ visible=True)
196
+ )
197
 
198
  if Global.should_stop_generating:
199
  # If the user stops the generation, and then clicks the
 
251
  if lora_mode_info and isinstance(lora_mode_info, dict):
252
  model_base_model = lora_mode_info.get("base_model")
253
  if model_base_model and model_base_model != Global.base_model_name:
254
+ messages.append(
255
+ f"⚠️ This model was trained on top of base model `{model_base_model}`, it might not work properly with the selected base model `{Global.base_model_name}`.")
256
 
257
  model_prompt_template = lora_mode_info.get("prompt_template")
258
  if model_prompt_template and model_prompt_template != prompt_template:
259
+ messages.append(
260
+ f"This model was trained with prompt template `{model_prompt_template}`.")
261
 
262
  return " ".join(messages)
263
 
 
275
 
276
  model_prompt_template_message_update = gr.Markdown.update(
277
  "", visible=False)
278
+ warning_message = get_warning_message_for_lora_model_and_prompt_template(
279
+ lora_model, prompt_template)
280
  if warning_message:
281
  model_prompt_template_message_update = gr.Markdown.update(
282
  warning_message, visible=True)
 
296
 
297
  model_prompt_template_message_update = gr.Markdown.update(
298
  "", visible=False)
299
+ warning_message = get_warning_message_for_lora_model_and_prompt_template(
300
+ lora_model, prompt_template)
301
  if warning_message:
302
  model_prompt_template_message_update = gr.Markdown.update(
303
  warning_message, visible=True)
 
316
 
317
 
318
  def inference_ui():
319
+ flagging_dir = os.path.join(Global.data_dir, "flagging", "inference")
320
+ if not os.path.exists(flagging_dir):
321
+ os.makedirs(flagging_dir)
322
+
323
+ flag_callback = gr.CSVLogger()
324
+ flag_components = [
325
+ LoggingItem("Base Model"),
326
+ LoggingItem("Adaptor Model"),
327
+ LoggingItem("Type"),
328
+ LoggingItem("Prompt"),
329
+ LoggingItem("Output"),
330
+ LoggingItem("Completed"),
331
+ LoggingItem("Config"),
332
+ LoggingItem("Raw Output"),
333
+ LoggingItem("Max New Tokens"),
334
+ LoggingItem("Prompt Template"),
335
+ LoggingItem("Prompt Template Variables"),
336
+ LoggingItem("Generation Config"),
337
+ ]
338
+ flag_callback.setup(flag_components, flagging_dir)
339
+
340
+ def get_flag_callback_args(output_for_flagging_str, flag_type):
341
+ output_for_flagging = json.loads(output_for_flagging_str)
342
+ generation_config = output_for_flagging.get("generation_config", {})
343
+ config = []
344
+ if generation_config.get('do_sample', False):
345
+ config.append(
346
+ f"Temperature: {generation_config.get('temperature')}")
347
+ config.append(f"Top P: {generation_config.get('top_p')}")
348
+ config.append(f"Top K: {generation_config.get('top_k')}")
349
+ num_beams = generation_config.get('num_beams', 1)
350
+ if num_beams > 1:
351
+ config.append(f"Beams: {generation_config.get('num_beams')}")
352
+ config.append(f"RP: {generation_config.get('repetition_penalty')}")
353
+ return [
354
+ output_for_flagging.get("base_model", ""),
355
+ output_for_flagging.get("adaptor_model", ""),
356
+ flag_type,
357
+ output_for_flagging.get("prompt", ""),
358
+ output_for_flagging.get("output", ""),
359
+ str(output_for_flagging.get("completed", "")),
360
+ ", ".join(config),
361
+ output_for_flagging.get("raw_output", ""),
362
+ str(output_for_flagging.get("max_new_tokens", "")),
363
+ output_for_flagging.get("prompt_template", ""),
364
+ json.dumps(output_for_flagging.get(
365
+ "prompt_template_variables", "")),
366
+ json.dumps(output_for_flagging.get("generation_config", "")),
367
+ ]
368
+
369
  things_that_might_timeout = []
370
 
371
  with gr.Blocks() as inference_ui_blocks:
 
493
  inference_output = gr.Textbox(
494
  lines=inference_output_lines, label="Output", elem_id="inference_output")
495
  inference_output.style(show_copy_button=True)
496
+
497
+ with gr.Row(elem_id="inference_flagging_group"):
498
+ output_for_flagging = gr.Textbox(
499
+ interactive=False, visible=False,
500
+ elem_id="inference_output_for_flagging")
501
+ flag_btn = gr.Button(
502
+ "Flag", elem_id="inference_flag_btn")
503
+ flag_up_btn = gr.Button(
504
+ "πŸ‘", elem_id="inference_flag_up_btn")
505
+ flag_down_btn = gr.Button(
506
+ "πŸ‘Ž", elem_id="inference_flag_down_btn")
507
+ flag_output = gr.Markdown(
508
+ "", elem_id="inference_flag_output")
509
+ flag_btn.click(
510
+ lambda d: (flag_callback.flag(
511
+ get_flag_callback_args(d, "Flag"),
512
+ flag_option="Flag",
513
+ username=None
514
+ ), "")[1],
515
+ inputs=[output_for_flagging],
516
+ outputs=[flag_output],
517
+ preprocess=False)
518
+ flag_up_btn.click(
519
+ lambda d: (flag_callback.flag(
520
+ get_flag_callback_args(d, "πŸ‘"),
521
+ flag_option="Up Vote",
522
+ username=None
523
+ ), "")[1],
524
+ inputs=[output_for_flagging],
525
+ outputs=[flag_output],
526
+ preprocess=False)
527
+ flag_down_btn.click(
528
+ lambda d: (flag_callback.flag(
529
+ get_flag_callback_args(d, "πŸ‘Ž"),
530
+ flag_option="Down Vote",
531
+ username=None
532
+ ), "")[1],
533
+ inputs=[output_for_flagging],
534
+ outputs=[flag_output],
535
+ preprocess=False)
536
+
537
  with gr.Accordion(
538
  "Raw Output",
539
  open=not default_show_raw,
 
547
  interactive=False,
548
  elem_id="inference_raw_output")
549
 
550
+ reload_selected_models_btn = gr.Button(
551
+ "", elem_id="inference_reload_selected_models_btn")
552
 
553
  show_raw_change_event = show_raw.change(
554
  fn=lambda show_raw: gr.Accordion.update(visible=show_raw),
 
588
  generate_event = generate_btn.click(
589
  fn=prepare_inference,
590
  inputs=[lora_model],
591
+ outputs=[inference_output,
592
+ inference_raw_output, output_for_flagging],
593
  ).then(
594
  fn=do_inference,
595
  inputs=[
 
606
  stream_output,
607
  show_raw,
608
  ],
609
+ outputs=[inference_output,
610
+ inference_raw_output, output_for_flagging],
611
  api_name="inference"
612
  )
613
  stop_btn.click(
llama_lora/ui/main_page.py CHANGED
@@ -398,6 +398,45 @@ def main_page_custom_css():
398
  bottom: 16px;
399
  }
400
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
401
  #dataset_plain_text_input_variables_separator textarea,
402
  #dataset_plain_text_input_and_output_separator textarea,
403
  #dataset_plain_text_data_separator textarea {
 
398
  bottom: 16px;
399
  }
400
 
401
+ #inference_flagging_group {
402
+ position: relative;
403
+ }
404
+ #inference_flag_output {
405
+ min-height: 1px !important;
406
+ position: absolute;
407
+ top: 0;
408
+ bottom: 0;
409
+ right: 0;
410
+ pointer-events: none;
411
+ opacity: 0.5;
412
+ }
413
+ #inference_flag_output .wrap {
414
+ top: 0;
415
+ bottom: 0;
416
+ right: 0;
417
+ justify-content: center;
418
+ align-items: flex-end;
419
+ padding: 4px !important;
420
+ }
421
+ #inference_flag_output .wrap svg {
422
+ display: none;
423
+ }
424
+ .form:has(> #inference_output_for_flagging),
425
+ #inference_output_for_flagging {
426
+ display: none;
427
+ }
428
+ #inference_flagging_group:has(#inference_output_for_flagging.hidden) {
429
+ opacity: 0.5;
430
+ pointer-events: none;
431
+ }
432
+ #inference_flag_up_btn, #inference_flag_down_btn {
433
+ min-width: 44px;
434
+ flex-grow: 1;
435
+ }
436
+ #inference_flag_btn {
437
+ flex-grow: 2;
438
+ }
439
+
440
  #dataset_plain_text_input_variables_separator textarea,
441
  #dataset_plain_text_input_and_output_separator textarea,
442
  #dataset_plain_text_data_separator textarea {