Limour commited on
Commit
34a0cdc
·
verified ·
1 Parent(s): d080993

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +19 -9
  2. hf_api.py +10 -0
app.py CHANGED
@@ -3,6 +3,7 @@ import os
3
  import re
4
  import json
5
  import threading
 
6
 
7
  import gradio as gr
8
 
@@ -467,15 +468,24 @@ with gr.Blocks() as chatting:
467
  @btn_com2.click(inputs=setting_cache_path,
468
  outputs=[s_info, btn_submit])
469
  def btn_com2(_cache_path):
470
- with lock:
471
- _tmp = model.load_session(setting_cache_path.value)
472
- print(f'load cache from {setting_cache_path.value} {_tmp}')
473
- global vo_idx
474
- vo_idx = 0
475
- model.venv = [0]
476
- global session_active
477
- session_active = False
478
- return str((model.n_tokens, model.venv)), gr.update(interactive=True)
 
 
 
 
 
 
 
 
 
479
 
480
  # ========== 开始运行 ==========
481
  demo = gr.TabbedInterface([chatting, setting, role],
 
3
  import re
4
  import json
5
  import threading
6
+ from hf_api import restart_space
7
 
8
  import gradio as gr
9
 
 
468
  @btn_com2.click(inputs=setting_cache_path,
469
  outputs=[s_info, btn_submit])
470
  def btn_com2(_cache_path):
471
+ try:
472
+ with lock:
473
+ _tmp = model.load_session(setting_cache_path.value)
474
+ print(f'load cache from {setting_cache_path.value} {_tmp}')
475
+ global vo_idx
476
+ vo_idx = 0
477
+ model.venv = [0]
478
+ global session_active
479
+ session_active = False
480
+ return str((model.n_tokens, model.venv)), gr.update(interactive=True)
481
+ except Exception as e:
482
+ restart_space()
483
+ raise e
484
+
485
+ @btn_com3.click()
486
+ def btn_com3():
487
+ restart_space()
488
+
489
 
490
  # ========== 开始运行 ==========
491
  demo = gr.TabbedInterface([chatting, setting, role],
hf_api.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from huggingface_hub import HfApi
4
+
5
+ API = HfApi(token=os.environ.get("HF_TOKEN"))
6
+ REPO_ID = "Limour/llama-python-streamingllm"
7
+
8
+
9
+ def restart_space():
10
+ API.restart_space(repo_id=REPO_ID, token=os.environ.get("HF_TOKEN"))