celestialli commited on
Commit
b1f3eeb
·
0 Parent(s):

Duplicate from ascend-ai/vicuna-on-ascend

Browse files
Files changed (4) hide show
  1. .gitattributes +35 -0
  2. README.md +14 -0
  3. app.py +839 -0
  4. requirements.txt +3 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Vicuna On Ascend
3
+ emoji: 🏃
4
+ colorFrom: red
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 3.41.2
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ duplicated_from: ascend-ai/vicuna-on-ascend
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,839 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from collections import defaultdict
3
+ import datetime
4
+ import json
5
+ import os
6
+ import random
7
+ import time
8
+ import uuid
9
+ import websocket
10
+
11
+ import gradio as gr
12
+ import requests
13
+
14
+ from fastchat.conversation import SeparatorStyle
15
+ from fastchat.constants import (
16
+ LOGDIR,
17
+ WORKER_API_TIMEOUT,
18
+ ErrorCode,
19
+ MODERATION_MSG,
20
+ CONVERSATION_LIMIT_MSG,
21
+ SERVER_ERROR_MSG,
22
+ INACTIVE_MSG,
23
+ INPUT_CHAR_LEN_LIMIT,
24
+ CONVERSATION_TURN_LIMIT,
25
+ SESSION_EXPIRATION_TIME,
26
+ )
27
+ from fastchat.model.model_adapter import get_conversation_template
28
+ from fastchat.model.model_registry import model_info
29
+ from fastchat.serve.api_provider import (
30
+ anthropic_api_stream_iter,
31
+ openai_api_stream_iter,
32
+ palm_api_stream_iter,
33
+ init_palm_chat,
34
+ )
35
+ from fastchat.utils import (
36
+ build_logger,
37
+ violates_moderation,
38
+ get_window_url_params_js,
39
+ parse_gradio_auth_creds,
40
+ )
41
+
42
+
43
+ logger = build_logger("gradio_web_server", "gradio_web_server.log")
44
+
45
+ PRESET_ANSWERS = "刚到美国的时候,觉得美国人像傻子一样,到处都是漏洞。任何地方的厕所都有免费纸,有些人定期去扯很多回家,纸都不用买。快餐店的饮料,有的可以无限续杯,有些几个人买一份饮料,接回来灌到各自的杯子里;等等。尽管美国有许多“漏洞”,但作为超级大国,显然能带给人以无尽的故事与思考。我来分享一下哪些是去了美国才知道的事,主题主要围绕着生活、衣食住行、文化冲击、教育医疗等展开叙说,本文有5千字左右,你也可以跳到感兴趣的部分阅读。美国的城市风貌与基础设施1、去到了美国才知道,纽约的城市样貌跟我想象的发达不一样,真实的纽约街景是很嘈杂和市井。例如,在曼哈顿区路旁,随处可见的小摊位,卖鲜花的、卖各种小食、卖自制首饰的,卖艺术品等等。我留意一下,发现每个路边摊都有合法的营业执照。"
46
+
47
+
48
+ no_change_btn = gr.Button.update()
49
+ enable_btn = gr.Button.update(interactive=True)
50
+ disable_btn = gr.Button.update(interactive=False)
51
+
52
+ # enable_moderation = False
53
+ # concurrency_count = 10
54
+ # model_list_mode = 'reload'
55
+
56
+ # midware_url = "http://159.138.58.253:8080/api/v1/chat/models"
57
+ # chat_token = 'abc'
58
+ # worker_addr = 'http://159.138.58.253:8080/api/v1/chat'
59
+
60
+ # allow_running = 5
61
+ # ft_list_job_url = "http://49.0.247.41:30139/api/v1/job"
62
+ # ft_submit_job_url = "http://49.0.247.41:30139/api/v1/job"
63
+ # ft_remove_job_url = "http://49.0.247.41:30139/api/v1/job/"
64
+ # ft_console_log_url = "ws://49.0.247.41:30139/api/v1/log/"
65
+
66
+
67
+ enable_moderation = True if os.environ.get('enable_moderation', default='False')=="True" else False
68
+ concurrency_count = int(os.environ.get('concurrency_count', default='10'))
69
+ model_list_mode = os.environ.get('model_list_mode', default='reload')
70
+
71
+ midware_url = os.environ['midware_url']
72
+ chat_token = os.environ.get('chat_token', default='')
73
+ worker_addr = os.environ.get('worker_addr', default='')
74
+
75
+ allow_running = int(os.environ.get('allow_running', default='1'))
76
+ ft_list_job_url = os.environ.get('ft_console_log_url', default='')
77
+ ft_submit_job_url = os.environ.get('ft_console_log_url', default='')
78
+ ft_remove_job_url = os.environ.get('ft_console_log_url', default='')
79
+ ft_console_log_url = os.environ.get('ft_console_log_url', default='')
80
+
81
+
82
+ headers = {"User-Agent": "FastChat Client", "PRIVATE-TOKEN": chat_token}
83
+
84
+ learn_more_md = """
85
+ ### License
86
+ The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/LICENSE) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
87
+ """
88
+
89
+ ip_expiration_dict = defaultdict(lambda: 0)
90
+
91
+
92
+ class State:
93
+ def __init__(self, model_name):
94
+ self.conv = get_conversation_template(model_name)
95
+ self.conv_id = uuid.uuid4().hex
96
+ self.skip_next = False
97
+ self.model_name = model_name
98
+
99
+ if model_name == "palm-2":
100
+ # According to release note, "chat-bison@001" is PaLM 2 for chat.
101
+ # https://cloud.google.com/vertex-ai/docs/release-notes#May_10_2023
102
+ self.palm_chat = init_palm_chat("chat-bison@001")
103
+
104
+ def to_gradio_chatbot(self):
105
+ return self.conv.to_gradio_chatbot()
106
+
107
+ def dict(self):
108
+ base = self.conv.dict()
109
+ base.update(
110
+ {
111
+ "conv_id": self.conv_id,
112
+ "model_name": self.model_name,
113
+ }
114
+ )
115
+ return base
116
+
117
+
118
+ def get_conv_log_filename():
119
+ t = datetime.datetime.now()
120
+ name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
121
+ return name
122
+
123
+
124
+ def get_model_list(midware_url):
125
+ ret = requests.get(midware_url, headers={"PRIVATE-TOKEN": chat_token})
126
+ models = ret.json()["data"]
127
+
128
+ priority = {k: f"___{i:02d}" for i, k in enumerate(model_info)}
129
+ models.sort(key=lambda x: priority.get(x, x))
130
+ logger.info(f"Models: {models}")
131
+ return models
132
+
133
+ df_headers = [
134
+ "Job Name",
135
+ "Create By",
136
+ "Create At",
137
+ "Model",
138
+ "Dataset",
139
+ "Status",
140
+ "HPs"
141
+ ]
142
+ values= [["task111", "Tom", "20230829 14:30", "Vicuna", "cat", "Done", "{\"epochs\": \"1\", \"train_batch_size\": \"2\",\"eval_batch_size\": \"3\" ,\"train_batch_size\": \"2\",\"train_batch_size\": \"2\"}"],
143
+ ["task222", "Jerry", "20230829 15:30", "Vicuna", "dog", "Doing", "{\"train_batch_size\": \"2\", \"train_batch_size\": \"2\",\"train_batch_size\": \"2\" ,\"train_batch_size\": \"2\",\"train_batch_size\": \"2\"}"],
144
+ ["task333", "Somebody", "20230830 15:30", "Vicuna", "cat", "Error", "{\"train_batch_size\": \"2\", \"train_batch_size\": \"2\",\"train_batch_size\": \"2\" ,\"train_batch_size\": \"2\",\"train_batch_size\": \"2\"}"]]
145
+
146
+ def load_demo_single(models, url_params):
147
+ selected_model = models[0] if len(models) > 0 else ""
148
+ if "model" in url_params:
149
+ model = url_params["model"]
150
+ if model in models:
151
+ selected_model = model
152
+
153
+ dropdown_update = gr.Dropdown.update(
154
+ choices=models, value=selected_model, visible=True
155
+ )
156
+
157
+ state = None
158
+ return (
159
+ state,
160
+ dropdown_update,
161
+ gr.Chatbot.update(visible=True),
162
+ gr.Textbox.update(visible=True),
163
+ gr.Button.update(visible=True),
164
+ gr.Row.update(visible=True),
165
+ gr.Accordion.update(visible=True),
166
+ )
167
+
168
+
169
+ def load_demo(url_params, request: gr.Request):
170
+ global models
171
+
172
+ ip = request.client.host
173
+ logger.info(f"load_demo. ip: {ip}. params: {url_params}")
174
+ ip_expiration_dict[ip] = time.time() + SESSION_EXPIRATION_TIME
175
+
176
+ if model_list_mode == "reload":
177
+ models = get_model_list(midware_url)
178
+
179
+ return load_demo_single(models, url_params)
180
+
181
+
182
+ def vote_last_response(state, vote_type, model_selector, request: gr.Request):
183
+ with open(get_conv_log_filename(), "a") as fout:
184
+ data = {
185
+ "tstamp": round(time.time(), 4),
186
+ "type": vote_type,
187
+ "model": model_selector,
188
+ "state": state.dict(),
189
+ "ip": request.client.host,
190
+ }
191
+ fout.write(json.dumps(data) + "\n")
192
+
193
+
194
+ def upvote_last_response(state, model_selector, request: gr.Request):
195
+ logger.info(f"upvote. ip: {request.client.host}")
196
+ vote_last_response(state, "upvote", model_selector, request)
197
+ return ("",) + (disable_btn,) * 3
198
+
199
+
200
+ def downvote_last_response(state, model_selector, request: gr.Request):
201
+ logger.info(f"downvote. ip: {request.client.host}")
202
+ vote_last_response(state, "downvote", model_selector, request)
203
+ return ("",) + (disable_btn,) * 3
204
+
205
+
206
+ def flag_last_response(state, model_selector, request: gr.Request):
207
+ logger.info(f"flag. ip: {request.client.host}")
208
+ vote_last_response(state, "flag", model_selector, request)
209
+ return ("",) + (disable_btn,) * 3
210
+
211
+
212
+ def regenerate(state, request: gr.Request):
213
+ logger.info(f"regenerate. ip: {request.client.host}")
214
+ state.conv.update_last_message(None)
215
+ return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
216
+
217
+
218
+ def clear_history(request: gr.Request):
219
+ logger.info(f"clear_history. ip: {request.client.host}")
220
+ state = None
221
+ return (state, [], "") + (disable_btn,) * 5
222
+
223
+
224
+ def add_text(state, model_selector, text, request: gr.Request):
225
+ ip = request.client.host
226
+ logger.info(f"add_text. ip: {ip}. len: {len(text)}")
227
+
228
+ if state is None:
229
+ state = State(model_selector)
230
+
231
+ if len(text) <= 0:
232
+ state.skip_next = True
233
+ return (state, state.to_gradio_chatbot(), "") + (no_change_btn,) * 5
234
+
235
+ if ip_expiration_dict[ip] < time.time():
236
+ logger.info(f"inactive. ip: {request.client.host}. text: {text}")
237
+ state.skip_next = True
238
+ return (state, state.to_gradio_chatbot(), INACTIVE_MSG) + (no_change_btn,) * 5
239
+
240
+ if enable_moderation:
241
+ flagged = violates_moderation(text)
242
+ if flagged:
243
+ logger.info(f"violate moderation. ip: {request.client.host}. text: {text}")
244
+ state.skip_next = True
245
+ return (state, state.to_gradio_chatbot(), MODERATION_MSG) + (
246
+ no_change_btn,
247
+ ) * 5
248
+
249
+ conv = state.conv
250
+ if (len(conv.messages) - conv.offset) // 2 >= CONVERSATION_TURN_LIMIT:
251
+ logger.info(f"conversation turn limit. ip: {request.client.host}. text: {text}")
252
+ state.skip_next = True
253
+ return (state, state.to_gradio_chatbot(), CONVERSATION_LIMIT_MSG) + (
254
+ no_change_btn,
255
+ ) * 5
256
+
257
+ text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off
258
+ conv.append_message(conv.roles[0], text)
259
+ conv.append_message(conv.roles[1], None)
260
+ return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
261
+
262
+
263
+ def post_process_code(code):
264
+ sep = "\n```"
265
+ if sep in code:
266
+ blocks = code.split(sep)
267
+ if len(blocks) % 2 == 1:
268
+ for i in range(1, len(blocks), 2):
269
+ blocks[i] = blocks[i].replace("\\_", "_")
270
+ code = sep.join(blocks)
271
+ return code
272
+
273
+
274
+ def model_worker_stream_iter(
275
+ conv,
276
+ model_name,
277
+ worker_addr,
278
+ prompt,
279
+ temperature,
280
+ repetition_penalty,
281
+ top_p,
282
+ max_new_tokens,
283
+ ):
284
+ # Make requests
285
+ gen_params = {
286
+ "model": model_name,
287
+ "prompt": prompt,
288
+ "temperature": temperature,
289
+ "repetition_penalty": repetition_penalty,
290
+ "top_p": top_p,
291
+ "max_new_tokens": max_new_tokens,
292
+ "stop": conv.stop_str,
293
+ "stop_token_ids": conv.stop_token_ids,
294
+ "echo": False,
295
+ }
296
+ logger.info(f"==== request ====\n{gen_params}")
297
+
298
+ # Stream output
299
+ response = requests.post(
300
+ worker_addr,
301
+ headers=headers,
302
+ json=gen_params,
303
+ stream=True,
304
+ timeout=WORKER_API_TIMEOUT,
305
+ )
306
+ for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
307
+ if chunk:
308
+ data = json.loads(chunk.decode())
309
+ yield data
310
+
311
+
312
+ def bot_response(state, temperature, top_p, max_new_tokens, request: gr.Request):
313
+ logger.info(f"bot_response. ip: {request.client.host}")
314
+ start_tstamp = time.time()
315
+ temperature = float(temperature)
316
+ top_p = float(top_p)
317
+ max_new_tokens = int(max_new_tokens)
318
+
319
+ if state.skip_next:
320
+ # This generate call is skipped due to invalid inputs
321
+ state.skip_next = False
322
+ yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
323
+ return
324
+
325
+ conv, model_name = state.conv, state.model_name
326
+ if model_name == "gpt-3.5-turbo" or model_name == "gpt-4":
327
+ prompt = conv.to_openai_api_messages()
328
+ stream_iter = openai_api_stream_iter(
329
+ model_name, prompt, temperature, top_p, max_new_tokens
330
+ )
331
+ elif model_name == "claude-2" or model_name == "claude-instant-1":
332
+ prompt = conv.get_prompt()
333
+ stream_iter = anthropic_api_stream_iter(
334
+ model_name, prompt, temperature, top_p, max_new_tokens
335
+ )
336
+ elif model_name == "palm-2":
337
+ stream_iter = palm_api_stream_iter(
338
+ state.palm_chat, conv.messages[-2][1], temperature, top_p, max_new_tokens
339
+ )
340
+ else:
341
+ # Get worker address
342
+ logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
343
+ # No available worker
344
+ if worker_addr == "":
345
+ conv.update_last_message(SERVER_ERROR_MSG)
346
+ yield (
347
+ state,
348
+ state.to_gradio_chatbot(),
349
+ disable_btn,
350
+ disable_btn,
351
+ disable_btn,
352
+ enable_btn,
353
+ enable_btn,
354
+ )
355
+ return
356
+
357
+ # Construct prompt.
358
+ # We need to call it here, so it will not be affected by "▌".
359
+ prompt = conv.get_prompt()
360
+
361
+ # Set repetition_penalty
362
+ if "t5" in model_name:
363
+ repetition_penalty = 1.2
364
+ else:
365
+ repetition_penalty = 1.0
366
+
367
+ stream_iter = model_worker_stream_iter(
368
+ conv,
369
+ model_name,
370
+ worker_addr,
371
+ prompt,
372
+ temperature,
373
+ repetition_penalty,
374
+ top_p,
375
+ max_new_tokens,
376
+ )
377
+
378
+ conv.update_last_message("▌")
379
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
380
+
381
+ try:
382
+ for data in stream_iter:
383
+ if data["error_code"] == 0:
384
+ output = data["text"].strip()
385
+ if "vicuna" in model_name:
386
+ output = post_process_code(output)
387
+ conv.update_last_message(output + "▌")
388
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
389
+ else:
390
+ output = data["text"] + f"\n\n(error_code: {data['error_code']})"
391
+ conv.update_last_message(output)
392
+ yield (state, state.to_gradio_chatbot()) + (
393
+ disable_btn,
394
+ disable_btn,
395
+ disable_btn,
396
+ enable_btn,
397
+ enable_btn,
398
+ )
399
+ return
400
+ time.sleep(0.015)
401
+ except requests.exceptions.RequestException as e:
402
+ conv.update_last_message(
403
+ f"{SERVER_ERROR_MSG}\n\n"
404
+ f"(error_code: {ErrorCode.GRADIO_REQUEST_ERROR}, {e})"
405
+ )
406
+ yield (state, state.to_gradio_chatbot()) + (
407
+ disable_btn,
408
+ disable_btn,
409
+ disable_btn,
410
+ enable_btn,
411
+ enable_btn,
412
+ )
413
+ return
414
+ except Exception as e:
415
+ conv.update_last_message(
416
+ f"{SERVER_ERROR_MSG}\n\n"
417
+ f"(error_code: {ErrorCode.GRADIO_STREAM_UNKNOWN_ERROR}, {e})"
418
+ )
419
+ yield (state, state.to_gradio_chatbot()) + (
420
+ disable_btn,
421
+ disable_btn,
422
+ disable_btn,
423
+ enable_btn,
424
+ enable_btn,
425
+ )
426
+ return
427
+
428
+ # Delete "▌"
429
+ conv.update_last_message(conv.messages[-1][-1][:-1])
430
+ yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
431
+
432
+ finish_tstamp = time.time()
433
+ logger.info(f"{output}")
434
+
435
+ with open(get_conv_log_filename(), "a") as fout:
436
+ data = {
437
+ "tstamp": round(finish_tstamp, 4),
438
+ "type": "chat",
439
+ "model": model_name,
440
+ "gen_params": {
441
+ "temperature": temperature,
442
+ "top_p": top_p,
443
+ "max_new_tokens": max_new_tokens,
444
+ },
445
+ "start": round(start_tstamp, 4),
446
+ "finish": round(finish_tstamp, 4),
447
+ "state": state.dict(),
448
+ "ip": request.client.host,
449
+ }
450
+ fout.write(json.dumps(data) + "\n")
451
+
452
+
453
+ block_css = """
454
+ #dialog_notice_markdown {
455
+ font-size: 104%
456
+ }
457
+ #dialog_notice_markdown th {
458
+ display: none;
459
+ }
460
+ #dialog_notice_markdown td {
461
+ padding-top: 6px;
462
+ padding-bottom: 6px;
463
+ }
464
+ #leaderboard_markdown {
465
+ font-size: 104%
466
+ }
467
+ #leaderboard_markdown td {
468
+ padding-top: 6px;
469
+ padding-bottom: 6px;
470
+ }
471
+ #leaderboard_dataframe td {
472
+ line-height: 0.1em;
473
+ }
474
+ """
475
+
476
+
477
+ def get_model_description_md(models):
478
+ model_description_md = """
479
+ | | | |
480
+ | ---- | ---- | ---- |
481
+ """
482
+ ct = 0
483
+ visited = set()
484
+ for i, name in enumerate(models):
485
+ if name in model_info:
486
+ minfo = model_info[name]
487
+ if minfo.simple_name in visited:
488
+ continue
489
+ visited.add(minfo.simple_name)
490
+ one_model_md = f"[{minfo.simple_name}]({minfo.link}): {minfo.description}"
491
+ else:
492
+ visited.add(name)
493
+ one_model_md = (
494
+ f"[{name}](): Add the description at fastchat/model/model_registry.py"
495
+ )
496
+
497
+ if ct % 3 == 0:
498
+ model_description_md += "|"
499
+ model_description_md += f" {one_model_md} |"
500
+ if ct % 3 == 2:
501
+ model_description_md += "\n"
502
+ ct += 1
503
+ return model_description_md
504
+
505
+
506
+ def build_single_model_ui(models, add_promotion_links=False):
507
+ global_notice_markdown = f"""
508
+ # Vicuna runs on Ascend
509
+ ## What does this space do
510
+ This is a space that providing a demo for users to try vicuna big model on Ascend 910B hardware. Using this space you can chat/finetune with vicuna.
511
+ ## What is changed
512
+ We modified some opensource libraries to make thems run well on Ascend. It includes: fastchat, torch_npu, deepspeed-npu.
513
+ ## What is not changed
514
+ 1. The Vicuna model is not changed. All the model runs here are from lmsys.
515
+ 2. All the libraries are not changed, excepet the ones mentioned above.
516
+ ## What hardware are used
517
+ 1. This web page is hosted on huggingface wih the free resource(2U16G)
518
+ 2. The chat/fietune function is hosted on a Kunpeng920(CPU) + Asend 910B(NPU) machine.
519
+ ## Useful link
520
+ - [Ascend home page](https://www.hiascend.com/)
521
+ - [Ascend related library](https://github.com/ascend)
522
+ """
523
+
524
+ dialog_notice_markdown = f"""
525
+ # Chat with Vicuna (Ascend Backended)
526
+
527
+ ### Notice
528
+ This space is originally from [FastChat](https://github.com/lm-sys/FastChat), but the backend computational hardware is Ascend.
529
+
530
+ ### Choose a model to chat with
531
+ """
532
+ finetune_notice_markdown = f"""
533
+ # Finetune with Ascend
534
+ ### Finetuning with Ascend
535
+ ### Access to Finetuning
536
+ Because of the limited computational resources, you will need a token to finetune models. Send an E-mail to [email protected] to apply for a token.
537
+ """
538
+ gr.Markdown(global_notice_markdown)
539
+ with gr.Column():
540
+ with gr.Tab("🧠 模型对话 Dialog"):
541
+ state = gr.State()
542
+ gr.Markdown(dialog_notice_markdown, elem_id="dialog_notice_markdown")
543
+
544
+ with gr.Row(elem_id="model_selector_row"):
545
+ model_selector = gr.Dropdown(
546
+ choices=models,
547
+ value=models[0] if len(models) > 0 else "",
548
+ interactive=True,
549
+ show_label=False,
550
+ container=False,
551
+ )
552
+
553
+ chatbot = gr.Chatbot(
554
+ elem_id="chatbot",
555
+ label="Scroll down and start chatting",
556
+ visible=False,
557
+ height=550,
558
+ )
559
+ with gr.Row():
560
+ with gr.Column(scale=20):
561
+ textbox = gr.Textbox(
562
+ show_label=False,
563
+ placeholder="Enter text and press ENTER",
564
+ visible=False,
565
+ container=False,
566
+ )
567
+ with gr.Column(scale=1, min_width=50):
568
+ send_btn = gr.Button(value="Send", visible=False)
569
+
570
+ with gr.Row(visible=False) as button_row:
571
+ upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
572
+ downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
573
+ flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
574
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
575
+ clear_btn = gr.Button(value="🗑️ Clear history", interactive=False)
576
+
577
+ with gr.Accordion("Parameters", open=False, visible=False) as parameter_row:
578
+ temperature = gr.Slider(
579
+ minimum=0.0,
580
+ maximum=1.0,
581
+ value=0.7,
582
+ step=0.1,
583
+ interactive=True,
584
+ label="Temperature",
585
+ )
586
+ top_p = gr.Slider(
587
+ minimum=0.0,
588
+ maximum=1.0,
589
+ value=1.0,
590
+ step=0.1,
591
+ interactive=True,
592
+ label="Top P",
593
+ )
594
+ max_output_tokens = gr.Slider(
595
+ minimum=16,
596
+ maximum=1024,
597
+ value=512,
598
+ step=64,
599
+ interactive=True,
600
+ label="Max output tokens",
601
+ )
602
+
603
+ gr.Markdown(learn_more_md)
604
+
605
+ # Register listeners
606
+ btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
607
+ upvote_btn.click(
608
+ upvote_last_response,
609
+ [state, model_selector],
610
+ [textbox, upvote_btn, downvote_btn, flag_btn],
611
+ )
612
+ downvote_btn.click(
613
+ downvote_last_response,
614
+ [state, model_selector],
615
+ [textbox, upvote_btn, downvote_btn, flag_btn],
616
+ )
617
+ flag_btn.click(
618
+ flag_last_response,
619
+ [state, model_selector],
620
+ [textbox, upvote_btn, downvote_btn, flag_btn],
621
+ )
622
+ regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
623
+ bot_response,
624
+ [state, temperature, top_p, max_output_tokens],
625
+ [state, chatbot] + btn_list,
626
+ )
627
+ clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list)
628
+
629
+ model_selector.change(clear_history, None, [state, chatbot, textbox] + btn_list)
630
+
631
+ textbox.submit(
632
+ add_text, [state, model_selector, textbox], [state, chatbot, textbox] + btn_list
633
+ ).then(
634
+ bot_response,
635
+ [state, temperature, top_p, max_output_tokens],
636
+ [state, chatbot] + btn_list,
637
+ )
638
+ send_btn.click(
639
+ add_text, [state, model_selector, textbox], [state, chatbot, textbox] + btn_list
640
+ ).then(
641
+ bot_response,
642
+ [state, temperature, top_p, max_output_tokens],
643
+ [state, chatbot] + btn_list,
644
+ )
645
+ with gr.Tab("🎚️ 模型微调 Finetune"):
646
+ gr.Markdown(finetune_notice_markdown)
647
+ ft_selected_row_data = gr.State()
648
+ ft_latest_running_cnt = gr.State()
649
+ df_headers = [
650
+ "Job Name",
651
+ "Create By",
652
+ "Create At",
653
+ "Model",
654
+ "Dataset",
655
+ "Status",
656
+ "HPs"
657
+ ]
658
+ values= [["task111", "Tom", "20230829 14:30", "Vicuna", "cat", "Done", "{\"epochs\": \"1\", \"train_batch_size\": \"2\",\"eval_batch_size\": \"3\" ,\"train_batch_size\": \"2\",\"train_batch_size\": \"2\"}"],
659
+ ["task222", "Jerry", "20230829 15:30", "Vicuna", "dog", "Doing", "{\"train_batch_size\": \"2\", \"train_batch_size\": \"2\",\"train_batch_size\": \"2\" ,\"train_batch_size\": \"2\",\"train_batch_size\": \"2\"}"],
660
+ ["task333", "Somebody", "20230830 15:30", "Vicuna", "cat", "Error", "{\"train_batch_size\": \"2\", \"train_batch_size\": \"2\",\"train_batch_size\": \"2\" ,\"train_batch_size\": \"2\",\"train_batch_size\": \"2\"}"]]
661
+ ft_jobs_info = gr.Dataframe(
662
+ headers=df_headers,
663
+ type='array',
664
+ datatype=["str", "str", "str", "str", "str", "str", "str"],
665
+ value=values,
666
+ interactive=False,
667
+ )
668
+ with gr.Row():
669
+ ft_show_btn = gr.Button(value="Show Logs")
670
+ ft_refresh_btn = gr.Button(value="Refresh")
671
+ ft_remove_btn = gr.Button(value="Remove Running")
672
+ with gr.Row():
673
+ with gr.Column(scale=1):
674
+ ft_user_name = gr.Textbox(value="", label="User Name")
675
+ ft_model = gr.Dropdown(["vicuna-7b-v1.5-16k"], value="vicuna-7b-v1.5-16k", label="Model", interactive=True)
676
+ ft_dataset_name = gr.Dropdown(["cat", "dog", "bird"], value="cat", label="Dataset", interactive=True)
677
+ ft_token = gr.Textbox(value="", label="Finetune token")
678
+ ft_submit_btn = gr.Button(value="Submit")
679
+ ft_cease_btn = gr.Button(value="Cease Streaming")
680
+ with gr.Column(scale=1):
681
+ ft_epochs = gr.Slider(
682
+ minimum=1,
683
+ maximum=3,
684
+ value=3,
685
+ step=1,
686
+ interactive=True,
687
+ label="epochs",
688
+ )
689
+ ft_train_batch_size = gr.Textbox(value="2", label="train batch size", interactive=True)
690
+ ft_eval_batch_size = gr.Textbox(value="2", label="eval batch size", interactive=True)
691
+ ft_gradient_accumulation_steps = gr.Textbox(value="16", label="gradient accumulation steps", interactive=True)
692
+ ft_learning_rate = gr.Textbox(value="2e-5", label="learning rate", interactive=True)
693
+ ft_weight_decay = gr.Textbox(value="0.", label="weight decay", interactive=True)
694
+ ft_model_max_length = gr.Textbox(value="1024", label="model max length", interactive=True)
695
+ with gr.Column(scale=8):
696
+ ft_console = gr.Textbox(value="", lines=28, label="Console", interactive=False)
697
+ ft_jobs_info.select(ft_jobs_info_select, [ft_jobs_info, ft_model, ft_dataset_name, ft_epochs, ft_train_batch_size, ft_eval_batch_size, ft_gradient_accumulation_steps, ft_learning_rate, ft_weight_decay, ft_model_max_length], [ft_selected_row_data, ft_model, ft_dataset_name, ft_epochs, ft_train_batch_size, ft_eval_batch_size, ft_gradient_accumulation_steps, ft_learning_rate, ft_weight_decay, ft_model_max_length])
698
+
699
+ ft_show_evt = ft_show_btn.click(ft_show_click, ft_selected_row_data, ft_console)
700
+ ft_remove_btn.click(ft_remove_click, [ft_selected_row_data, ft_token], ft_console)
701
+ ft_refresh_btn.click(ft_refresh_click, None, [ft_jobs_info, ft_latest_running_cnt])
702
+
703
+ ft_submit_evt = ft_submit_btn.click(ft_submit_click, [ft_latest_running_cnt, ft_user_name, ft_model, ft_dataset_name, ft_token, ft_epochs, ft_train_batch_size, ft_eval_batch_size, ft_gradient_accumulation_steps, ft_learning_rate, ft_weight_decay, ft_model_max_length], [ft_jobs_info, ft_latest_running_cnt, ft_console])
704
+ ft_cease_btn.click(ft_cease_click, ft_console, ft_console, cancels=[ft_submit_evt, ft_show_evt])
705
+
706
+ return state, model_selector, chatbot, textbox, send_btn, button_row, parameter_row, ft_jobs_info, ft_latest_running_cnt
707
+
708
+
709
+ def ft_get_job_data():
710
+ response = requests.get(ft_list_job_url)
711
+ res_lst = []
712
+ running = 0
713
+ for d in response.json():
714
+ if isinstance(d['status'], str) and d['status'].lower() == "running":
715
+ running += 1
716
+ res_lst.append([d['jobName'], d['username'], d['created_at'], d['model'], d['dataset'], d['status'], d['parameter']])
717
+ return res_lst, running
718
+
719
+
720
+ def ft_refresh_click():
721
+ return ft_get_job_data()
722
+
723
+ def ft_cease_click(ft_console):
724
+ output = ft_console + "\n" + "** Streaming output ceased by user **"
725
+ return output
726
+
727
+ def ft_submit_click(ft_latest_running_cnt, ft_user_name, ft_model, ft_dataset_name, ft_token, ft_epochs, ft_train_batch_size, ft_eval_batch_size, ft_gradient_accumulation_steps, ft_learning_rate, ft_weight_decay, ft_model_max_length):
728
+ if ft_latest_running_cnt < allow_running:
729
+ midware_header = {'Content-Type': 'application/json'}
730
+ hps_json = {
731
+ "epochs": str(ft_epochs),
732
+ "train_batch_size": ft_train_batch_size,
733
+ "eval_batch_size": ft_eval_batch_size,
734
+ "gradient_accumulation_steps": ft_gradient_accumulation_steps,
735
+ "learning_rate": ft_learning_rate,
736
+ "weight_decay": ft_weight_decay,
737
+ "model_max_length": ft_model_max_length
738
+ }
739
+ json_data = {
740
+ "dataset": ft_dataset_name,
741
+ "model": ft_model,
742
+ "parameter": json.dumps(hps_json),
743
+ "secret": ft_token,
744
+ "username": ft_user_name
745
+ }
746
+ r = requests.post(ft_submit_job_url, json=json_data, headers=midware_header)
747
+ gr.Info(f"Job submit success!")
748
+ res_lst, running = ft_get_job_data()
749
+ return res_lst, running, json.dumps(json_data) + "\n" + str(r.status_code) + json.dumps(r.json())
750
+ else:
751
+ gr.Info(f"Only allow {str(allow_running)} job(s) running simultaneously, please wait.")
752
+ return None
753
+
754
+ def ft_show_click(ft_selected_row_data):
755
+ s = PRESET_ANSWERS * 10
756
+ for i in range(1000):
757
+ yield s[:i*40]
758
+ time.sleep(0.01)
759
+ yield s
760
+
761
+ def ft_remove_click(ft_selected_row_data, ft_token):
762
+ status = ft_selected_row_data[5]
763
+ if isinstance(status, str) and status.lower() == "running":
764
+ if not ft_token.strip():
765
+ gr.Info("Remove fail, token needed.")
766
+ else:
767
+ pass
768
+ else:
769
+ gr.Info("Remove fail, can only remove a running job.")
770
+ return ft_selected_row_data[0]
771
+
772
+ def ft_jobs_info_select(ft_jobs_info, ft_model, ft_dataset_name, ft_epochs, ft_train_batch_size, ft_eval_batch_size, ft_gradient_accumulation_steps, ft_learning_rate, ft_weight_decay, ft_model_max_length, evt: gr.SelectData):
773
+ selected_row = ft_jobs_info[evt.index[0]]
774
+ if evt.index[1] in (3, 4, 6):
775
+ try:
776
+ Hps = json.loads(selected_row[6])
777
+ except json.decoder.JSONDecodeError:
778
+ Hps = dict()
779
+ return [selected_row, selected_row[3], selected_row[4], Hps.get('epochs', ''), Hps.get('train_batch_size', ''), Hps.get('eval_batch_size', ''),
780
+ Hps.get('gradient_accumulation_steps', ''), Hps.get('learning_rate', ''), Hps.get('weight_decay', ''), Hps.get('model_max_length', '')]
781
+ else:
782
+ return [selected_row, ft_model, ft_dataset_name, ft_epochs, ft_train_batch_size, ft_eval_batch_size, ft_gradient_accumulation_steps, ft_learning_rate, ft_weight_decay, ft_model_max_length]
783
+
784
+ def build_demo(models):
785
+ with gr.Blocks(
786
+ title="Chat with Vicuna (Ascend Backended)",
787
+ theme=gr.themes.Base(),
788
+ css = block_css
789
+ ) as demo:
790
+ url_params = gr.JSON(visible=False)
791
+ (
792
+ state,
793
+ model_selector,
794
+ chatbot,
795
+ textbox,
796
+ send_btn,
797
+ button_row,
798
+ parameter_row,
799
+ ft_jobs_info,
800
+ ft_latest_running_cnt,
801
+ ) = build_single_model_ui(models)
802
+
803
+ if model_list_mode not in ["once", "reload"]:
804
+ raise ValueError(f"Unknown model list mode: {model_list_mode}")
805
+ demo.load(
806
+ load_demo,
807
+ [url_params],
808
+ [
809
+ state,
810
+ model_selector,
811
+ chatbot,
812
+ textbox,
813
+ send_btn,
814
+ button_row,
815
+ parameter_row,
816
+ ],
817
+ _js=get_window_url_params_js,
818
+ )
819
+ demo.load(
820
+ ft_get_job_data,
821
+ None,
822
+ [
823
+ ft_jobs_info,
824
+ ft_latest_running_cnt,
825
+ ]
826
+ )
827
+
828
+ return demo
829
+
830
+
831
+ models = get_model_list(midware_url)
832
+
833
+ # Launch the demo
834
+ demo = build_demo(models)
835
+ demo.queue(
836
+ concurrency_count=concurrency_count, status_update_rate=10, api_open=False
837
+ ).launch(
838
+ max_threads=200,
839
+ )
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ requests
2
+ fschat[model_worker,webui]
3
+ websocket