zetavg commited on
Commit
4ac0d6a
·
unverified ·
1 Parent(s): f79f0d5

add possible missing model configs and update ui

Browse files
llama_lora/models.py CHANGED
@@ -89,6 +89,11 @@ def load_base_model():
89
  base_model, device_map={"": device}, low_cpu_mem_usage=True
90
  )
91
 
 
 
 
 
 
92
 
93
  def unload_models():
94
  del Global.loaded_base_model
 
89
  base_model, device_map={"": device}, low_cpu_mem_usage=True
90
  )
91
 
92
+ # unwind broken decapoda-research config
93
+ model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk
94
+ model.config.bos_token_id = 1
95
+ model.config.eos_token_id = 2
96
+
97
 
98
  def unload_models():
99
  del Global.loaded_base_model
llama_lora/ui/inference_ui.py CHANGED
@@ -16,6 +16,8 @@ from ..utils.callbacks import Iteratorize, Stream
16
 
17
  device = get_device()
18
 
 
 
19
 
20
  def do_inference(
21
  lora_model_name,
@@ -29,6 +31,7 @@ def do_inference(
29
  repetition_penalty=1.2,
30
  max_new_tokens=128,
31
  stream_output=False,
 
32
  progress=gr.Progress(track_tqdm=True),
33
  ):
34
  try:
@@ -47,7 +50,7 @@ def do_inference(
47
  message = f"Currently in UI dev mode, not running actual inference.\n\nLoRA model: {lora_model_name}\n\nYour prompt is:\n\n{prompt}"
48
  print(message)
49
  time.sleep(1)
50
- yield message
51
  return
52
 
53
  if lora_model_name == "None":
@@ -102,7 +105,10 @@ def do_inference(
102
  if output[-1] in [tokenizer.eos_token_id]:
103
  break
104
 
105
- yield prompter.get_response(decoded_output)
 
 
 
106
  return # early return for stream_output
107
 
108
  # Without streaming
@@ -116,7 +122,10 @@ def do_inference(
116
  )
117
  s = generation_output.sequences[0]
118
  output = tokenizer.decode(s)
119
- yield prompter.get_response(output)
 
 
 
120
 
121
  except Exception as e:
122
  raise gr.Error(e)
@@ -249,11 +258,17 @@ def inference_ui():
249
  elem_id="inference_max_new_tokens"
250
  )
251
 
252
- stream_output = gr.Checkbox(
253
- label="Stream Output",
254
- elem_id="inference_stream_output",
255
- value=True
256
- )
 
 
 
 
 
 
257
 
258
  with gr.Column():
259
  with gr.Row():
@@ -267,6 +282,23 @@ def inference_ui():
267
  inference_output = gr.Textbox(
268
  lines=12, label="Output", elem_id="inference_output")
269
  inference_output.style(show_copy_button=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
 
271
  reload_selections_button.click(
272
  reload_selections,
@@ -291,8 +323,9 @@ def inference_ui():
291
  repetition_penalty,
292
  max_new_tokens,
293
  stream_output,
 
294
  ],
295
- outputs=inference_output,
296
  api_name="inference"
297
  )
298
  stop_btn.click(fn=None, inputs=None, outputs=None,
 
16
 
17
  device = get_device()
18
 
19
+ default_show_raw = True
20
+
21
 
22
  def do_inference(
23
  lora_model_name,
 
31
  repetition_penalty=1.2,
32
  max_new_tokens=128,
33
  stream_output=False,
34
+ show_raw=False,
35
  progress=gr.Progress(track_tqdm=True),
36
  ):
37
  try:
 
50
  message = f"Currently in UI dev mode, not running actual inference.\n\nLoRA model: {lora_model_name}\n\nYour prompt is:\n\n{prompt}"
51
  print(message)
52
  time.sleep(1)
53
+ yield message, '[0]'
54
  return
55
 
56
  if lora_model_name == "None":
 
105
  if output[-1] in [tokenizer.eos_token_id]:
106
  break
107
 
108
+ raw_output = None
109
+ if show_raw:
110
+ raw_output = str(output)
111
+ yield prompter.get_response(decoded_output), raw_output
112
  return # early return for stream_output
113
 
114
  # Without streaming
 
122
  )
123
  s = generation_output.sequences[0]
124
  output = tokenizer.decode(s)
125
+ raw_output = None
126
+ if show_raw:
127
+ raw_output = str(s)
128
+ yield prompter.get_response(output), raw_output
129
 
130
  except Exception as e:
131
  raise gr.Error(e)
 
258
  elem_id="inference_max_new_tokens"
259
  )
260
 
261
+ with gr.Row():
262
+ stream_output = gr.Checkbox(
263
+ label="Stream Output",
264
+ elem_id="inference_stream_output",
265
+ value=True
266
+ )
267
+ show_raw = gr.Checkbox(
268
+ label="Show Raw",
269
+ elem_id="inference_show_raw",
270
+ value=default_show_raw
271
+ )
272
 
273
  with gr.Column():
274
  with gr.Row():
 
282
  inference_output = gr.Textbox(
283
  lines=12, label="Output", elem_id="inference_output")
284
  inference_output.style(show_copy_button=True)
285
+ with gr.Accordion(
286
+ "Raw Output",
287
+ open=False,
288
+ visible=default_show_raw,
289
+ elem_id="inference_inference_raw_output_accordion"
290
+ ) as raw_output_group:
291
+ inference_raw_output = gr.Code(
292
+ label="Raw Output",
293
+ show_label=False,
294
+ language="json",
295
+ interactive=False,
296
+ elem_id="inference_raw_output")
297
+
298
+ show_raw.change(
299
+ fn=lambda show_raw: gr.Accordion.update(visible=show_raw),
300
+ inputs=[show_raw],
301
+ outputs=[raw_output_group])
302
 
303
  reload_selections_button.click(
304
  reload_selections,
 
323
  repetition_penalty,
324
  max_new_tokens,
325
  stream_output,
326
+ show_raw,
327
  ],
328
+ outputs=[inference_output, inference_raw_output],
329
  api_name="inference"
330
  )
331
  stop_btn.click(fn=None, inputs=None, outputs=None,
llama_lora/ui/main_page.py CHANGED
@@ -5,6 +5,7 @@ from ..models import get_model_with_lora
5
 
6
  from .inference_ui import inference_ui
7
  from .finetune_ui import finetune_ui
 
8
 
9
  from .js_scripts import popperjs_core_code, tippy_js_code
10
 
@@ -25,6 +26,8 @@ def main_page():
25
  inference_ui()
26
  with gr.Tab("Fine-tuning"):
27
  finetune_ui()
 
 
28
  info = []
29
  if Global.version:
30
  info.append(f"LLaMA-LoRA `{Global.version}`")
@@ -100,6 +103,10 @@ def main_page_custom_css():
100
  font-weight: 100;
101
  }
102
 
 
 
 
 
103
  .textbox_that_is_only_used_to_display_a_label {
104
  border: 0 !important;
105
  box-shadow: none !important;
@@ -143,7 +150,8 @@ def main_page_custom_css():
143
  box-shadow: none;
144
  }
145
 
146
- #inference_output > .wrap {
 
147
  /* allow users to select text while generation is still in progress */
148
  pointer-events: none;
149
  }
 
5
 
6
  from .inference_ui import inference_ui
7
  from .finetune_ui import finetune_ui
8
+ from .tokenizer_ui import tokenizer_ui
9
 
10
  from .js_scripts import popperjs_core_code, tippy_js_code
11
 
 
26
  inference_ui()
27
  with gr.Tab("Fine-tuning"):
28
  finetune_ui()
29
+ with gr.Tab("Tokenizer"):
30
+ tokenizer_ui()
31
  info = []
32
  if Global.version:
33
  info.append(f"LLaMA-LoRA `{Global.version}`")
 
103
  font-weight: 100;
104
  }
105
 
106
+ .error-message, .error-message p {
107
+ color: var(--error-text-color) !important;
108
+ }
109
+
110
  .textbox_that_is_only_used_to_display_a_label {
111
  border: 0 !important;
112
  box-shadow: none !important;
 
150
  box-shadow: none;
151
  }
152
 
153
+ #inference_output > .wrap,
154
+ #inference_raw_output > .wrap {
155
  /* allow users to select text while generation is still in progress */
156
  pointer-events: none;
157
  }
llama_lora/ui/tokenizer_ui.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import time
3
+ import json
4
+
5
+ from ..globals import Global
6
+ from ..models import get_tokenizer
7
+
8
+
9
+ def handle_decode(encoded_tokens_json):
10
+ try:
11
+ encoded_tokens = json.loads(encoded_tokens_json)
12
+ if Global.ui_dev_mode:
13
+ return f"Not actually decoding tokens in UI dev mode.", gr.Markdown.update("", visible=False)
14
+ tokenizer = get_tokenizer()
15
+ decoded_tokens = tokenizer.decode(encoded_tokens)
16
+ return decoded_tokens, gr.Markdown.update("", visible=False)
17
+ except Exception as e:
18
+ return "", gr.Markdown.update("Error: " + str(e), visible=True)
19
+
20
+
21
+ def handle_encode(decoded_tokens):
22
+ try:
23
+ if Global.ui_dev_mode:
24
+ return f"[\"Not actually encoding tokens in UI dev mode.\"]", gr.Markdown.update("", visible=False)
25
+ tokenizer = get_tokenizer()
26
+ result = tokenizer(decoded_tokens)
27
+ encoded_tokens_json = json.dumps(result['input_ids'], indent=2)
28
+ return encoded_tokens_json, gr.Markdown.update("", visible=False)
29
+ except Exception as e:
30
+ return "", gr.Markdown.update("Error: " + str(e), visible=True)
31
+
32
+
33
+ def tokenizer_ui():
34
+ with gr.Blocks() as tokenizer_ui_blocks:
35
+ with gr.Row():
36
+ with gr.Column():
37
+ encoded_tokens = gr.Code(
38
+ label="Encoded Tokens (JSON)",
39
+ language="json",
40
+ value=sample_encoded_tokens_value,
41
+ elem_id="tokenizer_encoded_tokens_input_textbox")
42
+ decode_btn = gr.Button("Decode ➡️")
43
+ encoded_tokens_error_message = gr.Markdown(
44
+ "", visible=False, elem_classes="error-message")
45
+ with gr.Column():
46
+ decoded_tokens = gr.Code(
47
+ label="Decoded Tokens",
48
+ value=sample_decoded_text_value,
49
+ elem_id="tokenizer_decoded_text_input_textbox")
50
+ encode_btn = gr.Button("⬅️ Encode")
51
+ decoded_tokens_error_message = gr.Markdown(
52
+ "", visible=False, elem_classes="error-message")
53
+ stop_btn = gr.Button("Stop")
54
+
55
+ decoding = decode_btn.click(
56
+ fn=handle_decode,
57
+ inputs=[encoded_tokens],
58
+ outputs=[decoded_tokens, encoded_tokens_error_message],
59
+ )
60
+ encoding = encode_btn.click(
61
+ fn=handle_encode,
62
+ inputs=[decoded_tokens],
63
+ outputs=[encoded_tokens, decoded_tokens_error_message],
64
+ )
65
+ stop_btn.click(fn=None, inputs=None, outputs=None, cancels=[decoding, encoding])
66
+
67
+ tokenizer_ui_blocks.load(_js="""
68
+ function tokenizer_ui_blocks_js() {
69
+ }
70
+ """)
71
+
72
+
73
+ sample_encoded_tokens_value = """
74
+ [
75
+ 15043,
76
+ 3186,
77
+ 29889
78
+ ]
79
+ """
80
+
81
+ sample_decoded_text_value = """
82
+ """