Upload chat_template.py
Browse files- chat_template.py +41 -0
chat_template.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
|
3 |
+
|
4 |
+
class ChatTemplate:
|
5 |
+
cache = {}
|
6 |
+
|
7 |
+
def __init__(self, model, im_start=r'<|im_start|>', im_end=r'<|im_end|>', nl='\n'):
|
8 |
+
self.model = model
|
9 |
+
self.nl = nl
|
10 |
+
self.im_start = im_start
|
11 |
+
self.im_start_token = model.tokenize(self.im_start.encode('utf-8'), add_bos=False, special=True)
|
12 |
+
self.im_end = im_end
|
13 |
+
self.im_end_nl = model.tokenize((self.im_end + self.nl).encode('utf-8'), add_bos=False, special=True)
|
14 |
+
self.eos = [model._token_eos, self.im_end_nl[0]]
|
15 |
+
self.onenl = [self.im_end_nl[-1]]
|
16 |
+
tmp = model.tokenize(('\r' + self.nl).encode('utf-8'), add_bos=False, special=True)
|
17 |
+
if len(tmp) == 1:
|
18 |
+
self.onenl.append(tmp[0])
|
19 |
+
self.onerl = model.tokenize(b'\r', add_bos=False, special=True)
|
20 |
+
self.nlnl = None
|
21 |
+
tmp = model.tokenize((self.nl + self.nl).encode('utf-8'), add_bos=False, special=True)
|
22 |
+
if len(tmp) == 1:
|
23 |
+
self.nlnl = tmp[0]
|
24 |
+
print('ChatTemplate', self.eos, self.im_end_nl, self.onerl, self.onenl, self.nlnl)
|
25 |
+
|
26 |
+
def _get(self, key: str):
|
27 |
+
if key in self.cache:
|
28 |
+
return copy.deepcopy(self.cache[key]) # 深拷贝一下
|
29 |
+
else:
|
30 |
+
value = self.model.tokenize((self.im_start + key + self.nl).encode('utf-8'), add_bos=False, special=True)
|
31 |
+
self.cache[key] = copy.deepcopy(value) # 深拷贝一下
|
32 |
+
return value
|
33 |
+
|
34 |
+
def __call__(self, _role, prompt=None):
|
35 |
+
if prompt is None:
|
36 |
+
return self._get(_role)
|
37 |
+
# print(_role, prompt, self.cache)
|
38 |
+
prompt = self.im_start + _role + self.nl + prompt
|
39 |
+
prompt = self.model.tokenize(prompt.encode('utf-8'), add_bos=False, special=True) + self.im_end_nl
|
40 |
+
# print(self.model.str_detokenize(prompt), prompt)
|
41 |
+
return prompt
|