Spaces:
Runtime error
Runtime error
Commit
·
6d2a41d
1
Parent(s):
b1f3eeb
Update app.py
Browse files
app.py
CHANGED
@@ -7,9 +7,10 @@ import random
|
|
7 |
import time
|
8 |
import uuid
|
9 |
import websocket
|
10 |
-
|
11 |
import gradio as gr
|
12 |
import requests
|
|
|
13 |
|
14 |
from fastchat.conversation import SeparatorStyle
|
15 |
from fastchat.constants import (
|
@@ -42,41 +43,40 @@ from fastchat.utils import (
|
|
42 |
|
43 |
logger = build_logger("gradio_web_server", "gradio_web_server.log")
|
44 |
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
no_change_btn = gr.Button.update()
|
49 |
enable_btn = gr.Button.update(interactive=True)
|
50 |
disable_btn = gr.Button.update(interactive=False)
|
51 |
|
52 |
-
|
53 |
-
|
54 |
-
|
|
|
55 |
|
56 |
-
|
57 |
-
|
58 |
-
|
|
|
|
|
|
|
|
|
59 |
|
60 |
-
# allow_running = 5
|
61 |
-
# ft_list_job_url = "http://49.0.247.41:30139/api/v1/job"
|
62 |
-
# ft_submit_job_url = "http://49.0.247.41:30139/api/v1/job"
|
63 |
-
# ft_remove_job_url = "http://49.0.247.41:30139/api/v1/job/"
|
64 |
-
# ft_console_log_url = "ws://49.0.247.41:30139/api/v1/log/"
|
65 |
|
|
|
|
|
|
|
66 |
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
ft_list_job_url = os.environ.get('ft_console_log_url', default='')
|
77 |
-
ft_submit_job_url = os.environ.get('ft_console_log_url', default='')
|
78 |
-
ft_remove_job_url = os.environ.get('ft_console_log_url', default='')
|
79 |
-
ft_console_log_url = os.environ.get('ft_console_log_url', default='')
|
80 |
|
81 |
|
82 |
headers = {"User-Agent": "FastChat Client", "PRIVATE-TOKEN": chat_token}
|
@@ -89,6 +89,39 @@ The service is a research preview intended for non-commercial use only, subject
|
|
89 |
ip_expiration_dict = defaultdict(lambda: 0)
|
90 |
|
91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
class State:
|
93 |
def __init__(self, model_name):
|
94 |
self.conv = get_conversation_template(model_name)
|
@@ -122,26 +155,17 @@ def get_conv_log_filename():
|
|
122 |
|
123 |
|
124 |
def get_model_list(midware_url):
|
125 |
-
|
126 |
-
|
|
|
|
|
|
|
127 |
|
128 |
priority = {k: f"___{i:02d}" for i, k in enumerate(model_info)}
|
129 |
models.sort(key=lambda x: priority.get(x, x))
|
130 |
logger.info(f"Models: {models}")
|
131 |
return models
|
132 |
|
133 |
-
df_headers = [
|
134 |
-
"Job Name",
|
135 |
-
"Create By",
|
136 |
-
"Create At",
|
137 |
-
"Model",
|
138 |
-
"Dataset",
|
139 |
-
"Status",
|
140 |
-
"HPs"
|
141 |
-
]
|
142 |
-
values= [["task111", "Tom", "20230829 14:30", "Vicuna", "cat", "Done", "{\"epochs\": \"1\", \"train_batch_size\": \"2\",\"eval_batch_size\": \"3\" ,\"train_batch_size\": \"2\",\"train_batch_size\": \"2\"}"],
|
143 |
-
["task222", "Jerry", "20230829 15:30", "Vicuna", "dog", "Doing", "{\"train_batch_size\": \"2\", \"train_batch_size\": \"2\",\"train_batch_size\": \"2\" ,\"train_batch_size\": \"2\",\"train_batch_size\": \"2\"}"],
|
144 |
-
["task333", "Somebody", "20230830 15:30", "Vicuna", "cat", "Error", "{\"train_batch_size\": \"2\", \"train_batch_size\": \"2\",\"train_batch_size\": \"2\" ,\"train_batch_size\": \"2\",\"train_batch_size\": \"2\"}"]]
|
145 |
|
146 |
def load_demo_single(models, url_params):
|
147 |
selected_model = models[0] if len(models) > 0 else ""
|
@@ -283,9 +307,9 @@ def model_worker_stream_iter(
|
|
283 |
):
|
284 |
# Make requests
|
285 |
gen_params = {
|
286 |
-
"
|
287 |
-
"
|
288 |
-
"temperature":
|
289 |
"repetition_penalty": repetition_penalty,
|
290 |
"top_p": top_p,
|
291 |
"max_new_tokens": max_new_tokens,
|
@@ -384,6 +408,7 @@ def bot_response(state, temperature, top_p, max_new_tokens, request: gr.Request)
|
|
384 |
output = data["text"].strip()
|
385 |
if "vicuna" in model_name:
|
386 |
output = post_process_code(output)
|
|
|
387 |
conv.update_last_message(output + "▌")
|
388 |
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
389 |
else:
|
@@ -504,42 +529,9 @@ def get_model_description_md(models):
|
|
504 |
|
505 |
|
506 |
def build_single_model_ui(models, add_promotion_links=False):
|
507 |
-
global_notice_markdown = f"""
|
508 |
-
# Vicuna runs on Ascend
|
509 |
-
## What does this space do
|
510 |
-
This is a space that providing a demo for users to try vicuna big model on Ascend 910B hardware. Using this space you can chat/finetune with vicuna.
|
511 |
-
## What is changed
|
512 |
-
We modified some opensource libraries to make thems run well on Ascend. It includes: fastchat, torch_npu, deepspeed-npu.
|
513 |
-
## What is not changed
|
514 |
-
1. The Vicuna model is not changed. All the model runs here are from lmsys.
|
515 |
-
2. All the libraries are not changed, excepet the ones mentioned above.
|
516 |
-
## What hardware are used
|
517 |
-
1. This web page is hosted on huggingface wih the free resource(2U16G)
|
518 |
-
2. The chat/fietune function is hosted on a Kunpeng920(CPU) + Asend 910B(NPU) machine.
|
519 |
-
## Useful link
|
520 |
-
- [Ascend home page](https://www.hiascend.com/)
|
521 |
-
- [Ascend related library](https://github.com/ascend)
|
522 |
-
"""
|
523 |
-
|
524 |
-
dialog_notice_markdown = f"""
|
525 |
-
# Chat with Vicuna (Ascend Backended)
|
526 |
-
|
527 |
-
### Notice
|
528 |
-
This space is originally from [FastChat](https://github.com/lm-sys/FastChat), but the backend computational hardware is Ascend.
|
529 |
-
|
530 |
-
### Choose a model to chat with
|
531 |
-
"""
|
532 |
-
finetune_notice_markdown = f"""
|
533 |
-
# Finetune with Ascend
|
534 |
-
### Finetuning with Ascend
|
535 |
-
### Access to Finetuning
|
536 |
-
Because of the limited computational resources, you will need a token to finetune models. Send an E-mail to [email protected] to apply for a token.
|
537 |
-
"""
|
538 |
-
gr.Markdown(global_notice_markdown)
|
539 |
with gr.Column():
|
540 |
with gr.Tab("🧠 模型对话 Dialog"):
|
541 |
state = gr.State()
|
542 |
-
gr.Markdown(dialog_notice_markdown, elem_id="dialog_notice_markdown")
|
543 |
|
544 |
with gr.Row(elem_id="model_selector_row"):
|
545 |
model_selector = gr.Dropdown(
|
@@ -642,134 +634,120 @@ Because of the limited computational resources, you will need a token to finetun
|
|
642 |
[state, temperature, top_p, max_output_tokens],
|
643 |
[state, chatbot] + btn_list,
|
644 |
)
|
645 |
-
|
646 |
-
gr.Markdown(finetune_notice_markdown)
|
647 |
-
ft_selected_row_data = gr.State()
|
648 |
-
ft_latest_running_cnt = gr.State()
|
649 |
-
df_headers = [
|
650 |
-
"Job Name",
|
651 |
-
"Create By",
|
652 |
-
"Create At",
|
653 |
-
"Model",
|
654 |
-
"Dataset",
|
655 |
-
"Status",
|
656 |
-
"HPs"
|
657 |
-
]
|
658 |
-
values= [["task111", "Tom", "20230829 14:30", "Vicuna", "cat", "Done", "{\"epochs\": \"1\", \"train_batch_size\": \"2\",\"eval_batch_size\": \"3\" ,\"train_batch_size\": \"2\",\"train_batch_size\": \"2\"}"],
|
659 |
-
["task222", "Jerry", "20230829 15:30", "Vicuna", "dog", "Doing", "{\"train_batch_size\": \"2\", \"train_batch_size\": \"2\",\"train_batch_size\": \"2\" ,\"train_batch_size\": \"2\",\"train_batch_size\": \"2\"}"],
|
660 |
-
["task333", "Somebody", "20230830 15:30", "Vicuna", "cat", "Error", "{\"train_batch_size\": \"2\", \"train_batch_size\": \"2\",\"train_batch_size\": \"2\" ,\"train_batch_size\": \"2\",\"train_batch_size\": \"2\"}"]]
|
661 |
-
ft_jobs_info = gr.Dataframe(
|
662 |
-
headers=df_headers,
|
663 |
-
type='array',
|
664 |
-
datatype=["str", "str", "str", "str", "str", "str", "str"],
|
665 |
-
value=values,
|
666 |
-
interactive=False,
|
667 |
-
)
|
668 |
-
with gr.Row():
|
669 |
-
ft_show_btn = gr.Button(value="Show Logs")
|
670 |
-
ft_refresh_btn = gr.Button(value="Refresh")
|
671 |
-
ft_remove_btn = gr.Button(value="Remove Running")
|
672 |
-
with gr.Row():
|
673 |
-
with gr.Column(scale=1):
|
674 |
-
ft_user_name = gr.Textbox(value="", label="User Name")
|
675 |
-
ft_model = gr.Dropdown(["vicuna-7b-v1.5-16k"], value="vicuna-7b-v1.5-16k", label="Model", interactive=True)
|
676 |
-
ft_dataset_name = gr.Dropdown(["cat", "dog", "bird"], value="cat", label="Dataset", interactive=True)
|
677 |
-
ft_token = gr.Textbox(value="", label="Finetune token")
|
678 |
-
ft_submit_btn = gr.Button(value="Submit")
|
679 |
-
ft_cease_btn = gr.Button(value="Cease Streaming")
|
680 |
-
with gr.Column(scale=1):
|
681 |
-
ft_epochs = gr.Slider(
|
682 |
-
minimum=1,
|
683 |
-
maximum=3,
|
684 |
-
value=3,
|
685 |
-
step=1,
|
686 |
-
interactive=True,
|
687 |
-
label="epochs",
|
688 |
-
)
|
689 |
-
ft_train_batch_size = gr.Textbox(value="2", label="train batch size", interactive=True)
|
690 |
-
ft_eval_batch_size = gr.Textbox(value="2", label="eval batch size", interactive=True)
|
691 |
-
ft_gradient_accumulation_steps = gr.Textbox(value="16", label="gradient accumulation steps", interactive=True)
|
692 |
-
ft_learning_rate = gr.Textbox(value="2e-5", label="learning rate", interactive=True)
|
693 |
-
ft_weight_decay = gr.Textbox(value="0.", label="weight decay", interactive=True)
|
694 |
-
ft_model_max_length = gr.Textbox(value="1024", label="model max length", interactive=True)
|
695 |
-
with gr.Column(scale=8):
|
696 |
-
ft_console = gr.Textbox(value="", lines=28, label="Console", interactive=False)
|
697 |
-
ft_jobs_info.select(ft_jobs_info_select, [ft_jobs_info, ft_model, ft_dataset_name, ft_epochs, ft_train_batch_size, ft_eval_batch_size, ft_gradient_accumulation_steps, ft_learning_rate, ft_weight_decay, ft_model_max_length], [ft_selected_row_data, ft_model, ft_dataset_name, ft_epochs, ft_train_batch_size, ft_eval_batch_size, ft_gradient_accumulation_steps, ft_learning_rate, ft_weight_decay, ft_model_max_length])
|
698 |
|
699 |
-
|
700 |
-
ft_remove_btn.click(ft_remove_click, [ft_selected_row_data, ft_token], ft_console)
|
701 |
-
ft_refresh_btn.click(ft_refresh_click, None, [ft_jobs_info, ft_latest_running_cnt])
|
702 |
-
|
703 |
-
ft_submit_evt = ft_submit_btn.click(ft_submit_click, [ft_latest_running_cnt, ft_user_name, ft_model, ft_dataset_name, ft_token, ft_epochs, ft_train_batch_size, ft_eval_batch_size, ft_gradient_accumulation_steps, ft_learning_rate, ft_weight_decay, ft_model_max_length], [ft_jobs_info, ft_latest_running_cnt, ft_console])
|
704 |
-
ft_cease_btn.click(ft_cease_click, ft_console, ft_console, cancels=[ft_submit_evt, ft_show_evt])
|
705 |
-
|
706 |
-
return state, model_selector, chatbot, textbox, send_btn, button_row, parameter_row, ft_jobs_info, ft_latest_running_cnt
|
707 |
|
708 |
|
709 |
def ft_get_job_data():
|
710 |
-
response = requests.get(ft_list_job_url)
|
711 |
-
res_lst = []
|
712 |
running = 0
|
713 |
-
|
714 |
-
|
715 |
-
|
716 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
717 |
return res_lst, running
|
718 |
|
719 |
|
720 |
def ft_refresh_click():
|
721 |
return ft_get_job_data()
|
722 |
|
|
|
723 |
def ft_cease_click(ft_console):
|
724 |
output = ft_console + "\n" + "** Streaming output ceased by user **"
|
725 |
return output
|
726 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
727 |
def ft_submit_click(ft_latest_running_cnt, ft_user_name, ft_model, ft_dataset_name, ft_token, ft_epochs, ft_train_batch_size, ft_eval_batch_size, ft_gradient_accumulation_steps, ft_learning_rate, ft_weight_decay, ft_model_max_length):
|
728 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
729 |
midware_header = {'Content-Type': 'application/json'}
|
730 |
hps_json = {
|
731 |
"epochs": str(ft_epochs),
|
732 |
-
"train_batch_size": ft_train_batch_size,
|
733 |
-
"eval_batch_size": ft_eval_batch_size,
|
734 |
-
"gradient_accumulation_steps": ft_gradient_accumulation_steps,
|
735 |
-
"learning_rate": ft_learning_rate,
|
736 |
-
"weight_decay": ft_weight_decay,
|
737 |
-
"model_max_length": ft_model_max_length
|
738 |
}
|
739 |
json_data = {
|
740 |
"dataset": ft_dataset_name,
|
741 |
"model": ft_model,
|
742 |
-
"parameter":
|
743 |
"secret": ft_token,
|
744 |
"username": ft_user_name
|
745 |
}
|
746 |
-
|
747 |
-
|
748 |
-
|
749 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
750 |
else:
|
751 |
-
gr.
|
752 |
-
|
|
|
|
|
753 |
|
754 |
def ft_show_click(ft_selected_row_data):
|
755 |
-
s
|
756 |
-
|
757 |
-
|
758 |
-
time.sleep(0.01)
|
759 |
-
yield s
|
760 |
|
761 |
def ft_remove_click(ft_selected_row_data, ft_token):
|
762 |
status = ft_selected_row_data[5]
|
763 |
if isinstance(status, str) and status.lower() == "running":
|
764 |
-
|
765 |
-
|
|
|
766 |
else:
|
767 |
-
|
768 |
else:
|
769 |
-
gr.
|
770 |
-
return
|
771 |
|
772 |
-
|
|
|
773 |
selected_row = ft_jobs_info[evt.index[0]]
|
774 |
if evt.index[1] in (3, 4, 6):
|
775 |
try:
|
@@ -779,11 +757,12 @@ def ft_jobs_info_select(ft_jobs_info, ft_model, ft_dataset_name, ft_epochs, ft_t
|
|
779 |
return [selected_row, selected_row[3], selected_row[4], Hps.get('epochs', ''), Hps.get('train_batch_size', ''), Hps.get('eval_batch_size', ''),
|
780 |
Hps.get('gradient_accumulation_steps', ''), Hps.get('learning_rate', ''), Hps.get('weight_decay', ''), Hps.get('model_max_length', '')]
|
781 |
else:
|
782 |
-
return [selected_row,
|
|
|
783 |
|
784 |
def build_demo(models):
|
785 |
with gr.Blocks(
|
786 |
-
title="
|
787 |
theme=gr.themes.Base(),
|
788 |
css = block_css
|
789 |
) as demo:
|
@@ -796,8 +775,6 @@ def build_demo(models):
|
|
796 |
send_btn,
|
797 |
button_row,
|
798 |
parameter_row,
|
799 |
-
ft_jobs_info,
|
800 |
-
ft_latest_running_cnt,
|
801 |
) = build_single_model_ui(models)
|
802 |
|
803 |
if model_list_mode not in ["once", "reload"]:
|
@@ -816,14 +793,6 @@ def build_demo(models):
|
|
816 |
],
|
817 |
_js=get_window_url_params_js,
|
818 |
)
|
819 |
-
demo.load(
|
820 |
-
ft_get_job_data,
|
821 |
-
None,
|
822 |
-
[
|
823 |
-
ft_jobs_info,
|
824 |
-
ft_latest_running_cnt,
|
825 |
-
]
|
826 |
-
)
|
827 |
|
828 |
return demo
|
829 |
|
|
|
7 |
import time
|
8 |
import uuid
|
9 |
import websocket
|
10 |
+
from websocket import WebSocketConnectionClosedException
|
11 |
import gradio as gr
|
12 |
import requests
|
13 |
+
import logging
|
14 |
|
15 |
from fastchat.conversation import SeparatorStyle
|
16 |
from fastchat.constants import (
|
|
|
43 |
|
44 |
logger = build_logger("gradio_web_server", "gradio_web_server.log")
|
45 |
|
46 |
+
no_change_dropdown = gr.Dropdown.update()
|
47 |
+
no_change_slider = gr.Slider.update()
|
48 |
+
no_change_textbox = gr.Textbox.update()
|
49 |
no_change_btn = gr.Button.update()
|
50 |
enable_btn = gr.Button.update(interactive=True)
|
51 |
disable_btn = gr.Button.update(interactive=False)
|
52 |
|
53 |
+
enable_moderation = False
|
54 |
+
concurrency_count = 10
|
55 |
+
model_list_mode = 'reload'
|
56 |
+
allow_running = 1
|
57 |
|
58 |
+
midware_url = "http://119.8.43.169:8080/api/v1/chat/models"
|
59 |
+
worker_addr = 'http://119.8.43.169:8080/api/v1/chat'
|
60 |
+
chat_token = 'abc'
|
61 |
+
ft_list_job_url = "http://119.8.43.169:8080/api/v1/job"
|
62 |
+
ft_submit_job_url = "http://119.8.43.169:8080/api/v1/job"
|
63 |
+
ft_remove_job_url = "http://119.8.43.169:8080/api/v1/job/"
|
64 |
+
ft_console_log_url = "ws://119.8.43.169:8080/api/v1/log/"
|
65 |
|
|
|
|
|
|
|
|
|
|
|
66 |
|
67 |
+
# enable_moderation = True if os.environ.get('enable_moderation', default='False')=="True" else False
|
68 |
+
# concurrency_count = int(os.environ.get('concurrency_count', default='10'))
|
69 |
+
# model_list_mode = os.environ.get('model_list_mode', default='reload')
|
70 |
|
71 |
+
# midware_url = os.environ.get('midware_url', default='')
|
72 |
+
# chat_token = os.environ.get('chat_token', default='')
|
73 |
+
# worker_addr = os.environ.get('worker_addr', default='')
|
74 |
|
75 |
+
# allow_running = int(os.environ.get('allow_running', default='1'))
|
76 |
+
# ft_list_job_url = os.environ.get('ft_list_job_url', default='')
|
77 |
+
# ft_submit_job_url = os.environ.get('ft_submit_job_url', default='')
|
78 |
+
# ft_remove_job_url = os.environ.get('ft_remove_job_url', default='')
|
79 |
+
# ft_console_log_url = os.environ.get('ft_console_log_url', default='')
|
|
|
|
|
|
|
|
|
80 |
|
81 |
|
82 |
headers = {"User-Agent": "FastChat Client", "PRIVATE-TOKEN": chat_token}
|
|
|
89 |
ip_expiration_dict = defaultdict(lambda: 0)
|
90 |
|
91 |
|
92 |
+
def is_legal_char(c):
|
93 |
+
if c.isalnum():
|
94 |
+
return True
|
95 |
+
if c in "!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏.":
|
96 |
+
return True
|
97 |
+
if c in '!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~':
|
98 |
+
return True
|
99 |
+
return False
|
100 |
+
|
101 |
+
|
102 |
+
def str_filter(s):
|
103 |
+
for _ in range(2):
|
104 |
+
if len(s) > 0 and (not is_legal_char(s[-1])):
|
105 |
+
s = s[:-1]
|
106 |
+
return s
|
107 |
+
|
108 |
+
|
109 |
+
def str_not_int(s):
|
110 |
+
try:
|
111 |
+
int(s)
|
112 |
+
return False
|
113 |
+
except ValueError:
|
114 |
+
return True
|
115 |
+
|
116 |
+
|
117 |
+
def str_not_float(s):
|
118 |
+
try:
|
119 |
+
float(s)
|
120 |
+
return False
|
121 |
+
except ValueError:
|
122 |
+
return True
|
123 |
+
|
124 |
+
|
125 |
class State:
|
126 |
def __init__(self, model_name):
|
127 |
self.conv = get_conversation_template(model_name)
|
|
|
155 |
|
156 |
|
157 |
def get_model_list(midware_url):
|
158 |
+
try:
|
159 |
+
ret = requests.get(midware_url, headers={"PRIVATE-TOKEN": chat_token}, timeout=5)
|
160 |
+
models = ret.json()["data"]
|
161 |
+
except requests.exceptions.RequestException:
|
162 |
+
models = ["CANNOT GET MODEL"]
|
163 |
|
164 |
priority = {k: f"___{i:02d}" for i, k in enumerate(model_info)}
|
165 |
models.sort(key=lambda x: priority.get(x, x))
|
166 |
logger.info(f"Models: {models}")
|
167 |
return models
|
168 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
|
170 |
def load_demo_single(models, url_params):
|
171 |
selected_model = models[0] if len(models) > 0 else ""
|
|
|
307 |
):
|
308 |
# Make requests
|
309 |
gen_params = {
|
310 |
+
"model_name": model_name,
|
311 |
+
"question": prompt,
|
312 |
+
"temperature": 1e-6,
|
313 |
"repetition_penalty": repetition_penalty,
|
314 |
"top_p": top_p,
|
315 |
"max_new_tokens": max_new_tokens,
|
|
|
408 |
output = data["text"].strip()
|
409 |
if "vicuna" in model_name:
|
410 |
output = post_process_code(output)
|
411 |
+
output = str_filter(output)
|
412 |
conv.update_last_message(output + "▌")
|
413 |
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
414 |
else:
|
|
|
529 |
|
530 |
|
531 |
def build_single_model_ui(models, add_promotion_links=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
532 |
with gr.Column():
|
533 |
with gr.Tab("🧠 模型对话 Dialog"):
|
534 |
state = gr.State()
|
|
|
535 |
|
536 |
with gr.Row(elem_id="model_selector_row"):
|
537 |
model_selector = gr.Dropdown(
|
|
|
634 |
[state, temperature, top_p, max_output_tokens],
|
635 |
[state, chatbot] + btn_list,
|
636 |
)
|
637 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
638 |
|
639 |
+
return state, model_selector, chatbot, textbox, send_btn, button_row, parameter_row
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
640 |
|
641 |
|
642 |
def ft_get_job_data():
|
|
|
|
|
643 |
running = 0
|
644 |
+
res_lst = []
|
645 |
+
try:
|
646 |
+
r = requests.get(ft_list_job_url, timeout=8)
|
647 |
+
for d in r.json():
|
648 |
+
if isinstance(d['status'], str) and d['status'].lower() == "running":
|
649 |
+
running += 1
|
650 |
+
res_lst.append([d['jobName'], d['username'], d['created_at'], d['model'], d['dataset'], d['status'], json.dumps(d['parameter'])])
|
651 |
+
res_lst = sorted(res_lst,key=(lambda x:x[2]), reverse=True)
|
652 |
+
res_lst = sorted(res_lst,key=(lambda x:x[5]), reverse=True)
|
653 |
+
except requests.exceptions.RequestException:
|
654 |
+
logger.info(f"Get job list fail")
|
655 |
+
res_lst.append([])
|
656 |
return res_lst, running
|
657 |
|
658 |
|
659 |
def ft_refresh_click():
|
660 |
return ft_get_job_data()
|
661 |
|
662 |
+
|
663 |
def ft_cease_click(ft_console):
|
664 |
output = ft_console + "\n" + "** Streaming output ceased by user **"
|
665 |
return output
|
666 |
|
667 |
+
|
668 |
+
def console_generator(addr, sleep_time):
|
669 |
+
total_str = ""
|
670 |
+
ws = websocket.WebSocket()
|
671 |
+
ws.connect(addr)
|
672 |
+
while True:
|
673 |
+
try:
|
674 |
+
new_str = ws.recv()
|
675 |
+
total_str = total_str + new_str
|
676 |
+
time.sleep(sleep_time)
|
677 |
+
yield total_str
|
678 |
+
except WebSocketConnectionClosedException:
|
679 |
+
ws.close()
|
680 |
+
break
|
681 |
+
ws.close()
|
682 |
+
|
683 |
+
|
684 |
def ft_submit_click(ft_latest_running_cnt, ft_user_name, ft_model, ft_dataset_name, ft_token, ft_epochs, ft_train_batch_size, ft_eval_batch_size, ft_gradient_accumulation_steps, ft_learning_rate, ft_weight_decay, ft_model_max_length):
|
685 |
+
if ft_user_name == "":
|
686 |
+
gr.Warning(f"Submit fail, empty username.")
|
687 |
+
res_lst, running = ft_get_job_data()
|
688 |
+
return res_lst, running, no_change_textbox
|
689 |
+
if str_not_int(ft_train_batch_size) or str_not_int(ft_eval_batch_size) or str_not_int(ft_gradient_accumulation_steps) or str_not_float(ft_learning_rate) or str_not_float(ft_weight_decay) or str_not_int(ft_model_max_length):
|
690 |
+
gr.Warning(f"Submit fail, check the types. [learning rate] and [weight decay] should be float, others HPs should be int.")
|
691 |
+
res_lst, running = ft_get_job_data()
|
692 |
+
return res_lst, running, no_change_textbox
|
693 |
+
if ft_latest_running_cnt < int(allow_running):
|
694 |
midware_header = {'Content-Type': 'application/json'}
|
695 |
hps_json = {
|
696 |
"epochs": str(ft_epochs),
|
697 |
+
"train_batch_size": str(ft_train_batch_size),
|
698 |
+
"eval_batch_size": str(ft_eval_batch_size),
|
699 |
+
"gradient_accumulation_steps": str(ft_gradient_accumulation_steps),
|
700 |
+
"learning_rate": str(ft_learning_rate),
|
701 |
+
"weight_decay": str(ft_weight_decay),
|
702 |
+
"model_max_length": str(ft_model_max_length)
|
703 |
}
|
704 |
json_data = {
|
705 |
"dataset": ft_dataset_name,
|
706 |
"model": ft_model,
|
707 |
+
"parameter": hps_json,
|
708 |
"secret": ft_token,
|
709 |
"username": ft_user_name
|
710 |
}
|
711 |
+
try:
|
712 |
+
r = requests.post(ft_submit_job_url, json=json_data, headers=midware_header, timeout=120)
|
713 |
+
job_name = r.json()["jobName"]
|
714 |
+
gr.Info(f"Job {job_name} submit success.")
|
715 |
+
res_lst, running = ft_get_job_data()
|
716 |
+
total_str = ""
|
717 |
+
for s in console_generator(ft_console_log_url + job_name, 1):
|
718 |
+
total_str = s
|
719 |
+
yield res_lst, running, s
|
720 |
+
res_lst, running = ft_get_job_data()
|
721 |
+
yield res_lst, running, total_str
|
722 |
+
except requests.exceptions.RequestException:
|
723 |
+
gr.Warning(f"Connection Failure.")
|
724 |
+
res_lst, running = ft_get_job_data()
|
725 |
+
return res_lst, running, ""
|
726 |
else:
|
727 |
+
gr.Warning(f"Only allow {str(allow_running)} job(s) running simultaneously, please wait.")
|
728 |
+
res_lst, running = ft_get_job_data()
|
729 |
+
return res_lst, running, no_change_textbox
|
730 |
+
|
731 |
|
732 |
def ft_show_click(ft_selected_row_data):
|
733 |
+
for s in console_generator(ft_console_log_url + ft_selected_row_data[0], 0.2):
|
734 |
+
yield s
|
735 |
+
|
|
|
|
|
736 |
|
737 |
def ft_remove_click(ft_selected_row_data, ft_token):
|
738 |
status = ft_selected_row_data[5]
|
739 |
if isinstance(status, str) and status.lower() == "running":
|
740 |
+
r = requests.delete(ft_remove_job_url + ft_selected_row_data[0], json={"secret": ft_token})
|
741 |
+
if r.status_code == 200:
|
742 |
+
gr.Info("Remove success.")
|
743 |
else:
|
744 |
+
gr.Warning(f"Remove fail. {r.status_code} {r.reason}.")
|
745 |
else:
|
746 |
+
gr.Warning("Remove fail. Can only remove a running job.")
|
747 |
+
return ft_get_job_data()
|
748 |
|
749 |
+
|
750 |
+
def ft_jobs_info_select(ft_jobs_info, evt: gr.SelectData):
|
751 |
selected_row = ft_jobs_info[evt.index[0]]
|
752 |
if evt.index[1] in (3, 4, 6):
|
753 |
try:
|
|
|
757 |
return [selected_row, selected_row[3], selected_row[4], Hps.get('epochs', ''), Hps.get('train_batch_size', ''), Hps.get('eval_batch_size', ''),
|
758 |
Hps.get('gradient_accumulation_steps', ''), Hps.get('learning_rate', ''), Hps.get('weight_decay', ''), Hps.get('model_max_length', '')]
|
759 |
else:
|
760 |
+
return [selected_row, no_change_dropdown, no_change_dropdown, no_change_slider, no_change_textbox, no_change_textbox, no_change_textbox, no_change_textbox, no_change_textbox, no_change_textbox]
|
761 |
+
|
762 |
|
763 |
def build_demo(models):
|
764 |
with gr.Blocks(
|
765 |
+
title="Vicuna (Ascend Backended)",
|
766 |
theme=gr.themes.Base(),
|
767 |
css = block_css
|
768 |
) as demo:
|
|
|
775 |
send_btn,
|
776 |
button_row,
|
777 |
parameter_row,
|
|
|
|
|
778 |
) = build_single_model_ui(models)
|
779 |
|
780 |
if model_list_mode not in ["once", "reload"]:
|
|
|
793 |
],
|
794 |
_js=get_window_url_params_js,
|
795 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
796 |
|
797 |
return demo
|
798 |
|