Limour commited on
Commit
026cf13
·
verified ·
1 Parent(s): cf65ac6

Upload 3 files

Browse files
chat_template.py CHANGED
@@ -3,6 +3,7 @@ import copy
3
 
4
  class ChatTemplate:
5
  cache = {}
 
6
 
7
  def __init__(self, model, im_start=r'<|im_start|>', im_end=r'<|im_end|>', nl='\n'):
8
  self.model = model
@@ -31,7 +32,42 @@ class ChatTemplate:
31
  self.cache[key] = copy.deepcopy(value) # 深拷贝一下
32
  return value
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  def __call__(self, _role, prompt=None):
 
35
  if prompt is None:
36
  return self._get(_role)
37
  # print(_role, prompt, self.cache)
 
3
 
4
  class ChatTemplate:
5
  cache = {}
6
+ roles = set()
7
 
8
  def __init__(self, model, im_start=r'<|im_start|>', im_end=r'<|im_end|>', nl='\n'):
9
  self.model = model
 
32
  self.cache[key] = copy.deepcopy(value) # 深拷贝一下
33
  return value
34
 
35
+ def _add_role(self, _role):
36
+ if _role:
37
+ self.roles.add('\n' + _role)
38
+
39
+ def eos_in_role(self, history: str, t_bot):
40
+ if not (history.endswith('\n') or history.endswith('\r')):
41
+ return 0
42
+ tmp = history.rstrip()
43
+ for _role in self.roles:
44
+ if tmp.endswith(_role):
45
+ n = len(t_bot)
46
+ for i in range(1, n): # 找出需要弃置的tokens长度
47
+ tmp = self.model.str_detokenize(t_bot[n - i:])
48
+ if tmp.rstrip().endswith(_role):
49
+ print('eos_in_role', t_bot[n - i:], repr(tmp))
50
+ return i
51
+ print('eos_in_role missing')
52
+ break
53
+ return 0
54
+
55
+ def eos_in_nlnl(self, history: str, t_bot):
56
+ if not (history.endswith('\n\n') or history.endswith('\n\r\n')):
57
+ return 0
58
+ n = len(t_bot)
59
+ for i in range(1, n): # 找出需要弃置的tokens长度
60
+ tmp = self.model.str_detokenize(t_bot[n - i:])
61
+ if tmp.endswith('\n\n') or tmp.endswith('\n\r\n'):
62
+ if tmp.startswith(']'): # 避免误判
63
+ return 0
64
+ print('eos_in_nlnl', t_bot[n - i:], repr(tmp))
65
+ return i
66
+ print('eos_in_nlnl missing')
67
+ return 0
68
+
69
  def __call__(self, _role, prompt=None):
70
+ self._add_role(_role)
71
  if prompt is None:
72
  return self._get(_role)
73
  # print(_role, prompt, self.cache)
gradio_streamingllm.py CHANGED
@@ -28,6 +28,9 @@ from mods.btn_reset import init as btn_reset_init
28
  # ========== 聊天的模版 默认 chatml ==========
29
  from chat_template import ChatTemplate
30
 
 
 
 
31
  # ========== 全局锁,确保只能进行一个会话 ==========
32
  cfg['session_lock'] = threading.Lock()
33
  cfg['session_active'] = False
@@ -84,8 +87,6 @@ with gr.Blocks() as role:
84
  cfg['role_chat_style'] = gr.Textbox(lines=10, label="回复示例", **cfg['role_chat_style'])
85
 
86
  # ========== 加载角色卡-缓存 ==========
87
- from mods.load_cache import init as load_cache_init
88
-
89
  text_display_init(cfg)
90
  load_cache_init(cfg)
91
 
@@ -99,15 +100,6 @@ with gr.Blocks() as chatting:
99
  cfg['vo'] = gr.Textbox(label='VO', show_copy_button=True, elem_id="VO-area")
100
  cfg['s_info'] = gr.Textbox(value=cfg['model'].venv_info, max_lines=1, label='info', interactive=False)
101
  cfg['msg'] = gr.Textbox(label='Prompt', lines=2, max_lines=2, elem_id='prompt', autofocus=True, **cfg['msg'])
102
- with gr.Row():
103
- cfg['btn_vo'] = gr.Button("旁白")
104
- cfg['btn_rag'] = gr.Button("RAG")
105
- cfg['btn_retry'] = gr.Button("Retry")
106
- cfg['btn_com1'] = gr.Button("自定义1")
107
- cfg['btn_reset'] = gr.Button("Reset")
108
- cfg['btn_debug'] = gr.Button("Debug")
109
- cfg['btn_submit'] = gr.Button("Submit")
110
- cfg['btn_suggest'] = gr.Button("建议")
111
 
112
  cfg['gr'] = gr
113
  btn_com_init(cfg)
@@ -164,4 +156,4 @@ demo = gr.TabbedInterface([chatting, setting, role],
164
  ["聊天", "设置", '角色'],
165
  css=custom_css)
166
  gr.close_all()
167
- demo.queue(api_open=False, max_size=1).launch(share=False)
 
28
  # ========== 聊天的模版 默认 chatml ==========
29
  from chat_template import ChatTemplate
30
 
31
+ # ========== 加载角色卡-缓存 ==========
32
+ from mods.load_cache import init as load_cache_init
33
+
34
  # ========== 全局锁,确保只能进行一个会话 ==========
35
  cfg['session_lock'] = threading.Lock()
36
  cfg['session_active'] = False
 
87
  cfg['role_chat_style'] = gr.Textbox(lines=10, label="回复示例", **cfg['role_chat_style'])
88
 
89
  # ========== 加载角色卡-缓存 ==========
 
 
90
  text_display_init(cfg)
91
  load_cache_init(cfg)
92
 
 
100
  cfg['vo'] = gr.Textbox(label='VO', show_copy_button=True, elem_id="VO-area")
101
  cfg['s_info'] = gr.Textbox(value=cfg['model'].venv_info, max_lines=1, label='info', interactive=False)
102
  cfg['msg'] = gr.Textbox(label='Prompt', lines=2, max_lines=2, elem_id='prompt', autofocus=True, **cfg['msg'])
 
 
 
 
 
 
 
 
 
103
 
104
  cfg['gr'] = gr
105
  btn_com_init(cfg)
 
156
  ["聊天", "设置", '角色'],
157
  css=custom_css)
158
  gr.close_all()
159
+ demo.queue(api_open=False, max_size=1).launch(share=False, show_error=True, show_api=False)
llama_cpp_python_streamingllm.py CHANGED
@@ -6,35 +6,13 @@ from ctypes import POINTER
6
  from KMP_list import kmp_search, compute_lps_array
7
 
8
 
9
- def is_UTF8_incomplete(all_text):
10
- multibyte_fix = 0
11
- if len(all_text) < 3:
12
- all_text = b'000' + all_text
13
- for k, char in enumerate(all_text[-3:]):
14
- k = 3 - k
15
- for num, pattern in [(2, 192), (3, 224), (4, 240)]:
16
- # Bitwise AND check
17
- if num > k and pattern & char == pattern:
18
- multibyte_fix = num - k
19
- return multibyte_fix
20
-
21
-
22
- def get_complete_UTF8(all_text):
23
- multibyte_fix = is_UTF8_incomplete(all_text)
24
- if multibyte_fix > 0:
25
- multibyte_fix = multibyte_fix - 3
26
- return all_text[:multibyte_fix].decode("utf-8")
27
- else:
28
- return all_text.decode("utf-8")
29
-
30
-
31
  class StreamingLLM(Llama):
32
  def __init__(self, model_path: str, **kwargs):
33
  super().__init__(model_path, **kwargs)
34
  self._venv_init()
35
 
36
  def str_detokenize(self, tokens) -> str:
37
- return get_complete_UTF8(self.detokenize(tokens))
38
 
39
  def kv_cache_seq_trim(self):
40
  self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1)
@@ -103,9 +81,9 @@ class StreamingLLM(Llama):
103
  break
104
  return True
105
 
106
- def venv_pop_token(self):
107
- self.n_tokens -= 1
108
- self.venv[-1] -= 1
109
  self.kv_cache_seq_trim()
110
 
111
  @property
@@ -113,6 +91,8 @@ class StreamingLLM(Llama):
113
  return str((self.n_tokens, self.venv, self.venv_idx_map))
114
 
115
  def kv_cache_seq_ltrim(self, n_keep, n_discard=256, n_past=-1, im_start=None):
 
 
116
  if n_past < 0:
117
  n_past = self.n_tokens
118
  if im_start is not None: # [<|im_start|>, name, nl]
 
6
  from KMP_list import kmp_search, compute_lps_array
7
 
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  class StreamingLLM(Llama):
10
  def __init__(self, model_path: str, **kwargs):
11
  super().__init__(model_path, **kwargs)
12
  self._venv_init()
13
 
14
  def str_detokenize(self, tokens) -> str:
15
+ return self.detokenize(tokens).decode('utf-8', errors='ignore')
16
 
17
  def kv_cache_seq_trim(self):
18
  self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1)
 
81
  break
82
  return True
83
 
84
+ def venv_pop_token(self, n=1):
85
+ self.n_tokens -= n
86
+ self.venv[-1] -= n
87
  self.kv_cache_seq_trim()
88
 
89
  @property
 
91
  return str((self.n_tokens, self.venv, self.venv_idx_map))
92
 
93
  def kv_cache_seq_ltrim(self, n_keep, n_discard=256, n_past=-1, im_start=None):
94
+ if n_keep < 0:
95
+ return
96
  if n_past < 0:
97
  n_past = self.n_tokens
98
  if im_start is not None: # [<|im_start|>, name, nl]