celestialli commited on
Commit
6d2a41d
·
1 Parent(s): b1f3eeb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +150 -181
app.py CHANGED
@@ -7,9 +7,10 @@ import random
7
  import time
8
  import uuid
9
  import websocket
10
-
11
  import gradio as gr
12
  import requests
 
13
 
14
  from fastchat.conversation import SeparatorStyle
15
  from fastchat.constants import (
@@ -42,41 +43,40 @@ from fastchat.utils import (
42
 
43
  logger = build_logger("gradio_web_server", "gradio_web_server.log")
44
 
45
- PRESET_ANSWERS = "刚到美国的时候,觉得美国人像傻子一样,到处都是漏洞。任何地方的厕所都有免费纸,有些人定期去扯很多回家,纸都不用买。快餐店的饮料,有的可以无限续杯,有些几个人买一份饮料,接回来灌到各自的杯子里;等等。尽管美国有许多“漏洞”,但作为超级大国,显然能带给人以无尽的故事与思考。我来分享一下哪些是去了美国才知道的事,主题主要围绕着生活、衣食住行、文化冲击、教育医疗等展开叙说,本文有5千字左右,你也可以跳到感兴趣的部分阅读。美国的城市风貌与基础设施1、去到了美国才知道,纽约的城市样貌跟我想象的发达不一样,真实的纽约街景是很嘈杂和市井。例如,在曼哈顿区路旁,随处可见的小摊位,卖鲜花的、卖各种小食、卖自制首饰的,卖艺术品等等。我留意一下,发现每个路边摊都有合法的营业执照。"
46
-
47
-
48
  no_change_btn = gr.Button.update()
49
  enable_btn = gr.Button.update(interactive=True)
50
  disable_btn = gr.Button.update(interactive=False)
51
 
52
- # enable_moderation = False
53
- # concurrency_count = 10
54
- # model_list_mode = 'reload'
 
55
 
56
- # midware_url = "http://159.138.58.253:8080/api/v1/chat/models"
57
- # chat_token = 'abc'
58
- # worker_addr = 'http://159.138.58.253:8080/api/v1/chat'
 
 
 
 
59
 
60
- # allow_running = 5
61
- # ft_list_job_url = "http://49.0.247.41:30139/api/v1/job"
62
- # ft_submit_job_url = "http://49.0.247.41:30139/api/v1/job"
63
- # ft_remove_job_url = "http://49.0.247.41:30139/api/v1/job/"
64
- # ft_console_log_url = "ws://49.0.247.41:30139/api/v1/log/"
65
 
 
 
 
66
 
67
- enable_moderation = True if os.environ.get('enable_moderation', default='False')=="True" else False
68
- concurrency_count = int(os.environ.get('concurrency_count', default='10'))
69
- model_list_mode = os.environ.get('model_list_mode', default='reload')
70
 
71
- midware_url = os.environ['midware_url']
72
- chat_token = os.environ.get('chat_token', default='')
73
- worker_addr = os.environ.get('worker_addr', default='')
74
-
75
- allow_running = int(os.environ.get('allow_running', default='1'))
76
- ft_list_job_url = os.environ.get('ft_console_log_url', default='')
77
- ft_submit_job_url = os.environ.get('ft_console_log_url', default='')
78
- ft_remove_job_url = os.environ.get('ft_console_log_url', default='')
79
- ft_console_log_url = os.environ.get('ft_console_log_url', default='')
80
 
81
 
82
  headers = {"User-Agent": "FastChat Client", "PRIVATE-TOKEN": chat_token}
@@ -89,6 +89,39 @@ The service is a research preview intended for non-commercial use only, subject
89
  ip_expiration_dict = defaultdict(lambda: 0)
90
 
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  class State:
93
  def __init__(self, model_name):
94
  self.conv = get_conversation_template(model_name)
@@ -122,26 +155,17 @@ def get_conv_log_filename():
122
 
123
 
124
  def get_model_list(midware_url):
125
- ret = requests.get(midware_url, headers={"PRIVATE-TOKEN": chat_token})
126
- models = ret.json()["data"]
 
 
 
127
 
128
  priority = {k: f"___{i:02d}" for i, k in enumerate(model_info)}
129
  models.sort(key=lambda x: priority.get(x, x))
130
  logger.info(f"Models: {models}")
131
  return models
132
 
133
- df_headers = [
134
- "Job Name",
135
- "Create By",
136
- "Create At",
137
- "Model",
138
- "Dataset",
139
- "Status",
140
- "HPs"
141
- ]
142
- values= [["task111", "Tom", "20230829 14:30", "Vicuna", "cat", "Done", "{\"epochs\": \"1\", \"train_batch_size\": \"2\",\"eval_batch_size\": \"3\" ,\"train_batch_size\": \"2\",\"train_batch_size\": \"2\"}"],
143
- ["task222", "Jerry", "20230829 15:30", "Vicuna", "dog", "Doing", "{\"train_batch_size\": \"2\", \"train_batch_size\": \"2\",\"train_batch_size\": \"2\" ,\"train_batch_size\": \"2\",\"train_batch_size\": \"2\"}"],
144
- ["task333", "Somebody", "20230830 15:30", "Vicuna", "cat", "Error", "{\"train_batch_size\": \"2\", \"train_batch_size\": \"2\",\"train_batch_size\": \"2\" ,\"train_batch_size\": \"2\",\"train_batch_size\": \"2\"}"]]
145
 
146
  def load_demo_single(models, url_params):
147
  selected_model = models[0] if len(models) > 0 else ""
@@ -283,9 +307,9 @@ def model_worker_stream_iter(
283
  ):
284
  # Make requests
285
  gen_params = {
286
- "model": model_name,
287
- "prompt": prompt,
288
- "temperature": temperature,
289
  "repetition_penalty": repetition_penalty,
290
  "top_p": top_p,
291
  "max_new_tokens": max_new_tokens,
@@ -384,6 +408,7 @@ def bot_response(state, temperature, top_p, max_new_tokens, request: gr.Request)
384
  output = data["text"].strip()
385
  if "vicuna" in model_name:
386
  output = post_process_code(output)
 
387
  conv.update_last_message(output + "▌")
388
  yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
389
  else:
@@ -504,42 +529,9 @@ def get_model_description_md(models):
504
 
505
 
506
  def build_single_model_ui(models, add_promotion_links=False):
507
- global_notice_markdown = f"""
508
- # Vicuna runs on Ascend
509
- ## What does this space do
510
- This is a space that providing a demo for users to try vicuna big model on Ascend 910B hardware. Using this space you can chat/finetune with vicuna.
511
- ## What is changed
512
- We modified some opensource libraries to make thems run well on Ascend. It includes: fastchat, torch_npu, deepspeed-npu.
513
- ## What is not changed
514
- 1. The Vicuna model is not changed. All the model runs here are from lmsys.
515
- 2. All the libraries are not changed, excepet the ones mentioned above.
516
- ## What hardware are used
517
- 1. This web page is hosted on huggingface wih the free resource(2U16G)
518
- 2. The chat/fietune function is hosted on a Kunpeng920(CPU) + Asend 910B(NPU) machine.
519
- ## Useful link
520
- - [Ascend home page](https://www.hiascend.com/)
521
- - [Ascend related library](https://github.com/ascend)
522
- """
523
-
524
- dialog_notice_markdown = f"""
525
- # Chat with Vicuna (Ascend Backended)
526
-
527
- ### Notice
528
- This space is originally from [FastChat](https://github.com/lm-sys/FastChat), but the backend computational hardware is Ascend.
529
-
530
- ### Choose a model to chat with
531
- """
532
- finetune_notice_markdown = f"""
533
- # Finetune with Ascend
534
- ### Finetuning with Ascend
535
- ### Access to Finetuning
536
- Because of the limited computational resources, you will need a token to finetune models. Send an E-mail to [email protected] to apply for a token.
537
- """
538
- gr.Markdown(global_notice_markdown)
539
  with gr.Column():
540
  with gr.Tab("🧠 模型对话 Dialog"):
541
  state = gr.State()
542
- gr.Markdown(dialog_notice_markdown, elem_id="dialog_notice_markdown")
543
 
544
  with gr.Row(elem_id="model_selector_row"):
545
  model_selector = gr.Dropdown(
@@ -642,134 +634,120 @@ Because of the limited computational resources, you will need a token to finetun
642
  [state, temperature, top_p, max_output_tokens],
643
  [state, chatbot] + btn_list,
644
  )
645
- with gr.Tab("🎚️ 模型微调 Finetune"):
646
- gr.Markdown(finetune_notice_markdown)
647
- ft_selected_row_data = gr.State()
648
- ft_latest_running_cnt = gr.State()
649
- df_headers = [
650
- "Job Name",
651
- "Create By",
652
- "Create At",
653
- "Model",
654
- "Dataset",
655
- "Status",
656
- "HPs"
657
- ]
658
- values= [["task111", "Tom", "20230829 14:30", "Vicuna", "cat", "Done", "{\"epochs\": \"1\", \"train_batch_size\": \"2\",\"eval_batch_size\": \"3\" ,\"train_batch_size\": \"2\",\"train_batch_size\": \"2\"}"],
659
- ["task222", "Jerry", "20230829 15:30", "Vicuna", "dog", "Doing", "{\"train_batch_size\": \"2\", \"train_batch_size\": \"2\",\"train_batch_size\": \"2\" ,\"train_batch_size\": \"2\",\"train_batch_size\": \"2\"}"],
660
- ["task333", "Somebody", "20230830 15:30", "Vicuna", "cat", "Error", "{\"train_batch_size\": \"2\", \"train_batch_size\": \"2\",\"train_batch_size\": \"2\" ,\"train_batch_size\": \"2\",\"train_batch_size\": \"2\"}"]]
661
- ft_jobs_info = gr.Dataframe(
662
- headers=df_headers,
663
- type='array',
664
- datatype=["str", "str", "str", "str", "str", "str", "str"],
665
- value=values,
666
- interactive=False,
667
- )
668
- with gr.Row():
669
- ft_show_btn = gr.Button(value="Show Logs")
670
- ft_refresh_btn = gr.Button(value="Refresh")
671
- ft_remove_btn = gr.Button(value="Remove Running")
672
- with gr.Row():
673
- with gr.Column(scale=1):
674
- ft_user_name = gr.Textbox(value="", label="User Name")
675
- ft_model = gr.Dropdown(["vicuna-7b-v1.5-16k"], value="vicuna-7b-v1.5-16k", label="Model", interactive=True)
676
- ft_dataset_name = gr.Dropdown(["cat", "dog", "bird"], value="cat", label="Dataset", interactive=True)
677
- ft_token = gr.Textbox(value="", label="Finetune token")
678
- ft_submit_btn = gr.Button(value="Submit")
679
- ft_cease_btn = gr.Button(value="Cease Streaming")
680
- with gr.Column(scale=1):
681
- ft_epochs = gr.Slider(
682
- minimum=1,
683
- maximum=3,
684
- value=3,
685
- step=1,
686
- interactive=True,
687
- label="epochs",
688
- )
689
- ft_train_batch_size = gr.Textbox(value="2", label="train batch size", interactive=True)
690
- ft_eval_batch_size = gr.Textbox(value="2", label="eval batch size", interactive=True)
691
- ft_gradient_accumulation_steps = gr.Textbox(value="16", label="gradient accumulation steps", interactive=True)
692
- ft_learning_rate = gr.Textbox(value="2e-5", label="learning rate", interactive=True)
693
- ft_weight_decay = gr.Textbox(value="0.", label="weight decay", interactive=True)
694
- ft_model_max_length = gr.Textbox(value="1024", label="model max length", interactive=True)
695
- with gr.Column(scale=8):
696
- ft_console = gr.Textbox(value="", lines=28, label="Console", interactive=False)
697
- ft_jobs_info.select(ft_jobs_info_select, [ft_jobs_info, ft_model, ft_dataset_name, ft_epochs, ft_train_batch_size, ft_eval_batch_size, ft_gradient_accumulation_steps, ft_learning_rate, ft_weight_decay, ft_model_max_length], [ft_selected_row_data, ft_model, ft_dataset_name, ft_epochs, ft_train_batch_size, ft_eval_batch_size, ft_gradient_accumulation_steps, ft_learning_rate, ft_weight_decay, ft_model_max_length])
698
 
699
- ft_show_evt = ft_show_btn.click(ft_show_click, ft_selected_row_data, ft_console)
700
- ft_remove_btn.click(ft_remove_click, [ft_selected_row_data, ft_token], ft_console)
701
- ft_refresh_btn.click(ft_refresh_click, None, [ft_jobs_info, ft_latest_running_cnt])
702
-
703
- ft_submit_evt = ft_submit_btn.click(ft_submit_click, [ft_latest_running_cnt, ft_user_name, ft_model, ft_dataset_name, ft_token, ft_epochs, ft_train_batch_size, ft_eval_batch_size, ft_gradient_accumulation_steps, ft_learning_rate, ft_weight_decay, ft_model_max_length], [ft_jobs_info, ft_latest_running_cnt, ft_console])
704
- ft_cease_btn.click(ft_cease_click, ft_console, ft_console, cancels=[ft_submit_evt, ft_show_evt])
705
-
706
- return state, model_selector, chatbot, textbox, send_btn, button_row, parameter_row, ft_jobs_info, ft_latest_running_cnt
707
 
708
 
709
  def ft_get_job_data():
710
- response = requests.get(ft_list_job_url)
711
- res_lst = []
712
  running = 0
713
- for d in response.json():
714
- if isinstance(d['status'], str) and d['status'].lower() == "running":
715
- running += 1
716
- res_lst.append([d['jobName'], d['username'], d['created_at'], d['model'], d['dataset'], d['status'], d['parameter']])
 
 
 
 
 
 
 
 
717
  return res_lst, running
718
 
719
 
720
  def ft_refresh_click():
721
  return ft_get_job_data()
722
 
 
723
  def ft_cease_click(ft_console):
724
  output = ft_console + "\n" + "** Streaming output ceased by user **"
725
  return output
726
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
727
  def ft_submit_click(ft_latest_running_cnt, ft_user_name, ft_model, ft_dataset_name, ft_token, ft_epochs, ft_train_batch_size, ft_eval_batch_size, ft_gradient_accumulation_steps, ft_learning_rate, ft_weight_decay, ft_model_max_length):
728
- if ft_latest_running_cnt < allow_running:
 
 
 
 
 
 
 
 
729
  midware_header = {'Content-Type': 'application/json'}
730
  hps_json = {
731
  "epochs": str(ft_epochs),
732
- "train_batch_size": ft_train_batch_size,
733
- "eval_batch_size": ft_eval_batch_size,
734
- "gradient_accumulation_steps": ft_gradient_accumulation_steps,
735
- "learning_rate": ft_learning_rate,
736
- "weight_decay": ft_weight_decay,
737
- "model_max_length": ft_model_max_length
738
  }
739
  json_data = {
740
  "dataset": ft_dataset_name,
741
  "model": ft_model,
742
- "parameter": json.dumps(hps_json),
743
  "secret": ft_token,
744
  "username": ft_user_name
745
  }
746
- r = requests.post(ft_submit_job_url, json=json_data, headers=midware_header)
747
- gr.Info(f"Job submit success!")
748
- res_lst, running = ft_get_job_data()
749
- return res_lst, running, json.dumps(json_data) + "\n" + str(r.status_code) + json.dumps(r.json())
 
 
 
 
 
 
 
 
 
 
 
750
  else:
751
- gr.Info(f"Only allow {str(allow_running)} job(s) running simultaneously, please wait.")
752
- return None
 
 
753
 
754
  def ft_show_click(ft_selected_row_data):
755
- s = PRESET_ANSWERS * 10
756
- for i in range(1000):
757
- yield s[:i*40]
758
- time.sleep(0.01)
759
- yield s
760
 
761
  def ft_remove_click(ft_selected_row_data, ft_token):
762
  status = ft_selected_row_data[5]
763
  if isinstance(status, str) and status.lower() == "running":
764
- if not ft_token.strip():
765
- gr.Info("Remove fail, token needed.")
 
766
  else:
767
- pass
768
  else:
769
- gr.Info("Remove fail, can only remove a running job.")
770
- return ft_selected_row_data[0]
771
 
772
- def ft_jobs_info_select(ft_jobs_info, ft_model, ft_dataset_name, ft_epochs, ft_train_batch_size, ft_eval_batch_size, ft_gradient_accumulation_steps, ft_learning_rate, ft_weight_decay, ft_model_max_length, evt: gr.SelectData):
 
773
  selected_row = ft_jobs_info[evt.index[0]]
774
  if evt.index[1] in (3, 4, 6):
775
  try:
@@ -779,11 +757,12 @@ def ft_jobs_info_select(ft_jobs_info, ft_model, ft_dataset_name, ft_epochs, ft_t
779
  return [selected_row, selected_row[3], selected_row[4], Hps.get('epochs', ''), Hps.get('train_batch_size', ''), Hps.get('eval_batch_size', ''),
780
  Hps.get('gradient_accumulation_steps', ''), Hps.get('learning_rate', ''), Hps.get('weight_decay', ''), Hps.get('model_max_length', '')]
781
  else:
782
- return [selected_row, ft_model, ft_dataset_name, ft_epochs, ft_train_batch_size, ft_eval_batch_size, ft_gradient_accumulation_steps, ft_learning_rate, ft_weight_decay, ft_model_max_length]
 
783
 
784
  def build_demo(models):
785
  with gr.Blocks(
786
- title="Chat with Vicuna (Ascend Backended)",
787
  theme=gr.themes.Base(),
788
  css = block_css
789
  ) as demo:
@@ -796,8 +775,6 @@ def build_demo(models):
796
  send_btn,
797
  button_row,
798
  parameter_row,
799
- ft_jobs_info,
800
- ft_latest_running_cnt,
801
  ) = build_single_model_ui(models)
802
 
803
  if model_list_mode not in ["once", "reload"]:
@@ -816,14 +793,6 @@ def build_demo(models):
816
  ],
817
  _js=get_window_url_params_js,
818
  )
819
- demo.load(
820
- ft_get_job_data,
821
- None,
822
- [
823
- ft_jobs_info,
824
- ft_latest_running_cnt,
825
- ]
826
- )
827
 
828
  return demo
829
 
 
7
  import time
8
  import uuid
9
  import websocket
10
+ from websocket import WebSocketConnectionClosedException
11
  import gradio as gr
12
  import requests
13
+ import logging
14
 
15
  from fastchat.conversation import SeparatorStyle
16
  from fastchat.constants import (
 
43
 
44
  logger = build_logger("gradio_web_server", "gradio_web_server.log")
45
 
46
+ no_change_dropdown = gr.Dropdown.update()
47
+ no_change_slider = gr.Slider.update()
48
+ no_change_textbox = gr.Textbox.update()
49
  no_change_btn = gr.Button.update()
50
  enable_btn = gr.Button.update(interactive=True)
51
  disable_btn = gr.Button.update(interactive=False)
52
 
53
+ enable_moderation = False
54
+ concurrency_count = 10
55
+ model_list_mode = 'reload'
56
+ allow_running = 1
57
 
58
+ midware_url = "http://119.8.43.169:8080/api/v1/chat/models"
59
+ worker_addr = 'http://119.8.43.169:8080/api/v1/chat'
60
+ chat_token = 'abc'
61
+ ft_list_job_url = "http://119.8.43.169:8080/api/v1/job"
62
+ ft_submit_job_url = "http://119.8.43.169:8080/api/v1/job"
63
+ ft_remove_job_url = "http://119.8.43.169:8080/api/v1/job/"
64
+ ft_console_log_url = "ws://119.8.43.169:8080/api/v1/log/"
65
 
 
 
 
 
 
66
 
67
+ # enable_moderation = True if os.environ.get('enable_moderation', default='False')=="True" else False
68
+ # concurrency_count = int(os.environ.get('concurrency_count', default='10'))
69
+ # model_list_mode = os.environ.get('model_list_mode', default='reload')
70
 
71
+ # midware_url = os.environ.get('midware_url', default='')
72
+ # chat_token = os.environ.get('chat_token', default='')
73
+ # worker_addr = os.environ.get('worker_addr', default='')
74
 
75
+ # allow_running = int(os.environ.get('allow_running', default='1'))
76
+ # ft_list_job_url = os.environ.get('ft_list_job_url', default='')
77
+ # ft_submit_job_url = os.environ.get('ft_submit_job_url', default='')
78
+ # ft_remove_job_url = os.environ.get('ft_remove_job_url', default='')
79
+ # ft_console_log_url = os.environ.get('ft_console_log_url', default='')
 
 
 
 
80
 
81
 
82
  headers = {"User-Agent": "FastChat Client", "PRIVATE-TOKEN": chat_token}
 
89
  ip_expiration_dict = defaultdict(lambda: 0)
90
 
91
 
92
+ def is_legal_char(c):
93
+ if c.isalnum():
94
+ return True
95
+ if c in "!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏.":
96
+ return True
97
+ if c in '!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~':
98
+ return True
99
+ return False
100
+
101
+
102
+ def str_filter(s):
103
+ for _ in range(2):
104
+ if len(s) > 0 and (not is_legal_char(s[-1])):
105
+ s = s[:-1]
106
+ return s
107
+
108
+
109
+ def str_not_int(s):
110
+ try:
111
+ int(s)
112
+ return False
113
+ except ValueError:
114
+ return True
115
+
116
+
117
+ def str_not_float(s):
118
+ try:
119
+ float(s)
120
+ return False
121
+ except ValueError:
122
+ return True
123
+
124
+
125
  class State:
126
  def __init__(self, model_name):
127
  self.conv = get_conversation_template(model_name)
 
155
 
156
 
157
  def get_model_list(midware_url):
158
+ try:
159
+ ret = requests.get(midware_url, headers={"PRIVATE-TOKEN": chat_token}, timeout=5)
160
+ models = ret.json()["data"]
161
+ except requests.exceptions.RequestException:
162
+ models = ["CANNOT GET MODEL"]
163
 
164
  priority = {k: f"___{i:02d}" for i, k in enumerate(model_info)}
165
  models.sort(key=lambda x: priority.get(x, x))
166
  logger.info(f"Models: {models}")
167
  return models
168
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
  def load_demo_single(models, url_params):
171
  selected_model = models[0] if len(models) > 0 else ""
 
307
  ):
308
  # Make requests
309
  gen_params = {
310
+ "model_name": model_name,
311
+ "question": prompt,
312
+ "temperature": 1e-6,
313
  "repetition_penalty": repetition_penalty,
314
  "top_p": top_p,
315
  "max_new_tokens": max_new_tokens,
 
408
  output = data["text"].strip()
409
  if "vicuna" in model_name:
410
  output = post_process_code(output)
411
+ output = str_filter(output)
412
  conv.update_last_message(output + "▌")
413
  yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
414
  else:
 
529
 
530
 
531
  def build_single_model_ui(models, add_promotion_links=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
532
  with gr.Column():
533
  with gr.Tab("🧠 模型对话 Dialog"):
534
  state = gr.State()
 
535
 
536
  with gr.Row(elem_id="model_selector_row"):
537
  model_selector = gr.Dropdown(
 
634
  [state, temperature, top_p, max_output_tokens],
635
  [state, chatbot] + btn_list,
636
  )
637
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
638
 
639
+ return state, model_selector, chatbot, textbox, send_btn, button_row, parameter_row
 
 
 
 
 
 
 
640
 
641
 
642
  def ft_get_job_data():
 
 
643
  running = 0
644
+ res_lst = []
645
+ try:
646
+ r = requests.get(ft_list_job_url, timeout=8)
647
+ for d in r.json():
648
+ if isinstance(d['status'], str) and d['status'].lower() == "running":
649
+ running += 1
650
+ res_lst.append([d['jobName'], d['username'], d['created_at'], d['model'], d['dataset'], d['status'], json.dumps(d['parameter'])])
651
+ res_lst = sorted(res_lst,key=(lambda x:x[2]), reverse=True)
652
+ res_lst = sorted(res_lst,key=(lambda x:x[5]), reverse=True)
653
+ except requests.exceptions.RequestException:
654
+ logger.info(f"Get job list fail")
655
+ res_lst.append([])
656
  return res_lst, running
657
 
658
 
659
  def ft_refresh_click():
660
  return ft_get_job_data()
661
 
662
+
663
  def ft_cease_click(ft_console):
664
  output = ft_console + "\n" + "** Streaming output ceased by user **"
665
  return output
666
 
667
+
668
+ def console_generator(addr, sleep_time):
669
+ total_str = ""
670
+ ws = websocket.WebSocket()
671
+ ws.connect(addr)
672
+ while True:
673
+ try:
674
+ new_str = ws.recv()
675
+ total_str = total_str + new_str
676
+ time.sleep(sleep_time)
677
+ yield total_str
678
+ except WebSocketConnectionClosedException:
679
+ ws.close()
680
+ break
681
+ ws.close()
682
+
683
+
684
  def ft_submit_click(ft_latest_running_cnt, ft_user_name, ft_model, ft_dataset_name, ft_token, ft_epochs, ft_train_batch_size, ft_eval_batch_size, ft_gradient_accumulation_steps, ft_learning_rate, ft_weight_decay, ft_model_max_length):
685
+ if ft_user_name == "":
686
+ gr.Warning(f"Submit fail, empty username.")
687
+ res_lst, running = ft_get_job_data()
688
+ return res_lst, running, no_change_textbox
689
+ if str_not_int(ft_train_batch_size) or str_not_int(ft_eval_batch_size) or str_not_int(ft_gradient_accumulation_steps) or str_not_float(ft_learning_rate) or str_not_float(ft_weight_decay) or str_not_int(ft_model_max_length):
690
+ gr.Warning(f"Submit fail, check the types. [learning rate] and [weight decay] should be float, others HPs should be int.")
691
+ res_lst, running = ft_get_job_data()
692
+ return res_lst, running, no_change_textbox
693
+ if ft_latest_running_cnt < int(allow_running):
694
  midware_header = {'Content-Type': 'application/json'}
695
  hps_json = {
696
  "epochs": str(ft_epochs),
697
+ "train_batch_size": str(ft_train_batch_size),
698
+ "eval_batch_size": str(ft_eval_batch_size),
699
+ "gradient_accumulation_steps": str(ft_gradient_accumulation_steps),
700
+ "learning_rate": str(ft_learning_rate),
701
+ "weight_decay": str(ft_weight_decay),
702
+ "model_max_length": str(ft_model_max_length)
703
  }
704
  json_data = {
705
  "dataset": ft_dataset_name,
706
  "model": ft_model,
707
+ "parameter": hps_json,
708
  "secret": ft_token,
709
  "username": ft_user_name
710
  }
711
+ try:
712
+ r = requests.post(ft_submit_job_url, json=json_data, headers=midware_header, timeout=120)
713
+ job_name = r.json()["jobName"]
714
+ gr.Info(f"Job {job_name} submit success.")
715
+ res_lst, running = ft_get_job_data()
716
+ total_str = ""
717
+ for s in console_generator(ft_console_log_url + job_name, 1):
718
+ total_str = s
719
+ yield res_lst, running, s
720
+ res_lst, running = ft_get_job_data()
721
+ yield res_lst, running, total_str
722
+ except requests.exceptions.RequestException:
723
+ gr.Warning(f"Connection Failure.")
724
+ res_lst, running = ft_get_job_data()
725
+ return res_lst, running, ""
726
  else:
727
+ gr.Warning(f"Only allow {str(allow_running)} job(s) running simultaneously, please wait.")
728
+ res_lst, running = ft_get_job_data()
729
+ return res_lst, running, no_change_textbox
730
+
731
 
732
  def ft_show_click(ft_selected_row_data):
733
+ for s in console_generator(ft_console_log_url + ft_selected_row_data[0], 0.2):
734
+ yield s
735
+
 
 
736
 
737
  def ft_remove_click(ft_selected_row_data, ft_token):
738
  status = ft_selected_row_data[5]
739
  if isinstance(status, str) and status.lower() == "running":
740
+ r = requests.delete(ft_remove_job_url + ft_selected_row_data[0], json={"secret": ft_token})
741
+ if r.status_code == 200:
742
+ gr.Info("Remove success.")
743
  else:
744
+ gr.Warning(f"Remove fail. {r.status_code} {r.reason}.")
745
  else:
746
+ gr.Warning("Remove fail. Can only remove a running job.")
747
+ return ft_get_job_data()
748
 
749
+
750
+ def ft_jobs_info_select(ft_jobs_info, evt: gr.SelectData):
751
  selected_row = ft_jobs_info[evt.index[0]]
752
  if evt.index[1] in (3, 4, 6):
753
  try:
 
757
  return [selected_row, selected_row[3], selected_row[4], Hps.get('epochs', ''), Hps.get('train_batch_size', ''), Hps.get('eval_batch_size', ''),
758
  Hps.get('gradient_accumulation_steps', ''), Hps.get('learning_rate', ''), Hps.get('weight_decay', ''), Hps.get('model_max_length', '')]
759
  else:
760
+ return [selected_row, no_change_dropdown, no_change_dropdown, no_change_slider, no_change_textbox, no_change_textbox, no_change_textbox, no_change_textbox, no_change_textbox, no_change_textbox]
761
+
762
 
763
  def build_demo(models):
764
  with gr.Blocks(
765
+ title="Vicuna (Ascend Backended)",
766
  theme=gr.themes.Base(),
767
  css = block_css
768
  ) as demo:
 
775
  send_btn,
776
  button_row,
777
  parameter_row,
 
 
778
  ) = build_single_model_ui(models)
779
 
780
  if model_list_mode not in ["once", "reload"]:
 
793
  ],
794
  _js=get_window_url_params_js,
795
  )
 
 
 
 
 
 
 
 
796
 
797
  return demo
798