KaleiNeely
commited on
Commit
•
ad95cec
1
Parent(s):
47b17c7
Upload 7 files
Browse files- README.md +69 -3
- modeling_rwkv5.py +56 -37
- tokenization_rwkv_world.py +142 -12
README.md
CHANGED
@@ -85,7 +85,7 @@ Assistant:"""
|
|
85 |
model = AutoModelForCausalLM.from_pretrained("RWKV/rwkv-5-world-1b5", trust_remote_code=True, torch_dtype=torch.float16).to(0)
|
86 |
tokenizer = AutoTokenizer.from_pretrained("RWKV/rwkv-5-world-1b5", trust_remote_code=True)
|
87 |
|
88 |
-
text = "
|
89 |
prompt = generate_prompt(text)
|
90 |
|
91 |
inputs = tokenizer(prompt, return_tensors="pt").to(0)
|
@@ -100,8 +100,74 @@ User: hi
|
|
100 |
|
101 |
Assistant: Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.
|
102 |
|
103 |
-
User:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
|
105 |
-
Assistant: 乌兰察布市是中国新疆维吾尔自治区的一个地级市,位于新疆维吾尔自治区西南部,毗邻青海省。乌兰察布市是新疆维吾尔自治区的重要城市之一,也是新疆维吾尔自治区的第二大城市。乌兰察布市是新疆的重要经济中心之一,拥有丰富的自然资源和人口密度,是新疆的重要交通枢纽和商
|
106 |
```
|
107 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
model = AutoModelForCausalLM.from_pretrained("RWKV/rwkv-5-world-1b5", trust_remote_code=True, torch_dtype=torch.float16).to(0)
|
86 |
tokenizer = AutoTokenizer.from_pretrained("RWKV/rwkv-5-world-1b5", trust_remote_code=True)
|
87 |
|
88 |
+
text = "介绍一下大熊猫"
|
89 |
prompt = generate_prompt(text)
|
90 |
|
91 |
inputs = tokenizer(prompt, return_tensors="pt").to(0)
|
|
|
100 |
|
101 |
Assistant: Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.
|
102 |
|
103 |
+
User: 介绍一下大熊猫
|
104 |
+
|
105 |
+
Assistant: 大熊猫是一种中国特有的哺乳动物,也是中国的国宝之一。它们的外貌特征是圆形的黑白相间的身体,有着黑色的毛发和白色的耳朵。大熊猫的食物主要是竹子,它们会在竹林中寻找竹子,并且会将竹子放在竹笼中进行储存。大熊猫的寿命约为20至30年,但由于栖息地的丧失和人类活动的
|
106 |
+
```
|
107 |
+
|
108 |
+
#### Batch Inference
|
109 |
+
|
110 |
+
```python
|
111 |
+
import torch
|
112 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
113 |
+
|
114 |
+
def generate_prompt(instruction, input=""):
|
115 |
+
instruction = instruction.strip().replace('\r\n', '\n').replace('\n\n', '\n')
|
116 |
+
input = input.strip().replace('\r\n', '\n').replace('\n\n', '\n')
|
117 |
+
if input:
|
118 |
+
return f"""Instruction: {instruction}
|
119 |
+
|
120 |
+
Input: {input}
|
121 |
+
|
122 |
+
Response:"""
|
123 |
+
else:
|
124 |
+
return f"""User: hi
|
125 |
+
|
126 |
+
Assistant: Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.
|
127 |
+
|
128 |
+
User: {instruction}
|
129 |
+
|
130 |
+
Assistant:"""
|
131 |
+
|
132 |
+
model = AutoModelForCausalLM.from_pretrained("RWKV/rwkv-5-world-1b5", trust_remote_code=True).to(torch.float32)
|
133 |
+
tokenizer = AutoTokenizer.from_pretrained("RWKV/rwkv-5-world-1b5", trust_remote_code=True)
|
134 |
+
|
135 |
+
texts = ["请介绍北京的旅游景点", "介绍一下大熊猫", "乌兰察布"]
|
136 |
+
prompts = [generate_prompt(text) for text in texts]
|
137 |
+
|
138 |
+
inputs = tokenizer(prompts, return_tensors="pt", padding=True)
|
139 |
+
outputs = model.generate(inputs["input_ids"], max_new_tokens=128, do_sample=True, temperature=1.0, top_p=0.3, top_k=0, )
|
140 |
+
|
141 |
+
for output in outputs:
|
142 |
+
print(tokenizer.decode(output.tolist(), skip_special_tokens=True))
|
143 |
|
|
|
144 |
```
|
145 |
|
146 |
+
output:
|
147 |
+
|
148 |
+
```shell
|
149 |
+
User: hi
|
150 |
+
|
151 |
+
Assistant: Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.
|
152 |
+
|
153 |
+
User: 请介绍北京的旅游景点
|
154 |
+
|
155 |
+
Assistant: 北京是中国的首都,拥有丰富的旅游资源和历史文化遗产。以下是一些北京的旅游景点:
|
156 |
+
1. 故宫:位于北京市中心,是明清两代的皇宫,是中国最大的古代宫殿建筑群之一。
|
157 |
+
2. 天安门广场:位于北京市中心,是中国最著名的城市广场之一,也是中国最大的城市广场。
|
158 |
+
3. 颐和
|
159 |
+
User: hi
|
160 |
+
|
161 |
+
Assistant: Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.
|
162 |
+
|
163 |
+
User: 介绍一下大熊猫
|
164 |
+
|
165 |
+
Assistant: 大熊猫是一种生活在中国中部地区的哺乳动物,也是中国的国宝之一。它们的外貌特征是圆形的黑白相间的身体,有着黑色的毛发和圆圆的眼睛。大熊猫是一种濒危物种,目前只有在野外的几个保护区才能看到它们的身影。大熊猫的食物主要是竹子,它们会在竹子上寻找食物,并且可以通
|
166 |
+
User: hi
|
167 |
+
|
168 |
+
Assistant: Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.
|
169 |
+
|
170 |
+
User: 乌兰察布
|
171 |
+
|
172 |
+
Assistant: 乌兰察布是中国新疆维吾尔自治区的一个县级市,位于新疆维吾尔自治区中部,是新疆的第二大城市。乌兰察布市是新疆的第一大城市,也是新疆的重要城市之一。乌兰察布市是新疆的经济中心,也是新疆的重要交通枢纽之一。乌兰察布市的人口约为2.5万人,其中汉族占绝大多数。乌
|
173 |
+
```
|
modeling_rwkv5.py
CHANGED
@@ -85,33 +85,46 @@ def rwkv_linear_attention_v5_0(H, S, T, hidden, time_decay, time_first, receptan
|
|
85 |
|
86 |
return out, state
|
87 |
|
88 |
-
|
|
|
|
|
89 |
time_decay = torch.exp(-torch.exp(time_decay.float())).reshape(-1,1,1).reshape(n_head, -1, 1)
|
90 |
time_first = time_first.float().reshape(-1,1,1).reshape(n_head, -1, 1)
|
91 |
lxw = lxw.float()
|
92 |
lxb = lxb.float()
|
93 |
-
if seq_mode:
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
else:
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
|
116 |
return out, state
|
117 |
|
@@ -153,7 +166,7 @@ class RwkvSelfAttention(nn.Module):
|
|
153 |
self.ln_x = nn.GroupNorm(hidden_size // config.head_size, hidden_size)
|
154 |
|
155 |
# TODO: maybe jit, otherwise move inside forward
|
156 |
-
def extract_key_value(self, H, S, T, hidden, state=None):
|
157 |
# Mix hidden with the previous timestep to produce key, value, receptance
|
158 |
if hidden.size(1) == 1 and state is not None:
|
159 |
shifted = state[0][:, :, self.layer_id]
|
@@ -161,25 +174,27 @@ class RwkvSelfAttention(nn.Module):
|
|
161 |
shifted = self.time_shift(hidden)
|
162 |
if state is not None:
|
163 |
shifted[:, 0] = state[0][:, :, self.layer_id]
|
|
|
|
|
164 |
key = hidden * self.time_mix_key + shifted * (1 - self.time_mix_key)
|
165 |
value = hidden * self.time_mix_value + shifted * (1 - self.time_mix_value)
|
166 |
receptance = hidden * self.time_mix_receptance + shifted * (1 - self.time_mix_receptance)
|
167 |
if self.config.model_version == "5_2":
|
168 |
gate = hidden* self.time_mix_gate + shifted * (1 - self.time_mix_gate)
|
169 |
|
170 |
-
if hidden.size(1) == 1 and state is not None:
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
else:
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
|
180 |
if self.config.model_version == "5_2":
|
181 |
gate = F.silu(self.gate(gate))
|
182 |
-
|
183 |
if state is not None:
|
184 |
state[0][:, :, self.layer_id] = hidden[:, -1]
|
185 |
|
@@ -188,17 +203,19 @@ class RwkvSelfAttention(nn.Module):
|
|
188 |
return receptance, key, value, state
|
189 |
|
190 |
def forward(self, hidden, state=None, use_cache=False, seq_mode=True):
|
|
|
191 |
H = self.time_decay.shape[0]
|
192 |
S = hidden.shape[-1] // H
|
193 |
T = hidden.shape[1]
|
194 |
|
195 |
if self.config.model_version == "5_2":
|
196 |
-
receptance, key, value, gate, state = self.extract_key_value(H, S, T, hidden, state=state)
|
197 |
else:
|
198 |
receptance, key, value, state = self.extract_key_value(H, S, T, hidden, state=state)
|
199 |
layer_state = state[1][:, :, :, :, self.layer_id] if state is not None else None
|
200 |
if self.config.model_version == "5_2":
|
201 |
rwkv, layer_state = rwkv_linear_attention_v5_2(
|
|
|
202 |
H,
|
203 |
S,
|
204 |
T,
|
@@ -273,6 +290,8 @@ class RwkvFeedForward(nn.Module):
|
|
273 |
shifted = self.time_shift(hidden)
|
274 |
if state is not None:
|
275 |
shifted[:, 0] = state[2][:, :, self.layer_id]
|
|
|
|
|
276 |
key = hidden * self.time_mix_key + shifted * (1 - self.time_mix_key)
|
277 |
receptance = hidden * self.time_mix_receptance + shifted * (1 - self.time_mix_receptance)
|
278 |
|
@@ -594,7 +613,8 @@ class RwkvModel(RwkvPreTrainedModel):
|
|
594 |
|
595 |
|
596 |
hidden_states = inputs_embeds
|
597 |
-
|
|
|
598 |
all_self_attentions = () if output_attentions else None
|
599 |
all_hidden_states = () if output_hidden_states else None
|
600 |
for idx, block in enumerate(self.blocks):
|
@@ -645,7 +665,6 @@ class RwkvModel(RwkvPreTrainedModel):
|
|
645 |
|
646 |
self.layers_are_rescaled = not self.training
|
647 |
|
648 |
-
|
649 |
@add_start_docstrings(
|
650 |
"""
|
651 |
The RWKV Model transformer with a language modeling head on top (linear layer with weights tied to the input
|
|
|
85 |
|
86 |
return out, state
|
87 |
|
88 |
+
cnt = 0
|
89 |
+
|
90 |
+
def rwkv_linear_attention_v5_2(B, H, S, T, n_head, hidden, time_decay, time_first, receptance, key, value, gate, lxw, lxb, ow, state, return_state=False, seq_mode=True):
|
91 |
time_decay = torch.exp(-torch.exp(time_decay.float())).reshape(-1,1,1).reshape(n_head, -1, 1)
|
92 |
time_first = time_first.float().reshape(-1,1,1).reshape(n_head, -1, 1)
|
93 |
lxw = lxw.float()
|
94 |
lxb = lxb.float()
|
95 |
+
# if seq_mode:
|
96 |
+
out = torch.empty((B, T, H, S), dtype=receptance.dtype, device=receptance.device)
|
97 |
+
for t in range(T):
|
98 |
+
rt = receptance[:,:,t:t+1,:]
|
99 |
+
kt = key[:,:,:,t:t+1]
|
100 |
+
vt = value[:,:,t:t+1,:]
|
101 |
+
at = kt @ vt
|
102 |
+
out[:, t] = (rt @ (time_first * at + state)).squeeze(2)
|
103 |
+
state = at + time_decay * state
|
104 |
+
|
105 |
+
out = out.reshape(B*T, H*S)
|
106 |
+
out = F.group_norm(out, num_groups=H, weight=lxw, bias=lxb).reshape(B, T, H*S)
|
107 |
+
out = out.to(dtype=hidden.dtype) * gate
|
108 |
+
out = out @ ow
|
109 |
+
# else:
|
110 |
+
# a = key @ value
|
111 |
+
# # print('key.shape: ', key.shape)
|
112 |
+
# # print('value.shape: ', value.shape)
|
113 |
+
# # print('receptance.shape: ', receptance.shape)
|
114 |
+
# # print('a.shape: ', a.shape)
|
115 |
+
# # print('time_first.shape: ', time_first.shape)
|
116 |
+
# # print('(time_first * a).shape: ', (time_first * a).shape)
|
117 |
+
# # print('time_decay.shape: ', time_decay.shape)
|
118 |
+
# # print('state.shape: ', state.shape)
|
119 |
+
# out = receptance @ (time_first * a + state)
|
120 |
+
# # print('out.shape: ', out.shape)
|
121 |
+
# state = a + time_decay * state
|
122 |
+
# # print('state.shape: ', state.shape)
|
123 |
+
# out = out.reshape(B, H*S)
|
124 |
+
# out = F.group_norm(out, num_groups=H, weight=lxw, bias=lxb).reshape(B, 1, H*S)
|
125 |
+
# out = out.to(dtype=hidden.dtype) * gate
|
126 |
+
# out = out @ ow
|
127 |
+
|
128 |
|
129 |
return out, state
|
130 |
|
|
|
166 |
self.ln_x = nn.GroupNorm(hidden_size // config.head_size, hidden_size)
|
167 |
|
168 |
# TODO: maybe jit, otherwise move inside forward
|
169 |
+
def extract_key_value(self, B, H, S, T, hidden, state=None):
|
170 |
# Mix hidden with the previous timestep to produce key, value, receptance
|
171 |
if hidden.size(1) == 1 and state is not None:
|
172 |
shifted = state[0][:, :, self.layer_id]
|
|
|
174 |
shifted = self.time_shift(hidden)
|
175 |
if state is not None:
|
176 |
shifted[:, 0] = state[0][:, :, self.layer_id]
|
177 |
+
if len(shifted.size()) == 2:
|
178 |
+
shifted = shifted.unsqueeze(1)
|
179 |
key = hidden * self.time_mix_key + shifted * (1 - self.time_mix_key)
|
180 |
value = hidden * self.time_mix_value + shifted * (1 - self.time_mix_value)
|
181 |
receptance = hidden * self.time_mix_receptance + shifted * (1 - self.time_mix_receptance)
|
182 |
if self.config.model_version == "5_2":
|
183 |
gate = hidden* self.time_mix_gate + shifted * (1 - self.time_mix_gate)
|
184 |
|
185 |
+
# if hidden.size(1) == 1 and state is not None:
|
186 |
+
# receptance = self.receptance(receptance).to(torch.float32).view(B, H, 1, S)
|
187 |
+
# key = self.key(key).to(torch.float32).view(B, H, S, 1)
|
188 |
+
# value = self.value(value).to(torch.float32).view(B, H, 1, S)
|
189 |
+
# else:
|
190 |
+
# https://github.com/BlinkDL/ChatRWKV/blob/main/rwkv_pip_package/src/rwkv/model.py#L693
|
191 |
+
key = self.key(key).to(torch.float32).view(B, T, H, S).transpose(1, 2).transpose(-2, -1)
|
192 |
+
value = self.value(value).to(torch.float32).view(B, T, H, S).transpose(1, 2)
|
193 |
+
receptance = self.receptance(receptance).to(torch.float32).view(B, T, H, S).transpose(1, 2)
|
194 |
|
195 |
if self.config.model_version == "5_2":
|
196 |
gate = F.silu(self.gate(gate))
|
197 |
+
|
198 |
if state is not None:
|
199 |
state[0][:, :, self.layer_id] = hidden[:, -1]
|
200 |
|
|
|
203 |
return receptance, key, value, state
|
204 |
|
205 |
def forward(self, hidden, state=None, use_cache=False, seq_mode=True):
|
206 |
+
B = hidden.shape[0]
|
207 |
H = self.time_decay.shape[0]
|
208 |
S = hidden.shape[-1] // H
|
209 |
T = hidden.shape[1]
|
210 |
|
211 |
if self.config.model_version == "5_2":
|
212 |
+
receptance, key, value, gate, state = self.extract_key_value(B, H, S, T, hidden, state=state)
|
213 |
else:
|
214 |
receptance, key, value, state = self.extract_key_value(H, S, T, hidden, state=state)
|
215 |
layer_state = state[1][:, :, :, :, self.layer_id] if state is not None else None
|
216 |
if self.config.model_version == "5_2":
|
217 |
rwkv, layer_state = rwkv_linear_attention_v5_2(
|
218 |
+
B,
|
219 |
H,
|
220 |
S,
|
221 |
T,
|
|
|
290 |
shifted = self.time_shift(hidden)
|
291 |
if state is not None:
|
292 |
shifted[:, 0] = state[2][:, :, self.layer_id]
|
293 |
+
if len(shifted.size()) == 2:
|
294 |
+
shifted = shifted.unsqueeze(1)
|
295 |
key = hidden * self.time_mix_key + shifted * (1 - self.time_mix_key)
|
296 |
receptance = hidden * self.time_mix_receptance + shifted * (1 - self.time_mix_receptance)
|
297 |
|
|
|
613 |
|
614 |
|
615 |
hidden_states = inputs_embeds
|
616 |
+
global cnt
|
617 |
+
cnt += 1
|
618 |
all_self_attentions = () if output_attentions else None
|
619 |
all_hidden_states = () if output_hidden_states else None
|
620 |
for idx, block in enumerate(self.blocks):
|
|
|
665 |
|
666 |
self.layers_are_rescaled = not self.training
|
667 |
|
|
|
668 |
@add_start_docstrings(
|
669 |
"""
|
670 |
The RWKV Model transformer with a language modeling head on top (linear layer with weights tied to the input
|
tokenization_rwkv_world.py
CHANGED
@@ -107,6 +107,7 @@ class RWKVWorldTokenizer(PreTrainedTokenizer):
|
|
107 |
self,
|
108 |
vocab_file,
|
109 |
errors="replace",
|
|
|
110 |
**kwargs
|
111 |
):
|
112 |
self.add_bos_token = False
|
@@ -122,11 +123,7 @@ class RWKVWorldTokenizer(PreTrainedTokenizer):
|
|
122 |
assert len(x) == int(l[l.rindex(' '):])
|
123 |
sorted += [x]
|
124 |
self.encoder[idx] = x
|
125 |
-
|
126 |
-
super().__init__(
|
127 |
-
errors=errors,
|
128 |
-
**kwargs,
|
129 |
-
)
|
130 |
self.decoder = {}
|
131 |
for k,v in self.encoder.items():
|
132 |
self.decoder[v] = int(k)
|
@@ -136,6 +133,14 @@ class RWKVWorldTokenizer(PreTrainedTokenizer):
|
|
136 |
_ = self.trie.add(t, val=(t, i))
|
137 |
self.errors = errors # how to handle errors in decoding
|
138 |
self.cache = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
|
140 |
@property
|
141 |
def vocab_size(self):
|
@@ -143,6 +148,22 @@ class RWKVWorldTokenizer(PreTrainedTokenizer):
|
|
143 |
|
144 |
def get_vocab(self):
|
145 |
return dict(self.encoder, **self.added_tokens_encoder)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
|
147 |
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
148 |
if self.add_bos_token:
|
@@ -219,14 +240,21 @@ class RWKVWorldTokenizer(PreTrainedTokenizer):
|
|
219 |
skip_special_tokens: bool = False,
|
220 |
**kwargs
|
221 |
) -> str:
|
222 |
-
|
|
|
|
|
|
|
|
|
|
|
223 |
# Convert inputs to python lists
|
224 |
token_ids = to_py_obj(token_ids)
|
|
|
225 |
if isinstance(token_ids, int):
|
226 |
if token_ids in self.all_special_ids and skip_special_tokens:
|
227 |
return ""
|
228 |
return self.encoder.get(token_ids, self.unk_token)
|
229 |
elif isinstance(token_ids, list):
|
|
|
230 |
out_str = ""
|
231 |
out_last = 0
|
232 |
out_tokens = []
|
@@ -268,6 +296,11 @@ class RWKVWorldTokenizer(PreTrainedTokenizer):
|
|
268 |
def prepare_for_tokenization(self, text, **kwargs):
|
269 |
return (text, kwargs)
|
270 |
|
|
|
|
|
|
|
|
|
|
|
271 |
def _encode_plus(
|
272 |
self,
|
273 |
text: Union[TextInput, EncodedInput],
|
@@ -352,19 +385,33 @@ class RWKVWorldTokenizer(PreTrainedTokenizer):
|
|
352 |
verbose: bool = True,
|
353 |
**kwargs
|
354 |
) -> BatchEncoding:
|
355 |
-
def get_input_ids(text):
|
|
|
|
|
|
|
356 |
if isinstance(text, str):
|
357 |
-
|
358 |
-
|
|
|
|
|
|
|
359 |
elif isinstance(text, list) and len(text) > 0 and isinstance(text[0], str):
|
360 |
-
|
|
|
|
|
|
|
|
|
361 |
elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
|
|
|
|
|
362 |
return text
|
|
|
363 |
else:
|
364 |
raise ValueError(
|
365 |
"Input is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers."
|
366 |
)
|
367 |
|
|
|
368 |
if return_offsets_mapping:
|
369 |
raise NotImplementedError(
|
370 |
"return_offset_mapping is not available when using Python tokenizers. "
|
@@ -372,15 +419,29 @@ class RWKVWorldTokenizer(PreTrainedTokenizer):
|
|
372 |
"transformers.PreTrainedTokenizerFast."
|
373 |
)
|
374 |
|
375 |
-
|
|
|
376 |
for ids_or_pair_ids in batch_text_or_text_pairs:
|
377 |
if not isinstance(ids_or_pair_ids, (list, tuple)):
|
378 |
ids, pair_ids = ids_or_pair_ids, None
|
379 |
else:
|
380 |
ids, pair_ids = ids_or_pair_ids
|
381 |
-
|
382 |
first_ids = get_input_ids(ids)
|
383 |
second_ids = get_input_ids(pair_ids) if pair_ids is not None else None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
384 |
input_ids.append((first_ids, second_ids))
|
385 |
|
386 |
batch_outputs = self._batch_prepare_for_model(
|
@@ -401,6 +462,75 @@ class RWKVWorldTokenizer(PreTrainedTokenizer):
|
|
401 |
)
|
402 |
|
403 |
return BatchEncoding(batch_outputs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
404 |
|
405 |
def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
|
406 |
input_ids = []
|
|
|
107 |
self,
|
108 |
vocab_file,
|
109 |
errors="replace",
|
110 |
+
pad_token="0",
|
111 |
**kwargs
|
112 |
):
|
113 |
self.add_bos_token = False
|
|
|
123 |
assert len(x) == int(l[l.rindex(' '):])
|
124 |
sorted += [x]
|
125 |
self.encoder[idx] = x
|
126 |
+
|
|
|
|
|
|
|
|
|
127 |
self.decoder = {}
|
128 |
for k,v in self.encoder.items():
|
129 |
self.decoder[v] = int(k)
|
|
|
133 |
_ = self.trie.add(t, val=(t, i))
|
134 |
self.errors = errors # how to handle errors in decoding
|
135 |
self.cache = {}
|
136 |
+
self.first_max_length = 0
|
137 |
+
|
138 |
+
# pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
|
139 |
+
super().__init__(
|
140 |
+
errors=errors,
|
141 |
+
# pad_token=pad_token,
|
142 |
+
**kwargs,
|
143 |
+
)
|
144 |
|
145 |
@property
|
146 |
def vocab_size(self):
|
|
|
148 |
|
149 |
def get_vocab(self):
|
150 |
return dict(self.encoder, **self.added_tokens_encoder)
|
151 |
+
|
152 |
+
def add_tokens(self, new_tokens, special_tokens: bool = False):
|
153 |
+
for token in new_tokens:
|
154 |
+
token_id = self.convert_tokens_to_ids(token)
|
155 |
+
self.added_tokens_decoder[token_id] = token
|
156 |
+
|
157 |
+
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
|
158 |
+
if isinstance(ids, int):
|
159 |
+
ids = [ids]
|
160 |
+
tokens = []
|
161 |
+
for id_ in ids:
|
162 |
+
if id_ in self.added_tokens_decoder:
|
163 |
+
tokens.append(self.added_tokens_decoder[id_])
|
164 |
+
else:
|
165 |
+
tokens.append(self._convert_id_to_token(id_))
|
166 |
+
return tokens
|
167 |
|
168 |
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
169 |
if self.add_bos_token:
|
|
|
240 |
skip_special_tokens: bool = False,
|
241 |
**kwargs
|
242 |
) -> str:
|
243 |
+
|
244 |
+
def remove_zeros_from_first_segment(token_ids, first_max_length):
|
245 |
+
first_segment = token_ids[:first_max_length]
|
246 |
+
first_segment_cleaned = [token for token in first_segment if token != 0]
|
247 |
+
return first_segment_cleaned + token_ids[first_max_length:]
|
248 |
+
|
249 |
# Convert inputs to python lists
|
250 |
token_ids = to_py_obj(token_ids)
|
251 |
+
token_ids = remove_zeros_from_first_segment(token_ids, self.first_max_length)
|
252 |
if isinstance(token_ids, int):
|
253 |
if token_ids in self.all_special_ids and skip_special_tokens:
|
254 |
return ""
|
255 |
return self.encoder.get(token_ids, self.unk_token)
|
256 |
elif isinstance(token_ids, list):
|
257 |
+
self.first_max_length
|
258 |
out_str = ""
|
259 |
out_last = 0
|
260 |
out_tokens = []
|
|
|
296 |
def prepare_for_tokenization(self, text, **kwargs):
|
297 |
return (text, kwargs)
|
298 |
|
299 |
+
def _get_padding_truncation_strategies(
|
300 |
+
self, padding=False, truncation=None, max_length=None, pad_to_multiple_of=None, verbose=True, **kwargs
|
301 |
+
):
|
302 |
+
return PaddingStrategy.LONGEST, TruncationStrategy.DO_NOT_TRUNCATE, -1, kwargs
|
303 |
+
|
304 |
def _encode_plus(
|
305 |
self,
|
306 |
text: Union[TextInput, EncodedInput],
|
|
|
385 |
verbose: bool = True,
|
386 |
**kwargs
|
387 |
) -> BatchEncoding:
|
388 |
+
def get_input_ids(text, max_length=None, pad_token_id=0):
|
389 |
+
def pad_sequence(seq, max_len, pad_tok):
|
390 |
+
return [pad_tok] * (max_len - len(seq)) + seq
|
391 |
+
|
392 |
if isinstance(text, str):
|
393 |
+
tokens = self._tokenize(text)
|
394 |
+
if max_length is not None:
|
395 |
+
tokens = pad_sequence(tokens, max_length, pad_token_id)
|
396 |
+
return tokens
|
397 |
+
|
398 |
elif isinstance(text, list) and len(text) > 0 and isinstance(text[0], str):
|
399 |
+
tokenized_texts = [self._tokenize(t) for t in text]
|
400 |
+
if max_length is None:
|
401 |
+
max_length = max(len(t) for t in tokenized_texts)
|
402 |
+
return [pad_sequence(t, max_length, pad_token_id) for t in tokenized_texts]
|
403 |
+
|
404 |
elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
|
405 |
+
if max_length is not None and len(text) < max_length:
|
406 |
+
return pad_sequence(text, max_length, pad_token_id)
|
407 |
return text
|
408 |
+
|
409 |
else:
|
410 |
raise ValueError(
|
411 |
"Input is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers."
|
412 |
)
|
413 |
|
414 |
+
|
415 |
if return_offsets_mapping:
|
416 |
raise NotImplementedError(
|
417 |
"return_offset_mapping is not available when using Python tokenizers. "
|
|
|
419 |
"transformers.PreTrainedTokenizerFast."
|
420 |
)
|
421 |
|
422 |
+
first_max_length = 0
|
423 |
+
second_max_length = 0
|
424 |
for ids_or_pair_ids in batch_text_or_text_pairs:
|
425 |
if not isinstance(ids_or_pair_ids, (list, tuple)):
|
426 |
ids, pair_ids = ids_or_pair_ids, None
|
427 |
else:
|
428 |
ids, pair_ids = ids_or_pair_ids
|
|
|
429 |
first_ids = get_input_ids(ids)
|
430 |
second_ids = get_input_ids(pair_ids) if pair_ids is not None else None
|
431 |
+
first_max_length = max(first_max_length, len(first_ids))
|
432 |
+
if second_ids is not None:
|
433 |
+
second_max_length = max(second_max_length, len(second_ids))
|
434 |
+
|
435 |
+
self.first_max_length = first_max_length
|
436 |
+
input_ids = []
|
437 |
+
for ids_or_pair_ids in batch_text_or_text_pairs:
|
438 |
+
if not isinstance(ids_or_pair_ids, (list, tuple)):
|
439 |
+
ids, pair_ids = ids_or_pair_ids, None
|
440 |
+
else:
|
441 |
+
ids, pair_ids = ids_or_pair_ids
|
442 |
+
|
443 |
+
first_ids = get_input_ids(ids, max_length=first_max_length)
|
444 |
+
second_ids = get_input_ids(pair_ids, max_length=second_max_length) if pair_ids is not None else None
|
445 |
input_ids.append((first_ids, second_ids))
|
446 |
|
447 |
batch_outputs = self._batch_prepare_for_model(
|
|
|
462 |
)
|
463 |
|
464 |
return BatchEncoding(batch_outputs)
|
465 |
+
|
466 |
+
def decode(
|
467 |
+
self,
|
468 |
+
token_ids: Union[int, List[int], "np.ndarray", "torch.Tensor", "tf.Tensor"],
|
469 |
+
skip_special_tokens: bool = False,
|
470 |
+
clean_up_tokenization_spaces: bool = None,
|
471 |
+
**kwargs,
|
472 |
+
) -> str:
|
473 |
+
"""
|
474 |
+
Converts a sequence of ids in a string, using the tokenizer and vocabulary with options to remove special
|
475 |
+
tokens and clean up tokenization spaces.
|
476 |
+
|
477 |
+
Similar to doing `self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))`.
|
478 |
+
|
479 |
+
Args:
|
480 |
+
token_ids (`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`):
|
481 |
+
List of tokenized input ids. Can be obtained using the `__call__` method.
|
482 |
+
skip_special_tokens (`bool`, *optional*, defaults to `False`):
|
483 |
+
Whether or not to remove special tokens in the decoding.
|
484 |
+
clean_up_tokenization_spaces (`bool`, *optional*):
|
485 |
+
Whether or not to clean up the tokenization spaces. If `None`, will default to
|
486 |
+
`self.clean_up_tokenization_spaces`.
|
487 |
+
kwargs (additional keyword arguments, *optional*):
|
488 |
+
Will be passed to the underlying model specific decode method.
|
489 |
+
|
490 |
+
Returns:
|
491 |
+
`str`: The decoded sentence.
|
492 |
+
"""
|
493 |
+
# Convert inputs to python lists
|
494 |
+
return self._decode(
|
495 |
+
token_ids=token_ids,
|
496 |
+
skip_special_tokens=skip_special_tokens,
|
497 |
+
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
498 |
+
**kwargs,
|
499 |
+
)
|
500 |
+
|
501 |
+
def batch_decode(
|
502 |
+
self,
|
503 |
+
sequences: Union[List[int], List[List[int]], "np.ndarray", "torch.Tensor", "tf.Tensor"],
|
504 |
+
skip_special_tokens: bool = False,
|
505 |
+
clean_up_tokenization_spaces: bool = None,
|
506 |
+
**kwargs,
|
507 |
+
) -> List[str]:
|
508 |
+
"""
|
509 |
+
Convert a list of lists of token ids into a list of strings by calling decode.
|
510 |
+
|
511 |
+
Args:
|
512 |
+
sequences (`Union[List[int], List[List[int]], np.ndarray, torch.Tensor, tf.Tensor]`):
|
513 |
+
List of tokenized input ids. Can be obtained using the `__call__` method.
|
514 |
+
skip_special_tokens (`bool`, *optional*, defaults to `False`):
|
515 |
+
Whether or not to remove special tokens in the decoding.
|
516 |
+
clean_up_tokenization_spaces (`bool`, *optional*):
|
517 |
+
Whether or not to clean up the tokenization spaces. If `None`, will default to
|
518 |
+
`self.clean_up_tokenization_spaces`.
|
519 |
+
kwargs (additional keyword arguments, *optional*):
|
520 |
+
Will be passed to the underlying model specific decode method.
|
521 |
+
|
522 |
+
Returns:
|
523 |
+
`List[str]`: The list of decoded sentences.
|
524 |
+
"""
|
525 |
+
return [
|
526 |
+
self.decode(
|
527 |
+
seq,
|
528 |
+
skip_special_tokens=skip_special_tokens,
|
529 |
+
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
530 |
+
**kwargs,
|
531 |
+
)
|
532 |
+
for seq in sequences
|
533 |
+
]
|
534 |
|
535 |
def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
|
536 |
input_ids = []
|