Spaces:
Runtime error
Runtime error
Commit
·
89634f3
1
Parent(s):
0d87645
Update app.py
Browse files
app.py
CHANGED
@@ -50,36 +50,43 @@ no_change_btn = gr.Button.update()
|
|
50 |
enable_btn = gr.Button.update(interactive=True)
|
51 |
disable_btn = gr.Button.update(interactive=False)
|
52 |
|
53 |
-
enable_moderation = False
|
54 |
-
concurrency_count = 10
|
55 |
-
model_list_mode = 'reload'
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
worker_addr =
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
|
67 |
-
# enable_moderation = True if os.environ.get('enable_moderation', default='False')=="True" else False
|
68 |
-
# concurrency_count = int(os.environ.get('concurrency_count', default='10'))
|
69 |
-
# model_list_mode = os.environ.get('model_list_mode', default='reload')
|
70 |
|
71 |
-
|
72 |
-
|
73 |
-
|
|
|
|
|
|
|
74 |
|
75 |
-
|
76 |
-
# ft_list_job_url = os.environ.get('ft_list_job_url', default='')
|
77 |
-
# ft_submit_job_url = os.environ.get('ft_submit_job_url', default='')
|
78 |
-
# ft_remove_job_url = os.environ.get('ft_remove_job_url', default='')
|
79 |
-
# ft_console_log_url = os.environ.get('ft_console_log_url', default='')
|
80 |
|
81 |
|
82 |
-
headers = {"User-Agent": "FastChat Client", "PRIVATE-TOKEN":
|
83 |
|
84 |
learn_more_md = """
|
85 |
### License
|
@@ -155,14 +162,20 @@ def get_conv_log_filename():
|
|
155 |
|
156 |
|
157 |
def get_model_list(midware_url):
|
|
|
|
|
|
|
|
|
158 |
try:
|
159 |
-
ret = requests.get(midware_url, headers={"PRIVATE-TOKEN":
|
160 |
-
|
|
|
|
|
|
|
|
|
161 |
except requests.exceptions.RequestException:
|
162 |
models = ["CANNOT GET MODEL"]
|
163 |
-
|
164 |
-
priority = {k: f"___{i:02d}" for i, k in enumerate(model_info)}
|
165 |
-
models.sort(key=lambda x: priority.get(x, x))
|
166 |
logger.info(f"Models: {models}")
|
167 |
return models
|
168 |
|
@@ -202,47 +215,16 @@ def load_demo(url_params, request: gr.Request):
|
|
202 |
|
203 |
return load_demo_single(models, url_params)
|
204 |
|
205 |
-
|
206 |
-
def vote_last_response(state, vote_type, model_selector, request: gr.Request):
|
207 |
-
with open(get_conv_log_filename(), "a") as fout:
|
208 |
-
data = {
|
209 |
-
"tstamp": round(time.time(), 4),
|
210 |
-
"type": vote_type,
|
211 |
-
"model": model_selector,
|
212 |
-
"state": state.dict(),
|
213 |
-
"ip": request.client.host,
|
214 |
-
}
|
215 |
-
fout.write(json.dumps(data) + "\n")
|
216 |
-
|
217 |
-
|
218 |
-
def upvote_last_response(state, model_selector, request: gr.Request):
|
219 |
-
logger.info(f"upvote. ip: {request.client.host}")
|
220 |
-
vote_last_response(state, "upvote", model_selector, request)
|
221 |
-
return ("",) + (disable_btn,) * 3
|
222 |
-
|
223 |
-
|
224 |
-
def downvote_last_response(state, model_selector, request: gr.Request):
|
225 |
-
logger.info(f"downvote. ip: {request.client.host}")
|
226 |
-
vote_last_response(state, "downvote", model_selector, request)
|
227 |
-
return ("",) + (disable_btn,) * 3
|
228 |
-
|
229 |
-
|
230 |
-
def flag_last_response(state, model_selector, request: gr.Request):
|
231 |
-
logger.info(f"flag. ip: {request.client.host}")
|
232 |
-
vote_last_response(state, "flag", model_selector, request)
|
233 |
-
return ("",) + (disable_btn,) * 3
|
234 |
-
|
235 |
-
|
236 |
def regenerate(state, request: gr.Request):
|
237 |
logger.info(f"regenerate. ip: {request.client.host}")
|
238 |
state.conv.update_last_message(None)
|
239 |
-
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) *
|
240 |
|
241 |
|
242 |
def clear_history(request: gr.Request):
|
243 |
logger.info(f"clear_history. ip: {request.client.host}")
|
244 |
state = None
|
245 |
-
return (state, [], "") + (disable_btn,) *
|
246 |
|
247 |
|
248 |
def add_text(state, model_selector, text, request: gr.Request):
|
@@ -254,12 +236,12 @@ def add_text(state, model_selector, text, request: gr.Request):
|
|
254 |
|
255 |
if len(text) <= 0:
|
256 |
state.skip_next = True
|
257 |
-
return (state, state.to_gradio_chatbot(), "") + (no_change_btn,) *
|
258 |
|
259 |
if ip_expiration_dict[ip] < time.time():
|
260 |
logger.info(f"inactive. ip: {request.client.host}. text: {text}")
|
261 |
state.skip_next = True
|
262 |
-
return (state, state.to_gradio_chatbot(), INACTIVE_MSG) + (no_change_btn,) *
|
263 |
|
264 |
if enable_moderation:
|
265 |
flagged = violates_moderation(text)
|
@@ -268,7 +250,7 @@ def add_text(state, model_selector, text, request: gr.Request):
|
|
268 |
state.skip_next = True
|
269 |
return (state, state.to_gradio_chatbot(), MODERATION_MSG) + (
|
270 |
no_change_btn,
|
271 |
-
) *
|
272 |
|
273 |
conv = state.conv
|
274 |
if (len(conv.messages) - conv.offset) // 2 >= CONVERSATION_TURN_LIMIT:
|
@@ -276,12 +258,12 @@ def add_text(state, model_selector, text, request: gr.Request):
|
|
276 |
state.skip_next = True
|
277 |
return (state, state.to_gradio_chatbot(), CONVERSATION_LIMIT_MSG) + (
|
278 |
no_change_btn,
|
279 |
-
) *
|
280 |
|
281 |
text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off
|
282 |
conv.append_message(conv.roles[0], text)
|
283 |
conv.append_message(conv.roles[1], None)
|
284 |
-
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) *
|
285 |
|
286 |
|
287 |
def post_process_code(code):
|
@@ -343,7 +325,7 @@ def bot_response(state, temperature, top_p, max_new_tokens, request: gr.Request)
|
|
343 |
if state.skip_next:
|
344 |
# This generate call is skipped due to invalid inputs
|
345 |
state.skip_next = False
|
346 |
-
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) *
|
347 |
return
|
348 |
|
349 |
conv, model_name = state.conv, state.model_name
|
@@ -370,9 +352,6 @@ def bot_response(state, temperature, top_p, max_new_tokens, request: gr.Request)
|
|
370 |
yield (
|
371 |
state,
|
372 |
state.to_gradio_chatbot(),
|
373 |
-
disable_btn,
|
374 |
-
disable_btn,
|
375 |
-
disable_btn,
|
376 |
enable_btn,
|
377 |
enable_btn,
|
378 |
)
|
@@ -400,7 +379,7 @@ def bot_response(state, temperature, top_p, max_new_tokens, request: gr.Request)
|
|
400 |
)
|
401 |
|
402 |
conv.update_last_message("▌")
|
403 |
-
yield (state, state.to_gradio_chatbot()) + (disable_btn,) *
|
404 |
|
405 |
try:
|
406 |
for data in stream_iter:
|
@@ -410,14 +389,11 @@ def bot_response(state, temperature, top_p, max_new_tokens, request: gr.Request)
|
|
410 |
output = post_process_code(output)
|
411 |
output = str_filter(output)
|
412 |
conv.update_last_message(output + "▌")
|
413 |
-
yield (state, state.to_gradio_chatbot()) + (disable_btn,) *
|
414 |
else:
|
415 |
output = data["text"] + f"\n\n(error_code: {data['error_code']})"
|
416 |
conv.update_last_message(output)
|
417 |
yield (state, state.to_gradio_chatbot()) + (
|
418 |
-
disable_btn,
|
419 |
-
disable_btn,
|
420 |
-
disable_btn,
|
421 |
enable_btn,
|
422 |
enable_btn,
|
423 |
)
|
@@ -429,9 +405,6 @@ def bot_response(state, temperature, top_p, max_new_tokens, request: gr.Request)
|
|
429 |
f"(error_code: {ErrorCode.GRADIO_REQUEST_ERROR}, {e})"
|
430 |
)
|
431 |
yield (state, state.to_gradio_chatbot()) + (
|
432 |
-
disable_btn,
|
433 |
-
disable_btn,
|
434 |
-
disable_btn,
|
435 |
enable_btn,
|
436 |
enable_btn,
|
437 |
)
|
@@ -442,9 +415,6 @@ def bot_response(state, temperature, top_p, max_new_tokens, request: gr.Request)
|
|
442 |
f"(error_code: {ErrorCode.GRADIO_STREAM_UNKNOWN_ERROR}, {e})"
|
443 |
)
|
444 |
yield (state, state.to_gradio_chatbot()) + (
|
445 |
-
disable_btn,
|
446 |
-
disable_btn,
|
447 |
-
disable_btn,
|
448 |
enable_btn,
|
449 |
enable_btn,
|
450 |
)
|
@@ -452,7 +422,7 @@ def bot_response(state, temperature, top_p, max_new_tokens, request: gr.Request)
|
|
452 |
|
453 |
# Delete "▌"
|
454 |
conv.update_last_message(conv.messages[-1][-1][:-1])
|
455 |
-
yield (state, state.to_gradio_chatbot()) + (enable_btn,) *
|
456 |
|
457 |
finish_tstamp = time.time()
|
458 |
logger.info(f"{output}")
|
@@ -532,7 +502,6 @@ def build_single_model_ui(models, add_promotion_links=False):
|
|
532 |
with gr.Column():
|
533 |
with gr.Tab("🧠 模型对话 Dialog"):
|
534 |
state = gr.State()
|
535 |
-
|
536 |
with gr.Row(elem_id="model_selector_row"):
|
537 |
model_selector = gr.Dropdown(
|
538 |
choices=models,
|
@@ -560,12 +529,12 @@ def build_single_model_ui(models, add_promotion_links=False):
|
|
560 |
send_btn = gr.Button(value="Send", visible=False)
|
561 |
|
562 |
with gr.Row(visible=False) as button_row:
|
563 |
-
upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
|
564 |
-
downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
|
565 |
-
flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
|
566 |
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
|
567 |
clear_btn = gr.Button(value="🗑️ Clear history", interactive=False)
|
568 |
-
|
|
|
|
|
|
|
569 |
with gr.Accordion("Parameters", open=False, visible=False) as parameter_row:
|
570 |
temperature = gr.Slider(
|
571 |
minimum=0.0,
|
@@ -595,22 +564,7 @@ def build_single_model_ui(models, add_promotion_links=False):
|
|
595 |
gr.Markdown(learn_more_md)
|
596 |
|
597 |
# Register listeners
|
598 |
-
btn_list = [
|
599 |
-
upvote_btn.click(
|
600 |
-
upvote_last_response,
|
601 |
-
[state, model_selector],
|
602 |
-
[textbox, upvote_btn, downvote_btn, flag_btn],
|
603 |
-
)
|
604 |
-
downvote_btn.click(
|
605 |
-
downvote_last_response,
|
606 |
-
[state, model_selector],
|
607 |
-
[textbox, upvote_btn, downvote_btn, flag_btn],
|
608 |
-
)
|
609 |
-
flag_btn.click(
|
610 |
-
flag_last_response,
|
611 |
-
[state, model_selector],
|
612 |
-
[textbox, upvote_btn, downvote_btn, flag_btn],
|
613 |
-
)
|
614 |
regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
|
615 |
bot_response,
|
616 |
[state, temperature, top_p, max_output_tokens],
|
@@ -634,8 +588,6 @@ def build_single_model_ui(models, add_promotion_links=False):
|
|
634 |
[state, temperature, top_p, max_output_tokens],
|
635 |
[state, chatbot] + btn_list,
|
636 |
)
|
637 |
-
|
638 |
-
|
639 |
return state, model_selector, chatbot, textbox, send_btn, button_row, parameter_row
|
640 |
|
641 |
|
@@ -643,16 +595,22 @@ def ft_get_job_data():
|
|
643 |
running = 0
|
644 |
res_lst = []
|
645 |
try:
|
646 |
-
r = requests.get(ft_list_job_url, timeout=8)
|
|
|
|
|
|
|
647 |
for d in r.json():
|
648 |
if isinstance(d['status'], str) and d['status'].lower() == "running":
|
649 |
-
running += 1
|
650 |
-
|
|
|
|
|
|
|
|
|
651 |
res_lst = sorted(res_lst,key=(lambda x:x[2]), reverse=True)
|
652 |
res_lst = sorted(res_lst,key=(lambda x:x[5]), reverse=True)
|
653 |
except requests.exceptions.RequestException:
|
654 |
logger.info(f"Get job list fail")
|
655 |
-
res_lst.append([])
|
656 |
return res_lst, running
|
657 |
|
658 |
|
@@ -668,7 +626,7 @@ def ft_cease_click(ft_console):
|
|
668 |
def console_generator(addr, sleep_time):
|
669 |
total_str = ""
|
670 |
ws = websocket.WebSocket()
|
671 |
-
ws.connect(addr)
|
672 |
while True:
|
673 |
try:
|
674 |
new_str = ws.recv()
|
@@ -691,7 +649,7 @@ def ft_submit_click(ft_latest_running_cnt, ft_user_name, ft_model, ft_dataset_na
|
|
691 |
res_lst, running = ft_get_job_data()
|
692 |
return res_lst, running, no_change_textbox
|
693 |
if ft_latest_running_cnt < int(allow_running):
|
694 |
-
midware_header = {
|
695 |
hps_json = {
|
696 |
"epochs": str(ft_epochs),
|
697 |
"train_batch_size": str(ft_train_batch_size),
|
@@ -702,10 +660,9 @@ def ft_submit_click(ft_latest_running_cnt, ft_user_name, ft_model, ft_dataset_na
|
|
702 |
"model_max_length": str(ft_model_max_length)
|
703 |
}
|
704 |
json_data = {
|
705 |
-
"dataset": ft_dataset_name,
|
706 |
"model": ft_model,
|
707 |
"parameter": hps_json,
|
708 |
-
"secret": ft_token,
|
709 |
"username": ft_user_name
|
710 |
}
|
711 |
try:
|
@@ -737,7 +694,7 @@ def ft_show_click(ft_selected_row_data):
|
|
737 |
def ft_remove_click(ft_selected_row_data, ft_token):
|
738 |
status = ft_selected_row_data[5]
|
739 |
if isinstance(status, str) and status.lower() == "running":
|
740 |
-
r = requests.delete(ft_remove_job_url + ft_selected_row_data[0],
|
741 |
if r.status_code == 200:
|
742 |
gr.Info("Remove success.")
|
743 |
else:
|
@@ -760,6 +717,13 @@ def ft_jobs_info_select(ft_jobs_info, evt: gr.SelectData):
|
|
760 |
return [selected_row, no_change_dropdown, no_change_dropdown, no_change_slider, no_change_textbox, no_change_textbox, no_change_textbox, no_change_textbox, no_change_textbox, no_change_textbox]
|
761 |
|
762 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
763 |
def build_demo(models):
|
764 |
with gr.Blocks(
|
765 |
title="Vicuna Test",
|
|
|
50 |
enable_btn = gr.Button.update(interactive=True)
|
51 |
disable_btn = gr.Button.update(interactive=False)
|
52 |
|
53 |
+
enable_moderation = True if os.environ.get('enable_moderation', default='False')=="True" else False
|
54 |
+
concurrency_count = int(os.environ.get('concurrency_count', default='10'))
|
55 |
+
model_list_mode = os.environ.get('model_list_mode', default='reload')
|
56 |
+
|
57 |
+
midware_url = os.environ.get('midware_url', default='')
|
58 |
+
preset_token = os.environ.get('preset_token', default='')
|
59 |
+
worker_addr = os.environ.get('worker_addr', default='')
|
60 |
+
|
61 |
+
allow_running = int(os.environ.get('allow_running', default='1'))
|
62 |
+
ft_list_job_url = os.environ.get('ft_list_job_url', default='')
|
63 |
+
ft_submit_job_url = os.environ.get('ft_submit_job_url', default='')
|
64 |
+
ft_remove_job_url = os.environ.get('ft_remove_job_url', default='')
|
65 |
+
ft_console_log_url = os.environ.get('ft_console_log_url', default='')
|
66 |
+
|
67 |
+
dataset_sample = {
|
68 |
+
"english": {
|
69 |
+
"train": ["abcdef"],
|
70 |
+
"valid": ["zxcvbn"]
|
71 |
+
},
|
72 |
+
"cat": {
|
73 |
+
"train": ["aaaaaa"],
|
74 |
+
"valid": ["bbbbbb"]
|
75 |
+
}
|
76 |
+
}
|
77 |
|
|
|
|
|
|
|
78 |
|
79 |
+
dataset_to_midware_name = {
|
80 |
+
"english": "english",
|
81 |
+
"cat": "cat",
|
82 |
+
"dog": "dog",
|
83 |
+
"bird": "bird"
|
84 |
+
}
|
85 |
|
86 |
+
hps_keys = ["epochs", "train_batch_size", "eval_batch_size", "gradient_accumulation_steps", "learning_rate", "weight_decay", "model_max_length"]
|
|
|
|
|
|
|
|
|
87 |
|
88 |
|
89 |
+
headers = {"User-Agent": "FastChat Client", "PRIVATE-TOKEN": preset_token}
|
90 |
|
91 |
learn_more_md = """
|
92 |
### License
|
|
|
162 |
|
163 |
|
164 |
def get_model_list(midware_url):
|
165 |
+
setted_model_order = {
|
166 |
+
"vicuna-7b-v1.5-16k": 10,
|
167 |
+
"vicuna-13b-v1.5": 90,
|
168 |
+
}
|
169 |
try:
|
170 |
+
ret = requests.get(midware_url, headers={"PRIVATE-TOKEN": preset_token}, timeout=5)
|
171 |
+
if "code" in ret.json() and "invalid" in ret.json()["code"]:
|
172 |
+
gr.Warning("Invalid preset token.")
|
173 |
+
models = ["CANNOT GET MODEL"]
|
174 |
+
else:
|
175 |
+
models = ret.json()["data"]
|
176 |
except requests.exceptions.RequestException:
|
177 |
models = ["CANNOT GET MODEL"]
|
178 |
+
models = sorted(models, key=lambda x: setted_model_order.get(x, 100))
|
|
|
|
|
179 |
logger.info(f"Models: {models}")
|
180 |
return models
|
181 |
|
|
|
215 |
|
216 |
return load_demo_single(models, url_params)
|
217 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
218 |
def regenerate(state, request: gr.Request):
|
219 |
logger.info(f"regenerate. ip: {request.client.host}")
|
220 |
state.conv.update_last_message(None)
|
221 |
+
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 2
|
222 |
|
223 |
|
224 |
def clear_history(request: gr.Request):
|
225 |
logger.info(f"clear_history. ip: {request.client.host}")
|
226 |
state = None
|
227 |
+
return (state, [], "") + (disable_btn,) * 2
|
228 |
|
229 |
|
230 |
def add_text(state, model_selector, text, request: gr.Request):
|
|
|
236 |
|
237 |
if len(text) <= 0:
|
238 |
state.skip_next = True
|
239 |
+
return (state, state.to_gradio_chatbot(), "") + (no_change_btn,) * 2
|
240 |
|
241 |
if ip_expiration_dict[ip] < time.time():
|
242 |
logger.info(f"inactive. ip: {request.client.host}. text: {text}")
|
243 |
state.skip_next = True
|
244 |
+
return (state, state.to_gradio_chatbot(), INACTIVE_MSG) + (no_change_btn,) * 2
|
245 |
|
246 |
if enable_moderation:
|
247 |
flagged = violates_moderation(text)
|
|
|
250 |
state.skip_next = True
|
251 |
return (state, state.to_gradio_chatbot(), MODERATION_MSG) + (
|
252 |
no_change_btn,
|
253 |
+
) * 2
|
254 |
|
255 |
conv = state.conv
|
256 |
if (len(conv.messages) - conv.offset) // 2 >= CONVERSATION_TURN_LIMIT:
|
|
|
258 |
state.skip_next = True
|
259 |
return (state, state.to_gradio_chatbot(), CONVERSATION_LIMIT_MSG) + (
|
260 |
no_change_btn,
|
261 |
+
) * 2
|
262 |
|
263 |
text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off
|
264 |
conv.append_message(conv.roles[0], text)
|
265 |
conv.append_message(conv.roles[1], None)
|
266 |
+
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 2
|
267 |
|
268 |
|
269 |
def post_process_code(code):
|
|
|
325 |
if state.skip_next:
|
326 |
# This generate call is skipped due to invalid inputs
|
327 |
state.skip_next = False
|
328 |
+
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 2
|
329 |
return
|
330 |
|
331 |
conv, model_name = state.conv, state.model_name
|
|
|
352 |
yield (
|
353 |
state,
|
354 |
state.to_gradio_chatbot(),
|
|
|
|
|
|
|
355 |
enable_btn,
|
356 |
enable_btn,
|
357 |
)
|
|
|
379 |
)
|
380 |
|
381 |
conv.update_last_message("▌")
|
382 |
+
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 2
|
383 |
|
384 |
try:
|
385 |
for data in stream_iter:
|
|
|
389 |
output = post_process_code(output)
|
390 |
output = str_filter(output)
|
391 |
conv.update_last_message(output + "▌")
|
392 |
+
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 2
|
393 |
else:
|
394 |
output = data["text"] + f"\n\n(error_code: {data['error_code']})"
|
395 |
conv.update_last_message(output)
|
396 |
yield (state, state.to_gradio_chatbot()) + (
|
|
|
|
|
|
|
397 |
enable_btn,
|
398 |
enable_btn,
|
399 |
)
|
|
|
405 |
f"(error_code: {ErrorCode.GRADIO_REQUEST_ERROR}, {e})"
|
406 |
)
|
407 |
yield (state, state.to_gradio_chatbot()) + (
|
|
|
|
|
|
|
408 |
enable_btn,
|
409 |
enable_btn,
|
410 |
)
|
|
|
415 |
f"(error_code: {ErrorCode.GRADIO_STREAM_UNKNOWN_ERROR}, {e})"
|
416 |
)
|
417 |
yield (state, state.to_gradio_chatbot()) + (
|
|
|
|
|
|
|
418 |
enable_btn,
|
419 |
enable_btn,
|
420 |
)
|
|
|
422 |
|
423 |
# Delete "▌"
|
424 |
conv.update_last_message(conv.messages[-1][-1][:-1])
|
425 |
+
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 2
|
426 |
|
427 |
finish_tstamp = time.time()
|
428 |
logger.info(f"{output}")
|
|
|
502 |
with gr.Column():
|
503 |
with gr.Tab("🧠 模型对话 Dialog"):
|
504 |
state = gr.State()
|
|
|
505 |
with gr.Row(elem_id="model_selector_row"):
|
506 |
model_selector = gr.Dropdown(
|
507 |
choices=models,
|
|
|
529 |
send_btn = gr.Button(value="Send", visible=False)
|
530 |
|
531 |
with gr.Row(visible=False) as button_row:
|
|
|
|
|
|
|
532 |
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
|
533 |
clear_btn = gr.Button(value="🗑️ Clear history", interactive=False)
|
534 |
+
gr.Examples(
|
535 |
+
examples=["如何变得富有?", "你能用Python写一段快速排序吗?", "How to be rich?", "Can you write a quicksort code in Python?"],
|
536 |
+
inputs=textbox,
|
537 |
+
)
|
538 |
with gr.Accordion("Parameters", open=False, visible=False) as parameter_row:
|
539 |
temperature = gr.Slider(
|
540 |
minimum=0.0,
|
|
|
564 |
gr.Markdown(learn_more_md)
|
565 |
|
566 |
# Register listeners
|
567 |
+
btn_list = [regenerate_btn, clear_btn]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
568 |
regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
|
569 |
bot_response,
|
570 |
[state, temperature, top_p, max_output_tokens],
|
|
|
588 |
[state, temperature, top_p, max_output_tokens],
|
589 |
[state, chatbot] + btn_list,
|
590 |
)
|
|
|
|
|
591 |
return state, model_selector, chatbot, textbox, send_btn, button_row, parameter_row
|
592 |
|
593 |
|
|
|
595 |
running = 0
|
596 |
res_lst = []
|
597 |
try:
|
598 |
+
r = requests.get(ft_list_job_url, headers={"PRIVATE-TOKEN": preset_token}, timeout=8)
|
599 |
+
if "code" in r.json() and "invalid" in r.json()["code"]:
|
600 |
+
gr.Warning("Invalid preset token.")
|
601 |
+
return res_lst, running
|
602 |
for d in r.json():
|
603 |
if isinstance(d['status'], str) and d['status'].lower() == "running":
|
604 |
+
running += 1
|
605 |
+
hps = dict()
|
606 |
+
for key in hps_keys:
|
607 |
+
if key in d['parameter']:
|
608 |
+
hps[key] = d['parameter'][key]
|
609 |
+
res_lst.append([d['jobName'], d['username'], d['created_at'], d['model'], d['dataset'], d['status'], json.dumps(hps)])
|
610 |
res_lst = sorted(res_lst,key=(lambda x:x[2]), reverse=True)
|
611 |
res_lst = sorted(res_lst,key=(lambda x:x[5]), reverse=True)
|
612 |
except requests.exceptions.RequestException:
|
613 |
logger.info(f"Get job list fail")
|
|
|
614 |
return res_lst, running
|
615 |
|
616 |
|
|
|
626 |
def console_generator(addr, sleep_time):
|
627 |
total_str = ""
|
628 |
ws = websocket.WebSocket()
|
629 |
+
ws.connect(addr, header={"PRIVATE-TOKEN": preset_token})
|
630 |
while True:
|
631 |
try:
|
632 |
new_str = ws.recv()
|
|
|
649 |
res_lst, running = ft_get_job_data()
|
650 |
return res_lst, running, no_change_textbox
|
651 |
if ft_latest_running_cnt < int(allow_running):
|
652 |
+
midware_header = {"FINETUNE-SECRET": ft_token, "PRIVATE-TOKEN": preset_token}
|
653 |
hps_json = {
|
654 |
"epochs": str(ft_epochs),
|
655 |
"train_batch_size": str(ft_train_batch_size),
|
|
|
660 |
"model_max_length": str(ft_model_max_length)
|
661 |
}
|
662 |
json_data = {
|
663 |
+
"dataset": dataset_to_midware_name[ft_dataset_name],
|
664 |
"model": ft_model,
|
665 |
"parameter": hps_json,
|
|
|
666 |
"username": ft_user_name
|
667 |
}
|
668 |
try:
|
|
|
694 |
def ft_remove_click(ft_selected_row_data, ft_token):
|
695 |
status = ft_selected_row_data[5]
|
696 |
if isinstance(status, str) and status.lower() == "running":
|
697 |
+
r = requests.delete(ft_remove_job_url + ft_selected_row_data[0], headers={'FINETUNE-SECRET': ft_token, "PRIVATE-TOKEN": preset_token})
|
698 |
if r.status_code == 200:
|
699 |
gr.Info("Remove success.")
|
700 |
else:
|
|
|
717 |
return [selected_row, no_change_dropdown, no_change_dropdown, no_change_slider, no_change_textbox, no_change_textbox, no_change_textbox, no_change_textbox, no_change_textbox, no_change_textbox]
|
718 |
|
719 |
|
720 |
+
def ft_dataset_preview_click(ft_dataset_name):
|
721 |
+
value = dataset_sample.get(ft_dataset_name, {})
|
722 |
+
return gr.JSON.update(value=value, visible=True)
|
723 |
+
|
724 |
+
def ft_hide_dataset_click():
|
725 |
+
return gr.JSON.update(visible=False)
|
726 |
+
|
727 |
def build_demo(models):
|
728 |
with gr.Blocks(
|
729 |
title="Vicuna Test",
|