Spaces:
Runtime error
Runtime error
zetavg
commited on
add possible missing model configs and update ui
Browse files- llama_lora/models.py +5 -0
- llama_lora/ui/inference_ui.py +42 -9
- llama_lora/ui/main_page.py +9 -1
- llama_lora/ui/tokenizer_ui.py +82 -0
llama_lora/models.py
CHANGED
@@ -89,6 +89,11 @@ def load_base_model():
|
|
89 |
base_model, device_map={"": device}, low_cpu_mem_usage=True
|
90 |
)
|
91 |
|
|
|
|
|
|
|
|
|
|
|
92 |
|
93 |
def unload_models():
|
94 |
del Global.loaded_base_model
|
|
|
89 |
base_model, device_map={"": device}, low_cpu_mem_usage=True
|
90 |
)
|
91 |
|
92 |
+
# unwind broken decapoda-research config
|
93 |
+
model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk
|
94 |
+
model.config.bos_token_id = 1
|
95 |
+
model.config.eos_token_id = 2
|
96 |
+
|
97 |
|
98 |
def unload_models():
|
99 |
del Global.loaded_base_model
|
llama_lora/ui/inference_ui.py
CHANGED
@@ -16,6 +16,8 @@ from ..utils.callbacks import Iteratorize, Stream
|
|
16 |
|
17 |
device = get_device()
|
18 |
|
|
|
|
|
19 |
|
20 |
def do_inference(
|
21 |
lora_model_name,
|
@@ -29,6 +31,7 @@ def do_inference(
|
|
29 |
repetition_penalty=1.2,
|
30 |
max_new_tokens=128,
|
31 |
stream_output=False,
|
|
|
32 |
progress=gr.Progress(track_tqdm=True),
|
33 |
):
|
34 |
try:
|
@@ -47,7 +50,7 @@ def do_inference(
|
|
47 |
message = f"Currently in UI dev mode, not running actual inference.\n\nLoRA model: {lora_model_name}\n\nYour prompt is:\n\n{prompt}"
|
48 |
print(message)
|
49 |
time.sleep(1)
|
50 |
-
yield message
|
51 |
return
|
52 |
|
53 |
if lora_model_name == "None":
|
@@ -102,7 +105,10 @@ def do_inference(
|
|
102 |
if output[-1] in [tokenizer.eos_token_id]:
|
103 |
break
|
104 |
|
105 |
-
|
|
|
|
|
|
|
106 |
return # early return for stream_output
|
107 |
|
108 |
# Without streaming
|
@@ -116,7 +122,10 @@ def do_inference(
|
|
116 |
)
|
117 |
s = generation_output.sequences[0]
|
118 |
output = tokenizer.decode(s)
|
119 |
-
|
|
|
|
|
|
|
120 |
|
121 |
except Exception as e:
|
122 |
raise gr.Error(e)
|
@@ -249,11 +258,17 @@ def inference_ui():
|
|
249 |
elem_id="inference_max_new_tokens"
|
250 |
)
|
251 |
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
257 |
|
258 |
with gr.Column():
|
259 |
with gr.Row():
|
@@ -267,6 +282,23 @@ def inference_ui():
|
|
267 |
inference_output = gr.Textbox(
|
268 |
lines=12, label="Output", elem_id="inference_output")
|
269 |
inference_output.style(show_copy_button=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
270 |
|
271 |
reload_selections_button.click(
|
272 |
reload_selections,
|
@@ -291,8 +323,9 @@ def inference_ui():
|
|
291 |
repetition_penalty,
|
292 |
max_new_tokens,
|
293 |
stream_output,
|
|
|
294 |
],
|
295 |
-
outputs=inference_output,
|
296 |
api_name="inference"
|
297 |
)
|
298 |
stop_btn.click(fn=None, inputs=None, outputs=None,
|
|
|
16 |
|
17 |
device = get_device()
|
18 |
|
19 |
+
default_show_raw = True
|
20 |
+
|
21 |
|
22 |
def do_inference(
|
23 |
lora_model_name,
|
|
|
31 |
repetition_penalty=1.2,
|
32 |
max_new_tokens=128,
|
33 |
stream_output=False,
|
34 |
+
show_raw=False,
|
35 |
progress=gr.Progress(track_tqdm=True),
|
36 |
):
|
37 |
try:
|
|
|
50 |
message = f"Currently in UI dev mode, not running actual inference.\n\nLoRA model: {lora_model_name}\n\nYour prompt is:\n\n{prompt}"
|
51 |
print(message)
|
52 |
time.sleep(1)
|
53 |
+
yield message, '[0]'
|
54 |
return
|
55 |
|
56 |
if lora_model_name == "None":
|
|
|
105 |
if output[-1] in [tokenizer.eos_token_id]:
|
106 |
break
|
107 |
|
108 |
+
raw_output = None
|
109 |
+
if show_raw:
|
110 |
+
raw_output = str(output)
|
111 |
+
yield prompter.get_response(decoded_output), raw_output
|
112 |
return # early return for stream_output
|
113 |
|
114 |
# Without streaming
|
|
|
122 |
)
|
123 |
s = generation_output.sequences[0]
|
124 |
output = tokenizer.decode(s)
|
125 |
+
raw_output = None
|
126 |
+
if show_raw:
|
127 |
+
raw_output = str(s)
|
128 |
+
yield prompter.get_response(output), raw_output
|
129 |
|
130 |
except Exception as e:
|
131 |
raise gr.Error(e)
|
|
|
258 |
elem_id="inference_max_new_tokens"
|
259 |
)
|
260 |
|
261 |
+
with gr.Row():
|
262 |
+
stream_output = gr.Checkbox(
|
263 |
+
label="Stream Output",
|
264 |
+
elem_id="inference_stream_output",
|
265 |
+
value=True
|
266 |
+
)
|
267 |
+
show_raw = gr.Checkbox(
|
268 |
+
label="Show Raw",
|
269 |
+
elem_id="inference_show_raw",
|
270 |
+
value=default_show_raw
|
271 |
+
)
|
272 |
|
273 |
with gr.Column():
|
274 |
with gr.Row():
|
|
|
282 |
inference_output = gr.Textbox(
|
283 |
lines=12, label="Output", elem_id="inference_output")
|
284 |
inference_output.style(show_copy_button=True)
|
285 |
+
with gr.Accordion(
|
286 |
+
"Raw Output",
|
287 |
+
open=False,
|
288 |
+
visible=default_show_raw,
|
289 |
+
elem_id="inference_inference_raw_output_accordion"
|
290 |
+
) as raw_output_group:
|
291 |
+
inference_raw_output = gr.Code(
|
292 |
+
label="Raw Output",
|
293 |
+
show_label=False,
|
294 |
+
language="json",
|
295 |
+
interactive=False,
|
296 |
+
elem_id="inference_raw_output")
|
297 |
+
|
298 |
+
show_raw.change(
|
299 |
+
fn=lambda show_raw: gr.Accordion.update(visible=show_raw),
|
300 |
+
inputs=[show_raw],
|
301 |
+
outputs=[raw_output_group])
|
302 |
|
303 |
reload_selections_button.click(
|
304 |
reload_selections,
|
|
|
323 |
repetition_penalty,
|
324 |
max_new_tokens,
|
325 |
stream_output,
|
326 |
+
show_raw,
|
327 |
],
|
328 |
+
outputs=[inference_output, inference_raw_output],
|
329 |
api_name="inference"
|
330 |
)
|
331 |
stop_btn.click(fn=None, inputs=None, outputs=None,
|
llama_lora/ui/main_page.py
CHANGED
@@ -5,6 +5,7 @@ from ..models import get_model_with_lora
|
|
5 |
|
6 |
from .inference_ui import inference_ui
|
7 |
from .finetune_ui import finetune_ui
|
|
|
8 |
|
9 |
from .js_scripts import popperjs_core_code, tippy_js_code
|
10 |
|
@@ -25,6 +26,8 @@ def main_page():
|
|
25 |
inference_ui()
|
26 |
with gr.Tab("Fine-tuning"):
|
27 |
finetune_ui()
|
|
|
|
|
28 |
info = []
|
29 |
if Global.version:
|
30 |
info.append(f"LLaMA-LoRA `{Global.version}`")
|
@@ -100,6 +103,10 @@ def main_page_custom_css():
|
|
100 |
font-weight: 100;
|
101 |
}
|
102 |
|
|
|
|
|
|
|
|
|
103 |
.textbox_that_is_only_used_to_display_a_label {
|
104 |
border: 0 !important;
|
105 |
box-shadow: none !important;
|
@@ -143,7 +150,8 @@ def main_page_custom_css():
|
|
143 |
box-shadow: none;
|
144 |
}
|
145 |
|
146 |
-
#inference_output > .wrap
|
|
|
147 |
/* allow users to select text while generation is still in progress */
|
148 |
pointer-events: none;
|
149 |
}
|
|
|
5 |
|
6 |
from .inference_ui import inference_ui
|
7 |
from .finetune_ui import finetune_ui
|
8 |
+
from .tokenizer_ui import tokenizer_ui
|
9 |
|
10 |
from .js_scripts import popperjs_core_code, tippy_js_code
|
11 |
|
|
|
26 |
inference_ui()
|
27 |
with gr.Tab("Fine-tuning"):
|
28 |
finetune_ui()
|
29 |
+
with gr.Tab("Tokenizer"):
|
30 |
+
tokenizer_ui()
|
31 |
info = []
|
32 |
if Global.version:
|
33 |
info.append(f"LLaMA-LoRA `{Global.version}`")
|
|
|
103 |
font-weight: 100;
|
104 |
}
|
105 |
|
106 |
+
.error-message, .error-message p {
|
107 |
+
color: var(--error-text-color) !important;
|
108 |
+
}
|
109 |
+
|
110 |
.textbox_that_is_only_used_to_display_a_label {
|
111 |
border: 0 !important;
|
112 |
box-shadow: none !important;
|
|
|
150 |
box-shadow: none;
|
151 |
}
|
152 |
|
153 |
+
#inference_output > .wrap,
|
154 |
+
#inference_raw_output > .wrap {
|
155 |
/* allow users to select text while generation is still in progress */
|
156 |
pointer-events: none;
|
157 |
}
|
llama_lora/ui/tokenizer_ui.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import time
|
3 |
+
import json
|
4 |
+
|
5 |
+
from ..globals import Global
|
6 |
+
from ..models import get_tokenizer
|
7 |
+
|
8 |
+
|
9 |
+
def handle_decode(encoded_tokens_json):
|
10 |
+
try:
|
11 |
+
encoded_tokens = json.loads(encoded_tokens_json)
|
12 |
+
if Global.ui_dev_mode:
|
13 |
+
return f"Not actually decoding tokens in UI dev mode.", gr.Markdown.update("", visible=False)
|
14 |
+
tokenizer = get_tokenizer()
|
15 |
+
decoded_tokens = tokenizer.decode(encoded_tokens)
|
16 |
+
return decoded_tokens, gr.Markdown.update("", visible=False)
|
17 |
+
except Exception as e:
|
18 |
+
return "", gr.Markdown.update("Error: " + str(e), visible=True)
|
19 |
+
|
20 |
+
|
21 |
+
def handle_encode(decoded_tokens):
|
22 |
+
try:
|
23 |
+
if Global.ui_dev_mode:
|
24 |
+
return f"[\"Not actually encoding tokens in UI dev mode.\"]", gr.Markdown.update("", visible=False)
|
25 |
+
tokenizer = get_tokenizer()
|
26 |
+
result = tokenizer(decoded_tokens)
|
27 |
+
encoded_tokens_json = json.dumps(result['input_ids'], indent=2)
|
28 |
+
return encoded_tokens_json, gr.Markdown.update("", visible=False)
|
29 |
+
except Exception as e:
|
30 |
+
return "", gr.Markdown.update("Error: " + str(e), visible=True)
|
31 |
+
|
32 |
+
|
33 |
+
def tokenizer_ui():
|
34 |
+
with gr.Blocks() as tokenizer_ui_blocks:
|
35 |
+
with gr.Row():
|
36 |
+
with gr.Column():
|
37 |
+
encoded_tokens = gr.Code(
|
38 |
+
label="Encoded Tokens (JSON)",
|
39 |
+
language="json",
|
40 |
+
value=sample_encoded_tokens_value,
|
41 |
+
elem_id="tokenizer_encoded_tokens_input_textbox")
|
42 |
+
decode_btn = gr.Button("Decode ➡️")
|
43 |
+
encoded_tokens_error_message = gr.Markdown(
|
44 |
+
"", visible=False, elem_classes="error-message")
|
45 |
+
with gr.Column():
|
46 |
+
decoded_tokens = gr.Code(
|
47 |
+
label="Decoded Tokens",
|
48 |
+
value=sample_decoded_text_value,
|
49 |
+
elem_id="tokenizer_decoded_text_input_textbox")
|
50 |
+
encode_btn = gr.Button("⬅️ Encode")
|
51 |
+
decoded_tokens_error_message = gr.Markdown(
|
52 |
+
"", visible=False, elem_classes="error-message")
|
53 |
+
stop_btn = gr.Button("Stop")
|
54 |
+
|
55 |
+
decoding = decode_btn.click(
|
56 |
+
fn=handle_decode,
|
57 |
+
inputs=[encoded_tokens],
|
58 |
+
outputs=[decoded_tokens, encoded_tokens_error_message],
|
59 |
+
)
|
60 |
+
encoding = encode_btn.click(
|
61 |
+
fn=handle_encode,
|
62 |
+
inputs=[decoded_tokens],
|
63 |
+
outputs=[encoded_tokens, decoded_tokens_error_message],
|
64 |
+
)
|
65 |
+
stop_btn.click(fn=None, inputs=None, outputs=None, cancels=[decoding, encoding])
|
66 |
+
|
67 |
+
tokenizer_ui_blocks.load(_js="""
|
68 |
+
function tokenizer_ui_blocks_js() {
|
69 |
+
}
|
70 |
+
""")
|
71 |
+
|
72 |
+
|
73 |
+
sample_encoded_tokens_value = """
|
74 |
+
[
|
75 |
+
15043,
|
76 |
+
3186,
|
77 |
+
29889
|
78 |
+
]
|
79 |
+
"""
|
80 |
+
|
81 |
+
sample_decoded_text_value = """
|
82 |
+
"""
|