Limour commited on
Commit
6b70385
·
verified ·
1 Parent(s): 34a0cdc

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +13 -494
  2. sub_app.py +496 -0
app.py CHANGED
@@ -1,495 +1,14 @@
1
- import hashlib
2
- import os
3
- import re
4
- import json
5
- import threading
6
- from hf_api import restart_space
7
-
8
- import gradio as gr
9
-
10
- from chat_template import ChatTemplate
11
- from llama_cpp_python_streamingllm import StreamingLLM
12
-
13
- # ========== 全局锁,确保只能进行一个会话 ==========
14
- lock = threading.Lock()
15
- session_active = False
16
-
17
- # ========== 让聊天界面的文本框等高 ==========
18
- custom_css = r'''
19
- #area > div {
20
- height: 100%;
21
- }
22
- #RAG-area {
23
- flex-grow: 1;
24
- }
25
- #RAG-area > label {
26
- height: 100%;
27
- display: flex;
28
- flex-direction: column;
29
- }
30
- #RAG-area > label > textarea {
31
- flex-grow: 1;
32
- max-height: 20vh;
33
- }
34
- #VO-area {
35
- flex-grow: 1;
36
- }
37
- #VO-area > label {
38
- height: 100%;
39
- display: flex;
40
- flex-direction: column;
41
- }
42
- #VO-area > label > textarea {
43
- flex-grow: 1;
44
- max-height: 20vh;
45
- }
46
- #prompt > label > textarea {
47
- max-height: 63px;
48
- }
49
- '''
50
-
51
-
52
- # ========== 适配 SillyTavern 的模版 ==========
53
- def text_format(text: str, _env=None, **env):
54
- if _env is not None:
55
- for k, v in _env.items():
56
- text = text.replace(r'{{' + k + r'}}', v)
57
- for k, v in env.items():
58
- text = text.replace(r'{{' + k + r'}}', v)
59
- return text
60
-
61
-
62
- # ========== 哈希函数 ==========
63
- def x_hash(x: str):
64
- return hashlib.sha1(x.encode('utf-8')).hexdigest()
65
-
66
-
67
- # ========== 读取配置文件 ==========
68
- with open('rp_config.json', encoding='utf-8') as f:
69
- tmp = f.read()
70
- with open('rp_sample_config.json', encoding='utf-8') as f:
71
- cfg = json.load(f)
72
- cfg['setting_cache_path']['value'] += x_hash(tmp)
73
- cfg.update(json.loads(tmp))
74
-
75
- # ========== 给引号加粗 ==========
76
- reg_q = re.compile(r'“(.+?)”')
77
-
78
-
79
- def chat_display_format(text: str):
80
- return reg_q.sub(r' **\g<0>** ', text)
81
-
82
-
83
- # ========== 温度、采样之类的设置 ==========
84
- with gr.Blocks() as setting:
85
- with gr.Row():
86
- setting_path = gr.Textbox(label="模型路径", max_lines=1, scale=2, **cfg['setting_path'])
87
- setting_cache_path = gr.Textbox(label="缓存路径", max_lines=1, scale=2, **cfg['setting_cache_path'])
88
- setting_seed = gr.Number(label="随机种子", scale=1, **cfg['setting_seed'])
89
- setting_n_gpu_layers = gr.Number(label="n_gpu_layers", scale=1, **cfg['setting_n_gpu_layers'])
90
- with gr.Row():
91
- setting_ctx = gr.Number(label="上下文大小(Tokens)", **cfg['setting_ctx'])
92
- setting_max_tokens = gr.Number(label="最大响应长度(Tokens)", interactive=True, **cfg['setting_max_tokens'])
93
- setting_n_keep = gr.Number(value=10, label="n_keep", interactive=False)
94
- setting_n_discard = gr.Number(label="n_discard", interactive=True, **cfg['setting_n_discard'])
95
- with gr.Row():
96
- setting_temperature = gr.Number(label="温度", interactive=True, **cfg['setting_temperature'])
97
- setting_repeat_penalty = gr.Number(label="重复惩罚", interactive=True, **cfg['setting_repeat_penalty'])
98
- setting_frequency_penalty = gr.Number(label="频率惩罚", interactive=True, **cfg['setting_frequency_penalty'])
99
- setting_presence_penalty = gr.Number(label="存在惩罚", interactive=True, **cfg['setting_presence_penalty'])
100
- setting_repeat_last_n = gr.Number(label="惩罚范围", interactive=True, **cfg['setting_repeat_last_n'])
101
- with gr.Row():
102
- setting_top_k = gr.Number(label="Top-K", interactive=True, **cfg['setting_top_k'])
103
- setting_top_p = gr.Number(label="Top P", interactive=True, **cfg['setting_top_p'])
104
- setting_min_p = gr.Number(label="Min P", interactive=True, **cfg['setting_min_p'])
105
- setting_typical_p = gr.Number(label="Typical", interactive=True, **cfg['setting_typical_p'])
106
- setting_tfs_z = gr.Number(label="TFS", interactive=True, **cfg['setting_tfs_z'])
107
- with gr.Row():
108
- setting_mirostat_mode = gr.Number(label="Mirostat 模式", **cfg['setting_mirostat_mode'])
109
- setting_mirostat_eta = gr.Number(label="Mirostat 学习率", interactive=True, **cfg['setting_mirostat_eta'])
110
- setting_mirostat_tau = gr.Number(label="Mirostat 目标熵", interactive=True, **cfg['setting_mirostat_tau'])
111
-
112
- # ========== 下载模型 ==========
113
- if os.path.exists(setting_path.value):
114
- print(f"The file {setting_path.value} exists.")
115
- else:
116
- from huggingface_hub import snapshot_download
117
- os.mkdir("downloads")
118
- os.mkdir("cache")
119
- snapshot_download(repo_id='TheBloke/CausalLM-7B-GGUF', local_dir=r'downloads',
120
- allow_patterns='causallm_7b.Q5_K_M.gguf')
121
- snapshot_download(repo_id='Limour/llama-python-streamingllm-cache', repo_type='dataset', local_dir=r'cache')
122
-
123
- # ========== 加载模型 ==========
124
- model = StreamingLLM(model_path=setting_path.value,
125
- seed=setting_seed.value,
126
- n_gpu_layers=setting_n_gpu_layers.value,
127
- n_ctx=setting_ctx.value)
128
- setting_ctx.value = model.n_ctx()
129
-
130
- # ========== 聊天的模版 默认 chatml ==========
131
- chat_template = ChatTemplate(model)
132
-
133
- # ========== 展示角色卡 ==========
134
- with gr.Blocks() as role:
135
- with gr.Row():
136
- role_usr = gr.Textbox(label="用户名称", max_lines=1, interactive=False, **cfg['role_usr'])
137
- role_char = gr.Textbox(label="角色名称", max_lines=1, interactive=False, **cfg['role_char'])
138
-
139
- role_char_d = gr.Textbox(lines=10, label="故事描述", **cfg['role_char_d'])
140
- role_chat_style = gr.Textbox(lines=10, label="回复示例", **cfg['role_chat_style'])
141
-
142
- # model.eval_t([1]) # 这个暖机的 bos [1] 删了就不正常了
143
- if os.path.exists(setting_cache_path.value):
144
- # ========== 加载角色卡-缓存 ==========
145
- tmp = model.load_session(setting_cache_path.value)
146
- print(f'load cache from {setting_cache_path.value} {tmp}')
147
- tmp = chat_template('system',
148
- text_format(role_char_d.value,
149
- char=role_char.value,
150
- user=role_usr.value))
151
- setting_n_keep.value = len(tmp)
152
- tmp = chat_template(role_char.value,
153
- text_format(role_chat_style.value,
154
- char=role_char.value,
155
- user=role_usr.value))
156
- setting_n_keep.value += len(tmp)
157
- # ========== 加载角色卡-第一条消息 ==========
158
- chatbot = []
159
- for one in cfg["role_char_first"]:
160
- one['name'] = text_format(one['name'],
161
- char=role_char.value,
162
- user=role_usr.value)
163
- one['value'] = text_format(one['value'],
164
- char=role_char.value,
165
- user=role_usr.value)
166
- if one['name'] == role_char.value:
167
- chatbot.append((None, chat_display_format(one['value'])))
168
- print(one)
169
- else:
170
- # ========== 加载角色卡-角色描述 ==========
171
- tmp = chat_template('system',
172
- text_format(role_char_d.value,
173
- char=role_char.value,
174
- user=role_usr.value))
175
- setting_n_keep.value = model.eval_t(tmp) # 此内容永久存在
176
-
177
- # ========== 加载角色卡-回复示例 ==========
178
- tmp = chat_template(role_char.value,
179
- text_format(role_chat_style.value,
180
- char=role_char.value,
181
- user=role_usr.value))
182
- setting_n_keep.value = model.eval_t(tmp) # 此内容永久存在
183
-
184
- # ========== 加载角色卡-第一条消息 ==========
185
- chatbot = []
186
- for one in cfg["role_char_first"]:
187
- one['name'] = text_format(one['name'],
188
- char=role_char.value,
189
- user=role_usr.value)
190
- one['value'] = text_format(one['value'],
191
- char=role_char.value,
192
- user=role_usr.value)
193
- if one['name'] == role_char.value:
194
- chatbot.append((None, chat_display_format(one['value'])))
195
- print(one)
196
- tmp = chat_template(one['name'], one['value'])
197
- model.eval_t(tmp) # 此内容随上下文增加将被丢弃
198
-
199
- # ========== 保存角色卡-缓存 ==========
200
- with open(setting_cache_path.value, 'wb') as f:
201
- pass
202
- tmp = model.save_session(setting_cache_path.value)
203
- print(f'save cache {tmp}')
204
- # ========== 上传缓存 ==========
205
- from huggingface_hub import login, CommitScheduler
206
- login(token=os.environ.get("HF_TOKEN"), write_permission=True)
207
- CommitScheduler(repo_id='Limour/llama-python-streamingllm-cache', repo_type='dataset', folder_path='cache')
208
-
209
-
210
- # ========== 流式输出函数 ==========
211
- def btn_submit_com(_n_keep, _n_discard,
212
- _temperature, _repeat_penalty, _frequency_penalty,
213
- _presence_penalty, _repeat_last_n, _top_k,
214
- _top_p, _min_p, _typical_p,
215
- _tfs_z, _mirostat_mode, _mirostat_eta,
216
- _mirostat_tau, _role, _max_tokens):
217
- with lock:
218
- if not session_active:
219
- raise RuntimeError
220
- # ========== 初始化输出模版 ==========
221
- t_bot = chat_template(_role)
222
- completion_tokens = [] # 有可能多个 tokens 才能构成一个 utf-8 编码的文字
223
- history = ''
224
- # ========== 流式输出 ==========
225
- for token in model.generate_t(
226
- tokens=t_bot,
227
- n_keep=_n_keep,
228
- n_discard=_n_discard,
229
- im_start=chat_template.im_start_token,
230
- top_k=_top_k,
231
- top_p=_top_p,
232
- min_p=_min_p,
233
- typical_p=_typical_p,
234
- temp=_temperature,
235
- repeat_penalty=_repeat_penalty,
236
- repeat_last_n=_repeat_last_n,
237
- frequency_penalty=_frequency_penalty,
238
- presence_penalty=_presence_penalty,
239
- tfs_z=_tfs_z,
240
- mirostat_mode=_mirostat_mode,
241
- mirostat_tau=_mirostat_tau,
242
- mirostat_eta=_mirostat_eta,
243
- ):
244
- if token in chat_template.eos or token == chat_template.nlnl:
245
- t_bot.extend(completion_tokens)
246
- print('token in eos', token)
247
  break
248
- completion_tokens.append(token)
249
- all_text = model.str_detokenize(completion_tokens)
250
- if not all_text:
251
- continue
252
- t_bot.extend(completion_tokens)
253
- history += all_text
254
- yield history
255
- if token in chat_template.onenl:
256
- # ========== 移除末尾的换行符 ==========
257
- if t_bot[-2] in chat_template.onenl:
258
- model.venv_pop_token()
259
- break
260
- if t_bot[-2] in chat_template.onerl and t_bot[-3] in chat_template.onenl:
261
- model.venv_pop_token()
262
- break
263
- if history[-2:] == '\n\n': # 各种 'x\n\n' 的token,比如'。\n\n'
264
- print('t_bot[-4:]', t_bot[-4:], repr(model.str_detokenize(t_bot[-4:])),
265
- repr(model.str_detokenize(t_bot[-1:])))
266
- break
267
- if len(t_bot) > _max_tokens:
268
- break
269
- completion_tokens = []
270
- # ========== 查看末尾的换行符 ==========
271
- print('history', repr(history))
272
- # ========== 给 kv_cache 加上输出结束符 ==========
273
- model.eval_t(chat_template.im_end_nl, _n_keep, _n_discard)
274
- t_bot.extend(chat_template.im_end_nl)
275
-
276
-
277
- # ========== 显示用户消息 ==========
278
- def btn_submit_usr(message: str, history):
279
- global session_active
280
- with lock:
281
- if session_active:
282
- raise RuntimeError
283
- session_active = True
284
- # print('btn_submit_usr', message, history)
285
- if history is None:
286
- history = []
287
- return "", history + [[message.strip(), '']], gr.update(interactive=False)
288
-
289
-
290
- # ========== 模型流式响应 ==========
291
- def btn_submit_bot(history, _n_keep, _n_discard,
292
- _temperature, _repeat_penalty, _frequency_penalty,
293
- _presence_penalty, _repeat_last_n, _top_k,
294
- _top_p, _min_p, _typical_p,
295
- _tfs_z, _mirostat_mode, _mirostat_eta,
296
- _mirostat_tau, _usr, _char,
297
- _rag, _max_tokens):
298
- with lock:
299
- if not session_active:
300
- raise RuntimeError
301
- # ========== 需要临时注入的内容 ==========
302
- rag_idx = None
303
- if len(_rag) > 0:
304
- rag_idx = model.venv_create() # 记录 venv_idx
305
- t_rag = chat_template('system', _rag)
306
- model.eval_t(t_rag, _n_keep, _n_discard)
307
- model.venv_create() # 与 t_rag 隔离
308
- # ========== 用户输入 ==========
309
- t_msg = history[-1][0]
310
- t_msg = chat_template(_usr, t_msg)
311
- model.eval_t(t_msg, _n_keep, _n_discard)
312
- # ========== 模型输出 ==========
313
- _tmp = btn_submit_com(_n_keep, _n_discard,
314
- _temperature, _repeat_penalty, _frequency_penalty,
315
- _presence_penalty, _repeat_last_n, _top_k,
316
- _top_p, _min_p, _typical_p,
317
- _tfs_z, _mirostat_mode, _mirostat_eta,
318
- _mirostat_tau, _char, _max_tokens)
319
- for _h in _tmp:
320
- history[-1][1] = _h
321
- yield history, str((model.n_tokens, model.venv))
322
- # ========== 输出完毕后格式化输出 ==========
323
- history[-1][1] = chat_display_format(history[-1][1])
324
- yield history, str((model.n_tokens, model.venv))
325
- # ========== 及时清理上一次生成的旁白 ==========
326
- if vo_idx > 0:
327
- print('vo_idx', vo_idx, model.venv)
328
- model.venv_remove(vo_idx)
329
- print('vo_idx', vo_idx, model.venv)
330
- if rag_idx and vo_idx < rag_idx:
331
- rag_idx -= 1
332
- # ========== 响应完毕后清除注入的内容 ==========
333
- if rag_idx is not None:
334
- model.venv_remove(rag_idx) # 销毁对应的 venv
335
- model.venv_disband() # 退出隔离环境
336
- yield history, str((model.n_tokens, model.venv))
337
- print('venv_disband', vo_idx, model.venv)
338
-
339
-
340
- # ========== 待实现 ==========
341
- def btn_rag_(_rag, _msg):
342
- retn = ''
343
- return retn
344
-
345
-
346
- vo_idx = 0
347
-
348
-
349
- # ========== 输出一段旁白 ==========
350
- def btn_submit_vo(_n_keep, _n_discard,
351
- _temperature, _repeat_penalty, _frequency_penalty,
352
- _presence_penalty, _repeat_last_n, _top_k,
353
- _top_p, _min_p, _typical_p,
354
- _tfs_z, _mirostat_mode, _mirostat_eta,
355
- _mirostat_tau, _max_tokens):
356
- with lock:
357
- if not session_active:
358
- raise RuntimeError
359
- global vo_idx
360
- vo_idx = model.venv_create() # 创建隔离环境
361
- # ========== 模型输出旁白 ==========
362
- _tmp = btn_submit_com(_n_keep, _n_discard,
363
- _temperature, _repeat_penalty, _frequency_penalty,
364
- _presence_penalty, _repeat_last_n, _top_k,
365
- _top_p, _min_p, _typical_p,
366
- _tfs_z, _mirostat_mode, _mirostat_eta,
367
- _mirostat_tau, '旁白', _max_tokens)
368
- for _h in _tmp:
369
- yield _h, str((model.n_tokens, model.venv))
370
-
371
-
372
- # ========== 给用户提供默认回复 ==========
373
- def btn_submit_suggest(_n_keep, _n_discard,
374
- _temperature, _repeat_penalty, _frequency_penalty,
375
- _presence_penalty, _repeat_last_n, _top_k,
376
- _top_p, _min_p, _typical_p,
377
- _tfs_z, _mirostat_mode, _mirostat_eta,
378
- _mirostat_tau, _usr, _max_tokens):
379
- with lock:
380
- if not session_active:
381
- raise RuntimeError
382
- model.venv_create() # 创建隔离环境
383
- # ========== 模型输出 ==========
384
- _tmp = btn_submit_com(_n_keep, _n_discard,
385
- _temperature, _repeat_penalty, _frequency_penalty,
386
- _presence_penalty, _repeat_last_n, _top_k,
387
- _top_p, _min_p, _typical_p,
388
- _tfs_z, _mirostat_mode, _mirostat_eta,
389
- _mirostat_tau, _usr, _max_tokens)
390
- _h = ''
391
- for _h in _tmp:
392
- yield _h, str((model.n_tokens, model.venv))
393
- model.venv_remove() # 销毁隔离环境
394
- yield _h, str((model.n_tokens, model.venv))
395
-
396
-
397
- def btn_submit_finish():
398
- global session_active
399
- with lock:
400
- if not session_active:
401
- raise RuntimeError
402
- session_active = False
403
- return gr.update(interactive=True)
404
-
405
-
406
- # ========== 聊天页面 ==========
407
- with gr.Blocks() as chatting:
408
- with gr.Row(equal_height=True):
409
- chatbot = gr.Chatbot(height='60vh', scale=2, value=chatbot,
410
- avatar_images=(r'assets/user.png', r'assets/chatbot.webp'))
411
- with gr.Column(scale=1, elem_id="area"):
412
- rag = gr.Textbox(label='RAG', show_copy_button=True, elem_id="RAG-area")
413
- vo = gr.Textbox(label='VO', show_copy_button=True, elem_id="VO-area")
414
- s_info = gr.Textbox(value=str((model.n_tokens, model.venv)), max_lines=1, label='info', interactive=False)
415
- msg = gr.Textbox(label='Prompt', lines=2, max_lines=2, elem_id='prompt', autofocus=True, **cfg['msg'])
416
- with gr.Row():
417
- btn_rag = gr.Button("RAG")
418
- btn_submit = gr.Button("Submit")
419
- btn_retry = gr.Button("Retry")
420
- btn_com1 = gr.Button("自定义1")
421
- btn_com2 = gr.Button("自定义2")
422
- btn_com3 = gr.Button("自定义3")
423
-
424
- btn_rag.click(fn=btn_rag_, outputs=rag,
425
- inputs=[rag, msg])
426
-
427
- btn_submit.click(
428
- fn=btn_submit_usr, api_name="submit",
429
- inputs=[msg, chatbot],
430
- outputs=[msg, chatbot, btn_submit]
431
- ).success(
432
- fn=btn_submit_bot,
433
- inputs=[chatbot, setting_n_keep, setting_n_discard,
434
- setting_temperature, setting_repeat_penalty, setting_frequency_penalty,
435
- setting_presence_penalty, setting_repeat_last_n, setting_top_k,
436
- setting_top_p, setting_min_p, setting_typical_p,
437
- setting_tfs_z, setting_mirostat_mode, setting_mirostat_eta,
438
- setting_mirostat_tau, role_usr, role_char,
439
- rag, setting_max_tokens],
440
- outputs=[chatbot, s_info]
441
- ).success(
442
- fn=btn_submit_vo,
443
- inputs=[setting_n_keep, setting_n_discard,
444
- setting_temperature, setting_repeat_penalty, setting_frequency_penalty,
445
- setting_presence_penalty, setting_repeat_last_n, setting_top_k,
446
- setting_top_p, setting_min_p, setting_typical_p,
447
- setting_tfs_z, setting_mirostat_mode, setting_mirostat_eta,
448
- setting_mirostat_tau, setting_max_tokens],
449
- outputs=[vo, s_info]
450
- ).success(
451
- fn=btn_submit_suggest,
452
- inputs=[setting_n_keep, setting_n_discard,
453
- setting_temperature, setting_repeat_penalty, setting_frequency_penalty,
454
- setting_presence_penalty, setting_repeat_last_n, setting_top_k,
455
- setting_top_p, setting_min_p, setting_typical_p,
456
- setting_tfs_z, setting_mirostat_mode, setting_mirostat_eta,
457
- setting_mirostat_tau, role_usr, setting_max_tokens],
458
- outputs=[msg, s_info]
459
- ).success(
460
- fn=btn_submit_finish,
461
- outputs=btn_submit
462
- )
463
-
464
- # ========== 用于调试 ==========
465
- # btn_com1.click(fn=lambda: model.str_detokenize(model._input_ids), outputs=rag)
466
-
467
-
468
- @btn_com2.click(inputs=setting_cache_path,
469
- outputs=[s_info, btn_submit])
470
- def btn_com2(_cache_path):
471
- try:
472
- with lock:
473
- _tmp = model.load_session(setting_cache_path.value)
474
- print(f'load cache from {setting_cache_path.value} {_tmp}')
475
- global vo_idx
476
- vo_idx = 0
477
- model.venv = [0]
478
- global session_active
479
- session_active = False
480
- return str((model.n_tokens, model.venv)), gr.update(interactive=True)
481
- except Exception as e:
482
- restart_space()
483
- raise e
484
-
485
- @btn_com3.click()
486
- def btn_com3():
487
- restart_space()
488
-
489
-
490
- # ========== 开始运行 ==========
491
- demo = gr.TabbedInterface([chatting, setting, role],
492
- ["聊天", "设置", '角色'],
493
- css=custom_css)
494
- gr.close_all()
495
- demo.queue(max_size=1).launch(share=False)
 
1
+ import subprocess
2
+ from hf_api import restart_space
3
+
4
+ try:
5
+ # 启动另一个程序,并通过管道捕获其输出
6
+ process = subprocess.Popen(["python", "sub_app.py"], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
7
+ while True:
8
+ output = process.stdout.readline()
9
+ if output:
10
+ print(output.decode("utf-8").strip())
11
+ if process.poll() is not None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  break
13
+ finally:
14
+ restart_space()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sub_app.py ADDED
@@ -0,0 +1,496 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os
3
+ import re
4
+ import json
5
+ import threading
6
+ from hf_api import restart_space
7
+
8
+ import gradio as gr
9
+
10
+ from chat_template import ChatTemplate
11
+ from llama_cpp_python_streamingllm import StreamingLLM
12
+
13
+ # ========== 全局锁,确保只能进行一个会话 ==========
14
+ lock = threading.Lock()
15
+ session_active = False
16
+
17
+ # ========== 让聊天界面的文本框等高 ==========
18
+ custom_css = r'''
19
+ #area > div {
20
+ height: 100%;
21
+ }
22
+ #RAG-area {
23
+ flex-grow: 1;
24
+ }
25
+ #RAG-area > label {
26
+ height: 100%;
27
+ display: flex;
28
+ flex-direction: column;
29
+ }
30
+ #RAG-area > label > textarea {
31
+ flex-grow: 1;
32
+ max-height: 20vh;
33
+ }
34
+ #VO-area {
35
+ flex-grow: 1;
36
+ }
37
+ #VO-area > label {
38
+ height: 100%;
39
+ display: flex;
40
+ flex-direction: column;
41
+ }
42
+ #VO-area > label > textarea {
43
+ flex-grow: 1;
44
+ max-height: 20vh;
45
+ }
46
+ #prompt > label > textarea {
47
+ max-height: 63px;
48
+ }
49
+ '''
50
+
51
+
52
+ # ========== 适配 SillyTavern 的模版 ==========
53
+ def text_format(text: str, _env=None, **env):
54
+ if _env is not None:
55
+ for k, v in _env.items():
56
+ text = text.replace(r'{{' + k + r'}}', v)
57
+ for k, v in env.items():
58
+ text = text.replace(r'{{' + k + r'}}', v)
59
+ return text
60
+
61
+
62
+ # ========== 哈希函数 ==========
63
+ def x_hash(x: str):
64
+ return hashlib.sha1(x.encode('utf-8')).hexdigest()
65
+
66
+
67
+ # ========== 读取配置文件 ==========
68
+ with open('rp_config.json', encoding='utf-8') as f:
69
+ tmp = f.read()
70
+ with open('rp_sample_config.json', encoding='utf-8') as f:
71
+ cfg = json.load(f)
72
+ cfg['setting_cache_path']['value'] += x_hash(tmp)
73
+ cfg.update(json.loads(tmp))
74
+
75
+ # ========== 给引号加粗 ==========
76
+ reg_q = re.compile(r'“(.+?)”')
77
+
78
+
79
+ def chat_display_format(text: str):
80
+ return reg_q.sub(r' **\g<0>** ', text)
81
+
82
+
83
+ # ========== 温度、采样之类的设置 ==========
84
+ with gr.Blocks() as setting:
85
+ with gr.Row():
86
+ setting_path = gr.Textbox(label="模型路径", max_lines=1, scale=2, **cfg['setting_path'])
87
+ setting_cache_path = gr.Textbox(label="缓存路径", max_lines=1, scale=2, **cfg['setting_cache_path'])
88
+ setting_seed = gr.Number(label="随机种子", scale=1, **cfg['setting_seed'])
89
+ setting_n_gpu_layers = gr.Number(label="n_gpu_layers", scale=1, **cfg['setting_n_gpu_layers'])
90
+ with gr.Row():
91
+ setting_ctx = gr.Number(label="上下文大小(Tokens)", **cfg['setting_ctx'])
92
+ setting_max_tokens = gr.Number(label="最大响应长度(Tokens)", interactive=True, **cfg['setting_max_tokens'])
93
+ setting_n_keep = gr.Number(value=10, label="n_keep", interactive=False)
94
+ setting_n_discard = gr.Number(label="n_discard", interactive=True, **cfg['setting_n_discard'])
95
+ with gr.Row():
96
+ setting_temperature = gr.Number(label="温度", interactive=True, **cfg['setting_temperature'])
97
+ setting_repeat_penalty = gr.Number(label="重复惩罚", interactive=True, **cfg['setting_repeat_penalty'])
98
+ setting_frequency_penalty = gr.Number(label="频率惩罚", interactive=True, **cfg['setting_frequency_penalty'])
99
+ setting_presence_penalty = gr.Number(label="存在惩罚", interactive=True, **cfg['setting_presence_penalty'])
100
+ setting_repeat_last_n = gr.Number(label="惩罚范围", interactive=True, **cfg['setting_repeat_last_n'])
101
+ with gr.Row():
102
+ setting_top_k = gr.Number(label="Top-K", interactive=True, **cfg['setting_top_k'])
103
+ setting_top_p = gr.Number(label="Top P", interactive=True, **cfg['setting_top_p'])
104
+ setting_min_p = gr.Number(label="Min P", interactive=True, **cfg['setting_min_p'])
105
+ setting_typical_p = gr.Number(label="Typical", interactive=True, **cfg['setting_typical_p'])
106
+ setting_tfs_z = gr.Number(label="TFS", interactive=True, **cfg['setting_tfs_z'])
107
+ with gr.Row():
108
+ setting_mirostat_mode = gr.Number(label="Mirostat 模式", **cfg['setting_mirostat_mode'])
109
+ setting_mirostat_eta = gr.Number(label="Mirostat 学习率", interactive=True, **cfg['setting_mirostat_eta'])
110
+ setting_mirostat_tau = gr.Number(label="Mirostat 目标熵", interactive=True, **cfg['setting_mirostat_tau'])
111
+
112
+ # ========== 下载模型 ==========
113
+ if os.path.exists(setting_path.value):
114
+ print(f"The file {setting_path.value} exists.")
115
+ else:
116
+ from huggingface_hub import snapshot_download
117
+
118
+ os.mkdir("downloads")
119
+ os.mkdir("cache")
120
+ snapshot_download(repo_id='TheBloke/CausalLM-7B-GGUF', local_dir=r'downloads',
121
+ allow_patterns='causallm_7b.Q5_K_M.gguf')
122
+ snapshot_download(repo_id='Limour/llama-python-streamingllm-cache', repo_type='dataset', local_dir=r'cache')
123
+
124
+ # ========== 加载模型 ==========
125
+ model = StreamingLLM(model_path=setting_path.value,
126
+ seed=setting_seed.value,
127
+ n_gpu_layers=setting_n_gpu_layers.value,
128
+ n_ctx=setting_ctx.value)
129
+ setting_ctx.value = model.n_ctx()
130
+
131
+ # ========== 聊天的模版 默认 chatml ==========
132
+ chat_template = ChatTemplate(model)
133
+
134
+ # ========== 展示角色卡 ==========
135
+ with gr.Blocks() as role:
136
+ with gr.Row():
137
+ role_usr = gr.Textbox(label="用户名称", max_lines=1, interactive=False, **cfg['role_usr'])
138
+ role_char = gr.Textbox(label="角色名称", max_lines=1, interactive=False, **cfg['role_char'])
139
+
140
+ role_char_d = gr.Textbox(lines=10, label="故事描述", **cfg['role_char_d'])
141
+ role_chat_style = gr.Textbox(lines=10, label="回复示例", **cfg['role_chat_style'])
142
+
143
+ # model.eval_t([1]) # 这个暖机的 bos [1] 删了就不正常了
144
+ if os.path.exists(setting_cache_path.value):
145
+ # ========== 加载角色卡-缓存 ==========
146
+ tmp = model.load_session(setting_cache_path.value)
147
+ print(f'load cache from {setting_cache_path.value} {tmp}')
148
+ tmp = chat_template('system',
149
+ text_format(role_char_d.value,
150
+ char=role_char.value,
151
+ user=role_usr.value))
152
+ setting_n_keep.value = len(tmp)
153
+ tmp = chat_template(role_char.value,
154
+ text_format(role_chat_style.value,
155
+ char=role_char.value,
156
+ user=role_usr.value))
157
+ setting_n_keep.value += len(tmp)
158
+ # ========== 加载角色卡-第一条消息 ==========
159
+ chatbot = []
160
+ for one in cfg["role_char_first"]:
161
+ one['name'] = text_format(one['name'],
162
+ char=role_char.value,
163
+ user=role_usr.value)
164
+ one['value'] = text_format(one['value'],
165
+ char=role_char.value,
166
+ user=role_usr.value)
167
+ if one['name'] == role_char.value:
168
+ chatbot.append((None, chat_display_format(one['value'])))
169
+ print(one)
170
+ else:
171
+ # ========== 加载角色卡-角色描述 ==========
172
+ tmp = chat_template('system',
173
+ text_format(role_char_d.value,
174
+ char=role_char.value,
175
+ user=role_usr.value))
176
+ setting_n_keep.value = model.eval_t(tmp) # 此内容永久存在
177
+
178
+ # ========== 加载角色卡-回复示例 ==========
179
+ tmp = chat_template(role_char.value,
180
+ text_format(role_chat_style.value,
181
+ char=role_char.value,
182
+ user=role_usr.value))
183
+ setting_n_keep.value = model.eval_t(tmp) # 此内容永久存在
184
+
185
+ # ========== 加载角色卡-第一条消息 ==========
186
+ chatbot = []
187
+ for one in cfg["role_char_first"]:
188
+ one['name'] = text_format(one['name'],
189
+ char=role_char.value,
190
+ user=role_usr.value)
191
+ one['value'] = text_format(one['value'],
192
+ char=role_char.value,
193
+ user=role_usr.value)
194
+ if one['name'] == role_char.value:
195
+ chatbot.append((None, chat_display_format(one['value'])))
196
+ print(one)
197
+ tmp = chat_template(one['name'], one['value'])
198
+ model.eval_t(tmp) # 此内容随上下文增加将被丢弃
199
+
200
+ # ========== 保存角色卡-缓存 ==========
201
+ with open(setting_cache_path.value, 'wb') as f:
202
+ pass
203
+ tmp = model.save_session(setting_cache_path.value)
204
+ print(f'save cache {tmp}')
205
+ # ========== 上传缓存 ==========
206
+ from huggingface_hub import login, CommitScheduler
207
+
208
+ login(token=os.environ.get("HF_TOKEN"), write_permission=True)
209
+ CommitScheduler(repo_id='Limour/llama-python-streamingllm-cache', repo_type='dataset', folder_path='cache')
210
+
211
+
212
+ # ========== 流式输出函数 ==========
213
+ def btn_submit_com(_n_keep, _n_discard,
214
+ _temperature, _repeat_penalty, _frequency_penalty,
215
+ _presence_penalty, _repeat_last_n, _top_k,
216
+ _top_p, _min_p, _typical_p,
217
+ _tfs_z, _mirostat_mode, _mirostat_eta,
218
+ _mirostat_tau, _role, _max_tokens):
219
+ with lock:
220
+ if not session_active:
221
+ raise RuntimeError
222
+ # ========== 初始化输出模版 ==========
223
+ t_bot = chat_template(_role)
224
+ completion_tokens = [] # 有可能多个 tokens 才能构成一个 utf-8 编码的文字
225
+ history = ''
226
+ # ========== 流式输出 ==========
227
+ for token in model.generate_t(
228
+ tokens=t_bot,
229
+ n_keep=_n_keep,
230
+ n_discard=_n_discard,
231
+ im_start=chat_template.im_start_token,
232
+ top_k=_top_k,
233
+ top_p=_top_p,
234
+ min_p=_min_p,
235
+ typical_p=_typical_p,
236
+ temp=_temperature,
237
+ repeat_penalty=_repeat_penalty,
238
+ repeat_last_n=_repeat_last_n,
239
+ frequency_penalty=_frequency_penalty,
240
+ presence_penalty=_presence_penalty,
241
+ tfs_z=_tfs_z,
242
+ mirostat_mode=_mirostat_mode,
243
+ mirostat_tau=_mirostat_tau,
244
+ mirostat_eta=_mirostat_eta,
245
+ ):
246
+ if token in chat_template.eos or token == chat_template.nlnl:
247
+ t_bot.extend(completion_tokens)
248
+ print('token in eos', token)
249
+ break
250
+ completion_tokens.append(token)
251
+ all_text = model.str_detokenize(completion_tokens)
252
+ if not all_text:
253
+ continue
254
+ t_bot.extend(completion_tokens)
255
+ history += all_text
256
+ yield history
257
+ if token in chat_template.onenl:
258
+ # ========== 移除末尾的换行符 ==========
259
+ if t_bot[-2] in chat_template.onenl:
260
+ model.venv_pop_token()
261
+ break
262
+ if t_bot[-2] in chat_template.onerl and t_bot[-3] in chat_template.onenl:
263
+ model.venv_pop_token()
264
+ break
265
+ if history[-2:] == '\n\n': # 各种 'x\n\n' 的token,比如'。\n\n'
266
+ print('t_bot[-4:]', t_bot[-4:], repr(model.str_detokenize(t_bot[-4:])),
267
+ repr(model.str_detokenize(t_bot[-1:])))
268
+ break
269
+ if len(t_bot) > _max_tokens:
270
+ break
271
+ completion_tokens = []
272
+ # ========== 查看末尾的换行符 ==========
273
+ print('history', repr(history))
274
+ # ========== 给 kv_cache 加上输出结束符 ==========
275
+ model.eval_t(chat_template.im_end_nl, _n_keep, _n_discard)
276
+ t_bot.extend(chat_template.im_end_nl)
277
+
278
+
279
+ # ========== 显示用户消息 ==========
280
+ def btn_submit_usr(message: str, history):
281
+ global session_active
282
+ with lock:
283
+ if session_active:
284
+ raise RuntimeError
285
+ session_active = True
286
+ # print('btn_submit_usr', message, history)
287
+ if history is None:
288
+ history = []
289
+ return "", history + [[message.strip(), '']], gr.update(interactive=False)
290
+
291
+
292
+ # ========== 模型流式响应 ==========
293
+ def btn_submit_bot(history, _n_keep, _n_discard,
294
+ _temperature, _repeat_penalty, _frequency_penalty,
295
+ _presence_penalty, _repeat_last_n, _top_k,
296
+ _top_p, _min_p, _typical_p,
297
+ _tfs_z, _mirostat_mode, _mirostat_eta,
298
+ _mirostat_tau, _usr, _char,
299
+ _rag, _max_tokens):
300
+ with lock:
301
+ if not session_active:
302
+ raise RuntimeError
303
+ # ========== 需要临时注入的内容 ==========
304
+ rag_idx = None
305
+ if len(_rag) > 0:
306
+ rag_idx = model.venv_create() # 记录 venv_idx
307
+ t_rag = chat_template('system', _rag)
308
+ model.eval_t(t_rag, _n_keep, _n_discard)
309
+ model.venv_create() # 与 t_rag 隔离
310
+ # ========== 用户输入 ==========
311
+ t_msg = history[-1][0]
312
+ t_msg = chat_template(_usr, t_msg)
313
+ model.eval_t(t_msg, _n_keep, _n_discard)
314
+ # ========== 模型输出 ==========
315
+ _tmp = btn_submit_com(_n_keep, _n_discard,
316
+ _temperature, _repeat_penalty, _frequency_penalty,
317
+ _presence_penalty, _repeat_last_n, _top_k,
318
+ _top_p, _min_p, _typical_p,
319
+ _tfs_z, _mirostat_mode, _mirostat_eta,
320
+ _mirostat_tau, _char, _max_tokens)
321
+ for _h in _tmp:
322
+ history[-1][1] = _h
323
+ yield history, str((model.n_tokens, model.venv))
324
+ # ========== 输出完毕后格式化输出 ==========
325
+ history[-1][1] = chat_display_format(history[-1][1])
326
+ yield history, str((model.n_tokens, model.venv))
327
+ # ========== 及时清理上一次生成的旁白 ==========
328
+ if vo_idx > 0:
329
+ print('vo_idx', vo_idx, model.venv)
330
+ model.venv_remove(vo_idx)
331
+ print('vo_idx', vo_idx, model.venv)
332
+ if rag_idx and vo_idx < rag_idx:
333
+ rag_idx -= 1
334
+ # ========== 响应完毕后清除注入的内容 ==========
335
+ if rag_idx is not None:
336
+ model.venv_remove(rag_idx) # 销毁对应的 venv
337
+ model.venv_disband() # 退出隔离环境
338
+ yield history, str((model.n_tokens, model.venv))
339
+ print('venv_disband', vo_idx, model.venv)
340
+
341
+
342
+ # ========== 待实现 ==========
343
+ def btn_rag_(_rag, _msg):
344
+ retn = ''
345
+ return retn
346
+
347
+
348
+ vo_idx = 0
349
+
350
+
351
+ # ========== 输出一段旁白 ==========
352
+ def btn_submit_vo(_n_keep, _n_discard,
353
+ _temperature, _repeat_penalty, _frequency_penalty,
354
+ _presence_penalty, _repeat_last_n, _top_k,
355
+ _top_p, _min_p, _typical_p,
356
+ _tfs_z, _mirostat_mode, _mirostat_eta,
357
+ _mirostat_tau, _max_tokens):
358
+ with lock:
359
+ if not session_active:
360
+ raise RuntimeError
361
+ global vo_idx
362
+ vo_idx = model.venv_create() # 创建隔离环境
363
+ # ========== 模型输出旁白 ==========
364
+ _tmp = btn_submit_com(_n_keep, _n_discard,
365
+ _temperature, _repeat_penalty, _frequency_penalty,
366
+ _presence_penalty, _repeat_last_n, _top_k,
367
+ _top_p, _min_p, _typical_p,
368
+ _tfs_z, _mirostat_mode, _mirostat_eta,
369
+ _mirostat_tau, '旁白', _max_tokens)
370
+ for _h in _tmp:
371
+ yield _h, str((model.n_tokens, model.venv))
372
+
373
+
374
+ # ========== 给用户提供默认回复 ==========
375
+ def btn_submit_suggest(_n_keep, _n_discard,
376
+ _temperature, _repeat_penalty, _frequency_penalty,
377
+ _presence_penalty, _repeat_last_n, _top_k,
378
+ _top_p, _min_p, _typical_p,
379
+ _tfs_z, _mirostat_mode, _mirostat_eta,
380
+ _mirostat_tau, _usr, _max_tokens):
381
+ with lock:
382
+ if not session_active:
383
+ raise RuntimeError
384
+ model.venv_create() # 创建隔离环境
385
+ # ========== 模型输出 ==========
386
+ _tmp = btn_submit_com(_n_keep, _n_discard,
387
+ _temperature, _repeat_penalty, _frequency_penalty,
388
+ _presence_penalty, _repeat_last_n, _top_k,
389
+ _top_p, _min_p, _typical_p,
390
+ _tfs_z, _mirostat_mode, _mirostat_eta,
391
+ _mirostat_tau, _usr, _max_tokens)
392
+ _h = ''
393
+ for _h in _tmp:
394
+ yield _h, str((model.n_tokens, model.venv))
395
+ model.venv_remove() # 销毁隔离环境
396
+ yield _h, str((model.n_tokens, model.venv))
397
+
398
+
399
+ def btn_submit_finish():
400
+ global session_active
401
+ with lock:
402
+ if not session_active:
403
+ raise RuntimeError
404
+ session_active = False
405
+ return gr.update(interactive=True)
406
+
407
+
408
+ # ========== 聊天页面 ==========
409
+ with gr.Blocks() as chatting:
410
+ with gr.Row(equal_height=True):
411
+ chatbot = gr.Chatbot(height='60vh', scale=2, value=chatbot,
412
+ avatar_images=(r'assets/user.png', r'assets/chatbot.webp'))
413
+ with gr.Column(scale=1, elem_id="area"):
414
+ rag = gr.Textbox(label='RAG', show_copy_button=True, elem_id="RAG-area")
415
+ vo = gr.Textbox(label='VO', show_copy_button=True, elem_id="VO-area")
416
+ s_info = gr.Textbox(value=str((model.n_tokens, model.venv)), max_lines=1, label='info', interactive=False)
417
+ msg = gr.Textbox(label='Prompt', lines=2, max_lines=2, elem_id='prompt', autofocus=True, **cfg['msg'])
418
+ with gr.Row():
419
+ btn_rag = gr.Button("RAG")
420
+ btn_submit = gr.Button("Submit")
421
+ btn_retry = gr.Button("Retry")
422
+ btn_com1 = gr.Button("自定义1")
423
+ btn_com2 = gr.Button("自定义2")
424
+ btn_com3 = gr.Button("自定义3")
425
+
426
+ btn_rag.click(fn=btn_rag_, outputs=rag,
427
+ inputs=[rag, msg])
428
+
429
+ btn_submit.click(
430
+ fn=btn_submit_usr, api_name="submit",
431
+ inputs=[msg, chatbot],
432
+ outputs=[msg, chatbot, btn_submit]
433
+ ).success(
434
+ fn=btn_submit_bot,
435
+ inputs=[chatbot, setting_n_keep, setting_n_discard,
436
+ setting_temperature, setting_repeat_penalty, setting_frequency_penalty,
437
+ setting_presence_penalty, setting_repeat_last_n, setting_top_k,
438
+ setting_top_p, setting_min_p, setting_typical_p,
439
+ setting_tfs_z, setting_mirostat_mode, setting_mirostat_eta,
440
+ setting_mirostat_tau, role_usr, role_char,
441
+ rag, setting_max_tokens],
442
+ outputs=[chatbot, s_info]
443
+ ).success(
444
+ fn=btn_submit_vo,
445
+ inputs=[setting_n_keep, setting_n_discard,
446
+ setting_temperature, setting_repeat_penalty, setting_frequency_penalty,
447
+ setting_presence_penalty, setting_repeat_last_n, setting_top_k,
448
+ setting_top_p, setting_min_p, setting_typical_p,
449
+ setting_tfs_z, setting_mirostat_mode, setting_mirostat_eta,
450
+ setting_mirostat_tau, setting_max_tokens],
451
+ outputs=[vo, s_info]
452
+ ).success(
453
+ fn=btn_submit_suggest,
454
+ inputs=[setting_n_keep, setting_n_discard,
455
+ setting_temperature, setting_repeat_penalty, setting_frequency_penalty,
456
+ setting_presence_penalty, setting_repeat_last_n, setting_top_k,
457
+ setting_top_p, setting_min_p, setting_typical_p,
458
+ setting_tfs_z, setting_mirostat_mode, setting_mirostat_eta,
459
+ setting_mirostat_tau, role_usr, setting_max_tokens],
460
+ outputs=[msg, s_info]
461
+ ).success(
462
+ fn=btn_submit_finish,
463
+ outputs=btn_submit
464
+ )
465
+
466
+
467
+ # ========== 用于调试 ==========
468
+ # btn_com1.click(fn=lambda: model.str_detokenize(model._input_ids), outputs=rag)
469
+
470
+ @btn_com2.click(inputs=setting_cache_path,
471
+ outputs=[s_info, btn_submit])
472
+ def btn_com2(_cache_path):
473
+ try:
474
+ with lock:
475
+ _tmp = model.load_session(setting_cache_path.value)
476
+ print(f'load cache from {setting_cache_path.value} {_tmp}')
477
+ global vo_idx
478
+ vo_idx = 0
479
+ model.venv = [0]
480
+ global session_active
481
+ session_active = False
482
+ return str((model.n_tokens, model.venv)), gr.update(interactive=True)
483
+ except Exception as e:
484
+ restart_space()
485
+ raise e
486
+
487
+ # @btn_com3.click()
488
+ # def btn_com3():
489
+ # restart_space()
490
+
491
+ # ========== 开始运行 ==========
492
+ demo = gr.TabbedInterface([chatting, setting, role],
493
+ ["聊天", "设置", '角色'],
494
+ css=custom_css)
495
+ gr.close_all()
496
+ demo.queue(max_size=1).launch(share=False)