Limour commited on
Commit
9c5ce26
·
verified ·
1 Parent(s): e4ca6df

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -18
app.py CHANGED
@@ -2,12 +2,17 @@ import hashlib
2
  import os
3
  import re
4
  import json
 
5
 
6
  import gradio as gr
7
 
8
  from chat_template import ChatTemplate
9
  from llama_cpp_python_streamingllm import StreamingLLM
10
 
 
 
 
 
11
  # ========== 让聊天界面的文本框等高 ==========
12
  custom_css = r'''
13
  #area > div {
@@ -208,6 +213,9 @@ def btn_submit_com(_n_keep, _n_discard,
208
  _top_p, _min_p, _typical_p,
209
  _tfs_z, _mirostat_mode, _mirostat_eta,
210
  _mirostat_tau, _role, _max_tokens):
 
 
 
211
  # ========== 初始化输出模版 ==========
212
  t_bot = chat_template(_role)
213
  completion_tokens = [] # 有可能多个 tokens 才能构成一个 utf-8 编码的文字
@@ -267,10 +275,15 @@ def btn_submit_com(_n_keep, _n_discard,
267
 
268
  # ========== 显示用户消息 ==========
269
  def btn_submit_usr(message: str, history):
270
- # print('btn_submit_usr', message, history)
271
- if history is None:
272
- history = []
273
- return "", history + [[message.strip(), '']], gr.update(interactive=False)
 
 
 
 
 
274
 
275
 
276
  # ========== 模型流式响应 ==========
@@ -281,6 +294,9 @@ def btn_submit_bot(history, _n_keep, _n_discard,
281
  _tfs_z, _mirostat_mode, _mirostat_eta,
282
  _mirostat_tau, _usr, _char,
283
  _rag, _max_tokens):
 
 
 
284
  # ========== 需要临时注入的内容 ==========
285
  rag_idx = None
286
  if len(_rag) > 0:
@@ -336,6 +352,9 @@ def btn_submit_vo(_n_keep, _n_discard,
336
  _top_p, _min_p, _typical_p,
337
  _tfs_z, _mirostat_mode, _mirostat_eta,
338
  _mirostat_tau, _max_tokens):
 
 
 
339
  global vo_idx
340
  vo_idx = model.venv_create() # 创建隔离环境
341
  # ========== 模型输出旁白 ==========
@@ -356,6 +375,9 @@ def btn_submit_suggest(_n_keep, _n_discard,
356
  _top_p, _min_p, _typical_p,
357
  _tfs_z, _mirostat_mode, _mirostat_eta,
358
  _mirostat_tau, _usr, _max_tokens):
 
 
 
359
  model.venv_create() # 创建隔离环境
360
  # ========== 模型输出 ==========
361
  _tmp = btn_submit_com(_n_keep, _n_discard,
@@ -371,6 +393,15 @@ def btn_submit_suggest(_n_keep, _n_discard,
371
  yield _h, str((model.n_tokens, model.venv))
372
 
373
 
 
 
 
 
 
 
 
 
 
374
  # ========== 聊天页面 ==========
375
  with gr.Blocks() as chatting:
376
  with gr.Row(equal_height=True):
@@ -396,7 +427,7 @@ with gr.Blocks() as chatting:
396
  fn=btn_submit_usr, api_name="submit",
397
  inputs=[msg, chatbot],
398
  outputs=[msg, chatbot, btn_submit]
399
- ).then(
400
  fn=btn_submit_bot,
401
  inputs=[chatbot, setting_n_keep, setting_n_discard,
402
  setting_temperature, setting_repeat_penalty, setting_frequency_penalty,
@@ -406,7 +437,7 @@ with gr.Blocks() as chatting:
406
  setting_mirostat_tau, role_usr, role_char,
407
  rag, setting_max_tokens],
408
  outputs=[chatbot, s_info]
409
- ).then(
410
  fn=btn_submit_vo,
411
  inputs=[setting_n_keep, setting_n_discard,
412
  setting_temperature, setting_repeat_penalty, setting_frequency_penalty,
@@ -415,7 +446,7 @@ with gr.Blocks() as chatting:
415
  setting_tfs_z, setting_mirostat_mode, setting_mirostat_eta,
416
  setting_mirostat_tau, setting_max_tokens],
417
  outputs=[vo, s_info]
418
- ).then(
419
  fn=btn_submit_suggest,
420
  inputs=[setting_n_keep, setting_n_discard,
421
  setting_temperature, setting_repeat_penalty, setting_frequency_penalty,
@@ -424,28 +455,31 @@ with gr.Blocks() as chatting:
424
  setting_tfs_z, setting_mirostat_mode, setting_mirostat_eta,
425
  setting_mirostat_tau, role_usr, setting_max_tokens],
426
  outputs=[msg, s_info]
427
- ).then(
428
- fn=lambda: gr.update(interactive=True),
429
  outputs=btn_submit
430
  )
431
 
432
  # ========== 用于调试 ==========
433
- btn_com1.click(fn=lambda: model.str_detokenize(model._input_ids), outputs=rag)
434
 
435
 
436
  @btn_com2.click(inputs=setting_cache_path,
437
- outputs=s_info)
438
  def btn_com2(_cache_path):
439
- _tmp = model.load_session(setting_cache_path.value)
440
- print(f'load cache from {setting_cache_path.value} {_tmp}')
441
- global vo_idx
442
- vo_idx = 0
443
- model.venv = [0]
444
- return str((model.n_tokens, model.venv))
 
 
 
445
 
446
  # ========== 开始运行 ==========
447
  demo = gr.TabbedInterface([chatting, setting, role],
448
  ["聊天", "设置", '角色'],
449
  css=custom_css)
450
  gr.close_all()
451
- demo.queue().launch(share=False)
 
2
  import os
3
  import re
4
  import json
5
+ import threading
6
 
7
  import gradio as gr
8
 
9
  from chat_template import ChatTemplate
10
  from llama_cpp_python_streamingllm import StreamingLLM
11
 
12
+ # ========== 全局锁,确保只能进行一个会话 ==========
13
+ lock = threading.Lock()
14
+ session_active = False
15
+
16
  # ========== 让聊天界面的文本框等高 ==========
17
  custom_css = r'''
18
  #area > div {
 
213
  _top_p, _min_p, _typical_p,
214
  _tfs_z, _mirostat_mode, _mirostat_eta,
215
  _mirostat_tau, _role, _max_tokens):
216
+ with lock:
217
+ if not session_active:
218
+ raise RuntimeError
219
  # ========== 初始化输出模版 ==========
220
  t_bot = chat_template(_role)
221
  completion_tokens = [] # 有可能多个 tokens 才能构成一个 utf-8 编码的文字
 
275
 
276
  # ========== 显示用户消息 ==========
277
  def btn_submit_usr(message: str, history):
278
+ global session_active
279
+ with lock:
280
+ if session_active:
281
+ raise RuntimeError
282
+ session_active = True
283
+ # print('btn_submit_usr', message, history)
284
+ if history is None:
285
+ history = []
286
+ return "", history + [[message.strip(), '']], gr.update(interactive=False)
287
 
288
 
289
  # ========== 模型流式响应 ==========
 
294
  _tfs_z, _mirostat_mode, _mirostat_eta,
295
  _mirostat_tau, _usr, _char,
296
  _rag, _max_tokens):
297
+ with lock:
298
+ if not session_active:
299
+ raise RuntimeError
300
  # ========== 需要临时注入的内容 ==========
301
  rag_idx = None
302
  if len(_rag) > 0:
 
352
  _top_p, _min_p, _typical_p,
353
  _tfs_z, _mirostat_mode, _mirostat_eta,
354
  _mirostat_tau, _max_tokens):
355
+ with lock:
356
+ if not session_active:
357
+ raise RuntimeError
358
  global vo_idx
359
  vo_idx = model.venv_create() # 创建隔离环境
360
  # ========== 模型输出旁白 ==========
 
375
  _top_p, _min_p, _typical_p,
376
  _tfs_z, _mirostat_mode, _mirostat_eta,
377
  _mirostat_tau, _usr, _max_tokens):
378
+ with lock:
379
+ if not session_active:
380
+ raise RuntimeError
381
  model.venv_create() # 创建隔离环境
382
  # ========== 模型输出 ==========
383
  _tmp = btn_submit_com(_n_keep, _n_discard,
 
393
  yield _h, str((model.n_tokens, model.venv))
394
 
395
 
396
+ def btn_submit_finish():
397
+ global session_active
398
+ with lock:
399
+ if not session_active:
400
+ raise RuntimeError
401
+ session_active = False
402
+ return gr.update(interactive=True)
403
+
404
+
405
  # ========== 聊天页面 ==========
406
  with gr.Blocks() as chatting:
407
  with gr.Row(equal_height=True):
 
427
  fn=btn_submit_usr, api_name="submit",
428
  inputs=[msg, chatbot],
429
  outputs=[msg, chatbot, btn_submit]
430
+ ).success(
431
  fn=btn_submit_bot,
432
  inputs=[chatbot, setting_n_keep, setting_n_discard,
433
  setting_temperature, setting_repeat_penalty, setting_frequency_penalty,
 
437
  setting_mirostat_tau, role_usr, role_char,
438
  rag, setting_max_tokens],
439
  outputs=[chatbot, s_info]
440
+ ).success(
441
  fn=btn_submit_vo,
442
  inputs=[setting_n_keep, setting_n_discard,
443
  setting_temperature, setting_repeat_penalty, setting_frequency_penalty,
 
446
  setting_tfs_z, setting_mirostat_mode, setting_mirostat_eta,
447
  setting_mirostat_tau, setting_max_tokens],
448
  outputs=[vo, s_info]
449
+ ).success(
450
  fn=btn_submit_suggest,
451
  inputs=[setting_n_keep, setting_n_discard,
452
  setting_temperature, setting_repeat_penalty, setting_frequency_penalty,
 
455
  setting_tfs_z, setting_mirostat_mode, setting_mirostat_eta,
456
  setting_mirostat_tau, role_usr, setting_max_tokens],
457
  outputs=[msg, s_info]
458
+ ).success(
459
+ fn=btn_submit_finish,
460
  outputs=btn_submit
461
  )
462
 
463
  # ========== 用于调试 ==========
464
+ # btn_com1.click(fn=lambda: model.str_detokenize(model._input_ids), outputs=rag)
465
 
466
 
467
  @btn_com2.click(inputs=setting_cache_path,
468
+ outputs=[s_info, btn_submit])
469
  def btn_com2(_cache_path):
470
+ with lock:
471
+ _tmp = model.load_session(setting_cache_path.value)
472
+ print(f'load cache from {setting_cache_path.value} {_tmp}')
473
+ global vo_idx
474
+ vo_idx = 0
475
+ model.venv = [0]
476
+ global session_active
477
+ session_active = False
478
+ return str((model.n_tokens, model.venv)), gr.update(interactive=True)
479
 
480
  # ========== 开始运行 ==========
481
  demo = gr.TabbedInterface([chatting, setting, role],
482
  ["聊天", "设置", '角色'],
483
  css=custom_css)
484
  gr.close_all()
485
+ demo.queue(max_size=1).launch(max_threads=1, share=False)