KaleiNeely commited on
Commit
14f6f8d
1 Parent(s): af0c7e3

Upload 3 files

Browse files
Files changed (3) hide show
  1. configuration_rwkv5.py +22 -27
  2. modeling_rwkv5.py +160 -218
  3. tokenization_rwkv_world.py +91 -91
configuration_rwkv5.py CHANGED
@@ -21,46 +21,44 @@ from transformers.utils import logging
21
 
22
  logger = logging.get_logger(__name__)
23
 
24
- RWKV_PRETRAINED_CONFIG_ARCHIVE_MAP = {
25
-
26
- }
27
 
28
 
29
  class Rwkv5Config(PretrainedConfig):
30
  """
31
- This is the configuration class to store the configuration of a [`RwkvModel`]. It is used to instantiate a RWKV
32
  model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
33
  defaults will yield a similar configuration to that of the RWVK-4
34
- [RWKV/rwkv-4-169m-pile](https://huggingface.co/RWKV/rwkv-4-169m-pile) architecture.
35
 
36
  Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
37
  documentation from [`PretrainedConfig`] for more information.
38
 
39
 
40
  Args:
41
- vocab_size (`int`, *optional*, defaults to 50277):
42
- Vocabulary size of the RWKV model. Defines the number of different tokens that can be represented by the
43
- `inputs_ids` passed when calling [`RwkvModel`].
44
- context_length (`int`, *optional*, defaults to 1024):
45
- The maximum sequence length that this model can be be used with in a single forward (using it in RNN mode
46
- lets use any sequence length).
47
- hidden_size (`int`, *optional*, defaults to 4096):
48
  Dimensionality of the embeddings and hidden states.
49
- num_hidden_layers (`int`, *optional*, defaults to 32):
50
  Number of hidden layers in the model.
51
  attention_hidden_size (`int`, *optional*):
52
  Dimensionality of the attention hidden states. Will default to `hidden_size` if unset.
 
 
 
53
  intermediate_size (`int`, *optional*):
54
  Dimensionality of the inner feed-forward layers. Will default to 4 times `hidden_size` if unset.
55
- layer_norm_eps (`float`, *optional*, defaults to 1e-5):
56
  The epsilon to use in the layer normalization layers.
57
  bos_token_id (`int`, *optional*, defaults to 0):
58
- The id of the beginning of sentence token in the vocabulary. Defaults to 0 as RWKV uses the same tokenizer
59
  as GPTNeoX.
60
  eos_token_id (`int`, *optional*, defaults to 0):
61
- The id of the end of sentence token in the vocabulary. Defaults to 0 as RWKV uses the same tokenizer as
62
  GPTNeoX.
63
- rescale_every (`int`, *optional*, default to 6):
64
  At inference, the hidden states (and weights of the correponding output layers) are divided by 2 every
65
  `rescale_every` layer. If set to 0 or a negative number, no rescale is done.
66
  tie_word_embeddings (`bool`, *optional*, defaults to `False`):
@@ -72,28 +70,27 @@ class Rwkv5Config(PretrainedConfig):
72
  Example:
73
 
74
  ```python
75
- >>> from transformers import RwkvConfig, RwkvModel
76
 
77
- >>> # Initializing a Rwkv configuration
78
- >>> configuration = RwkvConfig()
79
 
80
  >>> # Initializing a model (with random weights) from the configuration
81
- >>> model = RwkvModel(configuration)
82
 
83
  >>> # Accessing the model configuration
84
  >>> configuration = model.config
85
  ```"""
86
 
87
  model_type = "rwkv5"
88
- attribute_map = {"max_position_embeddings": "context_length"}
89
 
90
- def __init__( #1.5B World
91
  self,
92
  vocab_size=65536,
93
- context_length=4096,
94
  hidden_size=768,
95
  num_hidden_layers=24,
96
  attention_hidden_size=None,
 
97
  head_size=64,
98
  intermediate_size=None,
99
  layer_norm_epsilon=1e-5,
@@ -102,14 +99,13 @@ class Rwkv5Config(PretrainedConfig):
102
  rescale_every=6,
103
  tie_word_embeddings=False,
104
  use_cache=True,
105
- model_version="5_2",
106
  **kwargs,
107
  ):
108
  self.vocab_size = vocab_size
109
- self.context_length = context_length
110
  self.hidden_size = hidden_size
111
  self.num_hidden_layers = num_hidden_layers
112
  self.attention_hidden_size = attention_hidden_size if attention_hidden_size is not None else hidden_size
 
113
  self.head_size = head_size
114
  self.intermediate_size = None
115
  self.layer_norm_epsilon = layer_norm_epsilon
@@ -118,7 +114,6 @@ class Rwkv5Config(PretrainedConfig):
118
 
119
  self.bos_token_id = bos_token_id
120
  self.eos_token_id = eos_token_id
121
- self.model_version = model_version
122
 
123
  super().__init__(
124
  tie_word_embeddings=tie_word_embeddings, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs
 
21
 
22
  logger = logging.get_logger(__name__)
23
 
24
+ RWKV5_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
 
 
25
 
26
 
27
  class Rwkv5Config(PretrainedConfig):
28
  """
29
+ This is the configuration class to store the configuration of a [`Rwkv5Model`]. It is used to instantiate a RWKV5
30
  model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
31
  defaults will yield a similar configuration to that of the RWVK-4
32
+ [RWKV/rwkv-5-world-1b5](https://huggingface.co/RWKV/rwkv-5-world-1b5) architecture.
33
 
34
  Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
35
  documentation from [`PretrainedConfig`] for more information.
36
 
37
 
38
  Args:
39
+ vocab_size (`int`, *optional*, defaults to 65536):
40
+ Vocabulary size of the RWKV5 model. Defines the number of different tokens that can be represented by the
41
+ `inputs_ids` passed when calling [`Rwkv5Model`].
42
+ hidden_size (`int`, *optional*, defaults to 768):
 
 
 
43
  Dimensionality of the embeddings and hidden states.
44
+ num_hidden_layers (`int`, *optional*, defaults to 24):
45
  Number of hidden layers in the model.
46
  attention_hidden_size (`int`, *optional*):
47
  Dimensionality of the attention hidden states. Will default to `hidden_size` if unset.
48
+ num_attention_heads (`int`, *optional*, defaults to 64):
49
+ The attention heads to use in rwkv5 self_attention module.
50
+ head_size (`int`, *optional*, defaults to 64): head_size of rwkv5 self_attention module.
51
  intermediate_size (`int`, *optional*):
52
  Dimensionality of the inner feed-forward layers. Will default to 4 times `hidden_size` if unset.
53
+ layer_norm_epsilon (`float`, *optional*, defaults to 1e-05):
54
  The epsilon to use in the layer normalization layers.
55
  bos_token_id (`int`, *optional*, defaults to 0):
56
+ The id of the beginning of sentence token in the vocabulary. Defaults to 0 as RWKV5 uses the same tokenizer
57
  as GPTNeoX.
58
  eos_token_id (`int`, *optional*, defaults to 0):
59
+ The id of the end of sentence token in the vocabulary. Defaults to 0 as RWKV5 uses the same tokenizer as
60
  GPTNeoX.
61
+ rescale_every (`int`, *optional*, defaults to 6):
62
  At inference, the hidden states (and weights of the correponding output layers) are divided by 2 every
63
  `rescale_every` layer. If set to 0 or a negative number, no rescale is done.
64
  tie_word_embeddings (`bool`, *optional*, defaults to `False`):
 
70
  Example:
71
 
72
  ```python
73
+ >>> from transformers import Rwkv5Config, Rwkv5Model
74
 
75
+ >>> # Initializing a Rwkv5 configuration
76
+ >>> configuration = Rwkv5Config()
77
 
78
  >>> # Initializing a model (with random weights) from the configuration
79
+ >>> model = Rwkv5Model(configuration)
80
 
81
  >>> # Accessing the model configuration
82
  >>> configuration = model.config
83
  ```"""
84
 
85
  model_type = "rwkv5"
 
86
 
87
+ def __init__(
88
  self,
89
  vocab_size=65536,
 
90
  hidden_size=768,
91
  num_hidden_layers=24,
92
  attention_hidden_size=None,
93
+ num_attention_heads=64,
94
  head_size=64,
95
  intermediate_size=None,
96
  layer_norm_epsilon=1e-5,
 
99
  rescale_every=6,
100
  tie_word_embeddings=False,
101
  use_cache=True,
 
102
  **kwargs,
103
  ):
104
  self.vocab_size = vocab_size
 
105
  self.hidden_size = hidden_size
106
  self.num_hidden_layers = num_hidden_layers
107
  self.attention_hidden_size = attention_hidden_size if attention_hidden_size is not None else hidden_size
108
+ self.num_attention_heads = num_attention_heads
109
  self.head_size = head_size
110
  self.intermediate_size = None
111
  self.layer_norm_epsilon = layer_norm_epsilon
 
114
 
115
  self.bos_token_id = bos_token_id
116
  self.eos_token_id = eos_token_id
 
117
 
118
  super().__init__(
119
  tie_word_embeddings=tie_word_embeddings, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs
modeling_rwkv5.py CHANGED
@@ -15,16 +15,13 @@
15
  # limitations under the License.
16
  """PyTorch RWKV5 World model."""
17
 
18
- import math
19
  from dataclasses import dataclass
20
- from pathlib import Path
21
  from typing import List, Optional, Tuple, Union
22
 
23
  import torch
 
24
  import torch.utils.checkpoint
25
  from torch import nn
26
- import torch.nn.functional as F
27
- from torch.nn import CrossEntropyLoss
28
 
29
  from transformers.modeling_utils import PreTrainedModel
30
  from transformers.utils import (
@@ -32,77 +29,59 @@ from transformers.utils import (
32
  add_code_sample_docstrings,
33
  add_start_docstrings,
34
  add_start_docstrings_to_model_forward,
35
- is_ninja_available,
36
- is_torch_cuda_available,
37
  logging,
38
  )
 
39
  from .configuration_rwkv5 import Rwkv5Config
40
 
41
 
42
  logger = logging.get_logger(__name__)
43
 
44
- _CHECKPOINT_FOR_DOC = "RWKV/rwkv-5-world"
45
  _CONFIG_FOR_DOC = "Rwkv5Config"
46
 
47
- RWKV_PRETRAINED_MODEL_ARCHIVE_LIST = [
48
-
 
49
  ]
50
 
51
- def rwkv_linear_attention_v5_0(H, S, T, hidden, time_decay, time_first, receptance, key, value, lxw, lxb, ow, state, return_state=False, seq_mode=True):
52
- time_decay = torch.exp(-torch.exp(time_decay.float())).reshape(-1,1,1)
53
- time_first = torch.exp(time_first.float()).reshape(-1,1,1)
54
- lxw = lxw.float()
55
- lxb = lxb.float()
56
-
57
- if seq_mode:
58
- w = time_decay.reshape(-1, 1)
59
- u = time_first.reshape(-1, 1)
60
- ws = w.pow(T).reshape(H, 1, 1)
61
- ind = torch.arange(T-1, -1, -1, device=w.device).unsqueeze(0).repeat(H, 1)
62
- w = w.repeat(1, T).pow(ind)
63
- wk = w.reshape(H, 1, T)
64
- wb = wk.transpose(-2, -1).flip(1)
65
- w = torch.cat([w[:, 1:], u], dim=1)
66
- w = F.pad(w, (0, T))
67
- w = torch.tile(w, [T])
68
- w = w[:, :-T].reshape(-1, T, 2 * T - 1)
69
- w = w[:, :, T-1:].reshape(H, T, T)
70
- out = ((receptance @ key) * w) @ value + (receptance @ state) * wb
71
- state = ws * state + (key * wk) @ value
72
-
73
- out = out.transpose(1, 2).contiguous().reshape(T, H*S)
74
- out = F.group_norm(out, num_groups=H, weight=lxw, bias=lxb)
75
- out = out.to(dtype=hidden.dtype)
76
- out = out @ ow
77
- else:
78
- a = key @ value
79
- out = receptance @ (time_first * a + state)
80
- state = a + time_decay * state
81
- out = out.flatten()
82
- out = F.group_norm(out.unsqueeze(0), num_groups=H, weight=lxw, bias=lxb)
83
- out = out.to(dtype=hidden.dtype)
84
- out = out @ ow
85
 
86
- return out, state
87
-
88
- cnt = 0
89
-
90
- def rwkv_linear_attention_v5_2(B, H, S, T, n_head, hidden, time_decay, time_first, receptance, key, value, gate, lxw, lxb, ow, state, return_state=False, seq_mode=True):
91
- time_decay = torch.exp(-torch.exp(time_decay.float())).reshape(-1,1,1).reshape(n_head, -1, 1)
92
- time_first = time_first.float().reshape(-1,1,1).reshape(n_head, -1, 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  lxw = lxw.float()
94
  lxb = lxb.float()
 
95
  out = torch.empty((B, T, H, S), dtype=receptance.dtype, device=receptance.device)
96
  for t in range(T):
97
- rt = receptance[:,:,t:t+1,:]
98
- kt = key[:,:,:,t:t+1]
99
- vt = value[:,:,t:t+1,:]
100
  at = kt @ vt
101
  out[:, t] = (rt @ (time_first * at + state)).squeeze(2)
102
  state = at + time_decay * state
103
 
104
- out = out.reshape(B*T, H*S)
105
- out = F.group_norm(out, num_groups=H, weight=lxw, bias=lxb).reshape(B, T, H*S)
106
  out = out.to(dtype=hidden.dtype) * gate
107
  out = out @ ow
108
 
@@ -123,13 +102,9 @@ class RwkvSelfAttention(nn.Module):
123
  )
124
  self.attention_hidden_size = attention_hidden_size
125
 
126
- if self.config.model_version == "5_2":
127
- self.time_decay = nn.Parameter(torch.empty(num_attention_heads, config.head_size))
128
- self.time_faaaa = nn.Parameter(torch.empty(num_attention_heads, config.head_size))
129
- self.time_mix_gate = nn.Parameter(torch.empty(1, 1, hidden_size))
130
- else:
131
- self.time_decay = nn.Parameter(torch.empty(num_attention_heads))
132
- self.time_first = nn.Parameter(torch.empty(num_attention_heads))
133
 
134
  self.time_mix_key = nn.Parameter(torch.empty(1, 1, hidden_size))
135
  self.time_mix_value = nn.Parameter(torch.empty(1, 1, hidden_size))
@@ -139,8 +114,7 @@ class RwkvSelfAttention(nn.Module):
139
  self.key = nn.Linear(hidden_size, attention_hidden_size, bias=False)
140
  self.value = nn.Linear(hidden_size, attention_hidden_size, bias=False)
141
  self.receptance = nn.Linear(hidden_size, attention_hidden_size, bias=False)
142
- if self.config.model_version == "5_2":
143
- self.gate = nn.Linear(hidden_size, attention_hidden_size, bias=False)
144
  self.output = nn.Linear(attention_hidden_size, hidden_size, bias=False)
145
  # https://github.com/BlinkDL/RWKV-LM/blob/3db37a72356b736966ddd377268f02b80963af3f/RWKV-v4neo/src/model.py#L190C1-L190C1
146
  self.ln_x = nn.GroupNorm(hidden_size // config.head_size, hidden_size)
@@ -155,32 +129,22 @@ class RwkvSelfAttention(nn.Module):
155
  if state is not None:
156
  shifted[:, 0] = state[0][:, :, self.layer_id]
157
  if len(shifted.size()) == 2:
158
- shifted = shifted.unsqueeze(1)
159
  key = hidden * self.time_mix_key + shifted * (1 - self.time_mix_key)
160
  value = hidden * self.time_mix_value + shifted * (1 - self.time_mix_value)
161
  receptance = hidden * self.time_mix_receptance + shifted * (1 - self.time_mix_receptance)
162
- if self.config.model_version == "5_2":
163
- gate = hidden* self.time_mix_gate + shifted * (1 - self.time_mix_gate)
164
-
165
- # if hidden.size(1) == 1 and state is not None:
166
- # receptance = self.receptance(receptance).to(torch.float32).view(B, H, 1, S)
167
- # key = self.key(key).to(torch.float32).view(B, H, S, 1)
168
- # value = self.value(value).to(torch.float32).view(B, H, 1, S)
169
- # else:
170
  # https://github.com/BlinkDL/ChatRWKV/blob/main/rwkv_pip_package/src/rwkv/model.py#L693
171
  key = self.key(key).to(torch.float32).view(B, T, H, S).transpose(1, 2).transpose(-2, -1)
172
  value = self.value(value).to(torch.float32).view(B, T, H, S).transpose(1, 2)
173
  receptance = self.receptance(receptance).to(torch.float32).view(B, T, H, S).transpose(1, 2)
 
174
 
175
- if self.config.model_version == "5_2":
176
- gate = F.silu(self.gate(gate))
177
-
178
  if state is not None:
179
  state[0][:, :, self.layer_id] = hidden[:, -1]
180
-
181
- if self.config.model_version == "5_2":
182
- return receptance, key, value, gate, state
183
- return receptance, key, value, state
184
 
185
  def forward(self, hidden, state=None, use_cache=False, seq_mode=True):
186
  B = hidden.shape[0]
@@ -188,13 +152,9 @@ class RwkvSelfAttention(nn.Module):
188
  S = hidden.shape[-1] // H
189
  T = hidden.shape[1]
190
 
191
- if self.config.model_version == "5_2":
192
- receptance, key, value, gate, state = self.extract_key_value(B, H, S, T, hidden, state=state)
193
- else:
194
- receptance, key, value, state = self.extract_key_value(H, S, T, hidden, state=state)
195
  layer_state = state[1][:, :, :, :, self.layer_id] if state is not None else None
196
- if self.config.model_version == "5_2":
197
- rwkv, layer_state = rwkv_linear_attention_v5_2(
198
  B,
199
  H,
200
  S,
@@ -214,24 +174,6 @@ class RwkvSelfAttention(nn.Module):
214
  return_state=use_cache,
215
  seq_mode=seq_mode,
216
  )
217
- else:
218
- rwkv, layer_state = rwkv_linear_attention_v5_0(
219
- H,
220
- S,
221
- T,
222
- hidden,
223
- self.time_decay,
224
- self.time_first,
225
- receptance,
226
- key,
227
- value,
228
- self.ln_x.weight,
229
- self.ln_x.bias,
230
- self.output.weight.t(),
231
- state=layer_state,
232
- return_state=use_cache,
233
- seq_mode=seq_mode,
234
- )
235
 
236
  if layer_state is not None:
237
  state[1][:, :, :, :, self.layer_id] = layer_state
@@ -246,19 +188,16 @@ class RwkvFeedForward(nn.Module):
246
  self.layer_id = layer_id
247
  hidden_size = config.hidden_size
248
  # https://github.com/BlinkDL/RWKV-LM/blob/3db37a72356b736966ddd377268f02b80963af3f/RWKV-v4neo/train.py#L168
249
- if self.config.model_version == "5_2":
250
- intermediate_size = (
251
- config.intermediate_size if config.intermediate_size is not None else int((config.hidden_size * 3.5) // 32 * 32)
252
- )
253
- else:
254
- intermediate_size = (
255
- config.intermediate_size if config.intermediate_size is not None else 4 * config.hidden_size
256
- )
257
 
258
  self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
259
  self.time_mix_key = nn.Parameter(torch.empty(1, 1, hidden_size))
260
  self.time_mix_receptance = nn.Parameter(torch.empty(1, 1, hidden_size))
261
-
262
  self.key = nn.Linear(hidden_size, intermediate_size, bias=False)
263
  self.receptance = nn.Linear(hidden_size, hidden_size, bias=False)
264
  self.value = nn.Linear(intermediate_size, hidden_size, bias=False)
@@ -301,7 +240,6 @@ class RwkvBlock(nn.Module):
301
  self.feed_forward = RwkvFeedForward(config, layer_id)
302
 
303
  def forward(self, hidden, state=None, use_cache=False, output_attentions=False, seq_mode=True):
304
-
305
  attention, state = self.attention(self.ln1(hidden), state=state, use_cache=use_cache, seq_mode=seq_mode)
306
  hidden = hidden + attention
307
 
@@ -317,16 +255,18 @@ class RwkvBlock(nn.Module):
317
  return outputs
318
 
319
 
320
- class RwkvPreTrainedModel(PreTrainedModel):
321
  """
322
  An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
323
  models.
324
  """
325
 
326
  config_class = Rwkv5Config
327
- base_model_prefix = "transformer"
328
  _no_split_modules = ["RwkvBlock"]
329
  _keep_in_fp32_modules = ["time_decay", "time_first"]
 
 
330
 
331
  def _init_weights(self, module):
332
  """Initialize the weights."""
@@ -335,7 +275,7 @@ class RwkvPreTrainedModel(PreTrainedModel):
335
  num_hidden_layers = module.config.num_hidden_layers
336
  hidden_size = module.config.hidden_size
337
  attention_hidden_size = module.attention_hidden_size
338
- num_attention_heads = hidden_size // module.config.head_size
339
 
340
  ratio_0_to_1 = layer_id / (num_hidden_layers - 1) # 0 to 1
341
  ratio_1_to_almost0 = 1.0 - (layer_id / num_hidden_layers) # 1 to ~0
@@ -347,43 +287,30 @@ class RwkvPreTrainedModel(PreTrainedModel):
347
  )
348
  time_weight = time_weight[None, None, :]
349
 
350
- if module.config.model_version == "5_2":
351
- # https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4neo/src/model.py#L398
352
- decay_speed = [
353
- -6.0 + 5.0 * (h / (attention_hidden_size - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
354
- for h in range(attention_hidden_size)
355
- ]
356
- else:
357
- # https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4neo/src/model.py#L172
358
- decay_speed = [
359
- -6.0 + 5.0 * (h / (num_attention_heads - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
360
- for h in range(num_attention_heads)
361
- ]
362
  decay_speed = torch.tensor(decay_speed, dtype=module.time_decay.dtype, device=module.time_decay.device)
363
- if module.config.model_version == "5_2":
364
- tmp = (
365
- torch.tensor(
366
- [(1.0 - (i / (attention_hidden_size - 1.0))) * ratio_0_to_1 + 0.1 * ((i + 1) % 3 - 1) for i in range(attention_hidden_size)],
367
- dtype=module.time_faaaa.dtype,
368
- device=module.time_faaaa.device,
369
- )
370
- )
371
- else:
372
- tmp = torch.ones(num_attention_heads) * (-3.0)
373
 
374
  with torch.no_grad():
375
- if module.config.model_version == "5_2":
376
- module.time_decay.data = decay_speed.reshape(num_attention_heads, module.config.head_size)
377
- module.time_faaaa.data = tmp.reshape(num_attention_heads, module.config.head_size)
378
- else:
379
- module.time_decay.data = decay_speed
380
- module.time_first.data = tmp
381
-
382
  module.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0)
 
383
  module.time_mix_value.data = torch.pow(time_weight, ratio_1_to_almost0) + 0.3 * ratio_0_to_1
384
  module.time_mix_receptance.data = torch.pow(time_weight, 0.5 * ratio_1_to_almost0)
385
- if module.config.model_version == "5_2":
386
- module.time_mix_gate.data = torch.pow(time_weight, 0.5 * ratio_1_to_almost0)
387
  elif isinstance(module, RwkvFeedForward):
388
  layer_id = module.layer_id
389
  num_hidden_layers = module.config.num_hidden_layers
@@ -402,13 +329,9 @@ class RwkvPreTrainedModel(PreTrainedModel):
402
  module.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0)
403
  module.time_mix_receptance.data = torch.pow(time_weight, ratio_1_to_almost0)
404
 
405
- def _set_gradient_checkpointing(self, module, value=False):
406
- if isinstance(module, RwkvModel):
407
- module.gradient_checkpointing = value
408
-
409
 
410
  @dataclass
411
- class RwkvOutput(ModelOutput):
412
  """
413
  Class for the RWKV model outputs.
414
 
@@ -420,15 +343,12 @@ class RwkvOutput(ModelOutput):
420
  avoid providing the old `input_ids`.
421
  hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
422
  Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
423
- one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
424
-
425
- Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
426
  attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
427
  Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
428
- sequence_length)`.
429
-
430
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
431
- heads.
432
  """
433
 
434
  last_hidden_state: torch.FloatTensor = None
@@ -438,7 +358,7 @@ class RwkvOutput(ModelOutput):
438
 
439
 
440
  @dataclass
441
- class RwkvCausalLMOutput(ModelOutput):
442
  """
443
  Base class for causal language model (or autoregressive) outputs.
444
 
@@ -452,33 +372,27 @@ class RwkvCausalLMOutput(ModelOutput):
452
  avoid providing the old `input_ids`.
453
  hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
454
  Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
455
- one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
456
-
457
- Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
458
  attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
459
  Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
460
- sequence_length)`.
461
-
462
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
463
- heads.
464
  """
465
 
466
  loss: Optional[torch.FloatTensor] = None
467
  logits: torch.FloatTensor = None
468
  state: Optional[List[torch.FloatTensor]] = None
469
- last_hidden_state: Optional[Tuple[torch.FloatTensor]] = None
470
  attentions: Optional[Tuple[torch.FloatTensor]] = None
471
 
472
 
473
  RWKV_START_DOCSTRING = r"""
474
-
475
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
476
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
477
- etc.)
478
-
479
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
480
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
481
- and behavior.
482
 
483
  Parameters:
484
  config ([`Rwkv5Config`]): Model configuration class with all the parameters of the model.
@@ -491,15 +405,10 @@ RWKV_INPUTS_DOCSTRING = r"""
491
  input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
492
  `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
493
  `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
494
- sequence tokens in the vocabulary.
495
-
496
- If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
497
- `input_ids`.
498
-
499
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
500
- [`PreTrainedTokenizer.__call__`] for details.
501
-
502
- [What are input IDs?](../glossary#input-ids)
503
  inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
504
  Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
505
  is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
@@ -524,7 +433,7 @@ RWKV_INPUTS_DOCSTRING = r"""
524
  "The bare RWKV Model transformer outputting raw hidden-states without any specific head on top.",
525
  RWKV_START_DOCSTRING,
526
  )
527
- class RwkvModel(RwkvPreTrainedModel):
528
  def __init__(self, config):
529
  super().__init__(config)
530
 
@@ -535,6 +444,8 @@ class RwkvModel(RwkvPreTrainedModel):
535
  self.layers_are_rescaled = False
536
  self.pre_ln_flag = False
537
 
 
 
538
  # Initialize weights and apply final processing
539
  self.post_init()
540
 
@@ -547,28 +458,31 @@ class RwkvModel(RwkvPreTrainedModel):
547
  @add_start_docstrings_to_model_forward(RWKV_INPUTS_DOCSTRING)
548
  @add_code_sample_docstrings(
549
  checkpoint=_CHECKPOINT_FOR_DOC,
550
- output_type=RwkvOutput,
551
  config_class=_CONFIG_FOR_DOC,
552
  )
553
  def forward(
554
  self,
555
  input_ids: Optional[torch.LongTensor] = None,
 
556
  inputs_embeds: Optional[torch.FloatTensor] = None,
557
  state: Optional[List[torch.FloatTensor]] = None,
558
  use_cache: Optional[bool] = None,
559
  output_attentions: Optional[bool] = None,
560
  output_hidden_states: Optional[bool] = None,
561
  return_dict: Optional[bool] = None,
562
- ) -> Union[Tuple, RwkvOutput]:
563
- seq_mode = input_ids.shape[1] > 1
564
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
565
  output_hidden_states = (
566
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
567
  )
568
- use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
 
569
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
570
 
571
- if self.training == self.layers_are_rescaled and (self.embeddings.weight.dtype == torch.float16 or self.embeddings.weight.dtype == torch.bfloat16):
 
 
572
  self._rescale_layers()
573
 
574
  if input_ids is not None and inputs_embeds is not None:
@@ -578,23 +492,55 @@ class RwkvModel(RwkvPreTrainedModel):
578
 
579
  if inputs_embeds is None:
580
  if not self.pre_ln_flag:
581
- normalized_weight = F.layer_norm(self.embeddings.weight, (self.config.hidden_size, ), weight=self.blocks[0].pre_ln.weight, bias=self.blocks[0].pre_ln.bias)
 
 
 
 
 
582
  self.embeddings.weight = nn.Parameter(normalized_weight)
583
  self.pre_ln_flag = True
 
584
  inputs_embeds = self.embeddings(input_ids)
585
 
586
  if use_cache and state is None:
587
  # https://github.com/BlinkDL/ChatRWKV/blob/main/rwkv_pip_package/src/rwkv/model.py#L904-L906
588
  state = []
589
- num_attention_heads = self.config.hidden_size // self.config.head_size
590
- state.append(torch.zeros((inputs_embeds.size(0), self.config.hidden_size, self.config.num_hidden_layers), dtype=inputs_embeds.dtype, requires_grad=False, device=inputs_embeds.device).contiguous())
591
- state.append(torch.zeros((inputs_embeds.size(0), num_attention_heads, self.config.hidden_size // num_attention_heads, self.config.hidden_size // num_attention_heads, self.config.num_hidden_layers), dtype=torch.float32, requires_grad=False, device=inputs_embeds.device).contiguous())
592
- state.append(torch.zeros((inputs_embeds.size(0), self.config.hidden_size, self.config.num_hidden_layers), dtype=inputs_embeds.dtype, requires_grad=False, device=inputs_embeds.device).contiguous())
593
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
594
 
 
595
  hidden_states = inputs_embeds
596
- global cnt
597
- cnt += 1
598
  all_self_attentions = () if output_attentions else None
599
  all_hidden_states = () if output_hidden_states else None
600
  for idx, block in enumerate(self.blocks):
@@ -622,11 +568,11 @@ class RwkvModel(RwkvPreTrainedModel):
622
  if not return_dict:
623
  return (hidden_states, state, all_hidden_states, all_self_attentions)
624
 
625
- return RwkvOutput(
626
  last_hidden_state=hidden_states,
627
  state=state,
628
- hidden_states=all_hidden_states, #None
629
- attentions=all_self_attentions, #None
630
  )
631
 
632
  def _rescale_layers(self):
@@ -645,6 +591,7 @@ class RwkvModel(RwkvPreTrainedModel):
645
 
646
  self.layers_are_rescaled = not self.training
647
 
 
648
  @add_start_docstrings(
649
  """
650
  The RWKV Model transformer with a language modeling head on top (linear layer with weights tied to the input
@@ -652,10 +599,12 @@ class RwkvModel(RwkvPreTrainedModel):
652
  """,
653
  RWKV_START_DOCSTRING,
654
  )
655
- class RwkvForCausalLM(RwkvPreTrainedModel):
 
 
656
  def __init__(self, config):
657
  super().__init__(config)
658
- self.rwkv = RwkvModel(config)
659
  self.head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
660
 
661
  # Initialize weights and apply final processing
@@ -684,7 +633,7 @@ class RwkvForCausalLM(RwkvPreTrainedModel):
684
  @add_start_docstrings_to_model_forward(RWKV_INPUTS_DOCSTRING)
685
  @add_code_sample_docstrings(
686
  checkpoint=_CHECKPOINT_FOR_DOC,
687
- output_type=RwkvCausalLMOutput,
688
  config_class=_CONFIG_FOR_DOC,
689
  )
690
  def forward(
@@ -698,7 +647,7 @@ class RwkvForCausalLM(RwkvPreTrainedModel):
698
  output_attentions: Optional[bool] = None,
699
  output_hidden_states: Optional[bool] = None,
700
  return_dict: Optional[bool] = None,
701
- ) -> Union[Tuple, RwkvCausalLMOutput]:
702
  r"""
703
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
704
  Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
@@ -716,30 +665,23 @@ class RwkvForCausalLM(RwkvPreTrainedModel):
716
  output_hidden_states=output_hidden_states,
717
  return_dict=return_dict,
718
  )
719
- last_hidden_state = rwkv_outputs.last_hidden_state
720
- state = rwkv_outputs.state
721
 
722
- logits = self.head(last_hidden_state)
723
 
724
  loss = None
725
  if labels is not None:
726
- # move labels to correct device to enable model parallelism
727
- labels = labels.to(logits.device)
728
- # Shift so that tokens < n predict n
729
- shift_logits = logits[..., :-1, :].contiguous()
730
- shift_labels = labels[..., 1:].contiguous()
731
- # Flatten the tokens
732
- loss_fct = CrossEntropyLoss()
733
- loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
734
 
735
  if not return_dict:
736
  output = (logits,) + rwkv_outputs[1:]
737
  return ((loss,) + output) if loss is not None else output
738
 
739
- return RwkvCausalLMOutput(
740
  loss=loss,
741
  logits=logits,
742
  state=rwkv_outputs.state,
743
- last_hidden_state=rwkv_outputs.last_hidden_state,
744
  attentions=rwkv_outputs.attentions,
745
  )
 
15
  # limitations under the License.
16
  """PyTorch RWKV5 World model."""
17
 
 
18
  from dataclasses import dataclass
 
19
  from typing import List, Optional, Tuple, Union
20
 
21
  import torch
22
+ import torch.nn.functional as F
23
  import torch.utils.checkpoint
24
  from torch import nn
 
 
25
 
26
  from transformers.modeling_utils import PreTrainedModel
27
  from transformers.utils import (
 
29
  add_code_sample_docstrings,
30
  add_start_docstrings,
31
  add_start_docstrings_to_model_forward,
 
 
32
  logging,
33
  )
34
+
35
  from .configuration_rwkv5 import Rwkv5Config
36
 
37
 
38
  logger = logging.get_logger(__name__)
39
 
40
+ _CHECKPOINT_FOR_DOC = "RWKV/rwkv-5-world-1b5"
41
  _CONFIG_FOR_DOC = "Rwkv5Config"
42
 
43
+ RWKV5_PRETRAINED_MODEL_ARCHIVE_LIST = [
44
+ "RWKV/rwkv-5-world-1b5",
45
+ # See all RWKV models at https://huggingface.co/models?filter=rwkv
46
  ]
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
+ def rwkv_linear_attention_v5(
50
+ B,
51
+ H,
52
+ S,
53
+ T,
54
+ n_head,
55
+ hidden,
56
+ time_decay,
57
+ time_first,
58
+ receptance,
59
+ key,
60
+ value,
61
+ gate,
62
+ lxw,
63
+ lxb,
64
+ ow,
65
+ state,
66
+ return_state=False,
67
+ seq_mode=True,
68
+ ):
69
+ time_decay = torch.exp(-torch.exp(time_decay.float())).reshape(-1, 1, 1).reshape(n_head, -1, 1)
70
+ time_first = time_first.float().reshape(-1, 1, 1).reshape(n_head, -1, 1)
71
  lxw = lxw.float()
72
  lxb = lxb.float()
73
+ # if seq_mode:
74
  out = torch.empty((B, T, H, S), dtype=receptance.dtype, device=receptance.device)
75
  for t in range(T):
76
+ rt = receptance[:, :, t : t + 1, :]
77
+ kt = key[:, :, :, t : t + 1]
78
+ vt = value[:, :, t : t + 1, :]
79
  at = kt @ vt
80
  out[:, t] = (rt @ (time_first * at + state)).squeeze(2)
81
  state = at + time_decay * state
82
 
83
+ out = out.reshape(B * T, H * S)
84
+ out = F.group_norm(out, num_groups=H, weight=lxw, bias=lxb).reshape(B, T, H * S)
85
  out = out.to(dtype=hidden.dtype) * gate
86
  out = out @ ow
87
 
 
102
  )
103
  self.attention_hidden_size = attention_hidden_size
104
 
105
+ self.time_decay = nn.Parameter(torch.empty(num_attention_heads, config.head_size))
106
+ self.time_faaaa = nn.Parameter(torch.empty(num_attention_heads, config.head_size))
107
+ self.time_mix_gate = nn.Parameter(torch.empty(1, 1, hidden_size))
 
 
 
 
108
 
109
  self.time_mix_key = nn.Parameter(torch.empty(1, 1, hidden_size))
110
  self.time_mix_value = nn.Parameter(torch.empty(1, 1, hidden_size))
 
114
  self.key = nn.Linear(hidden_size, attention_hidden_size, bias=False)
115
  self.value = nn.Linear(hidden_size, attention_hidden_size, bias=False)
116
  self.receptance = nn.Linear(hidden_size, attention_hidden_size, bias=False)
117
+ self.gate = nn.Linear(hidden_size, attention_hidden_size, bias=False)
 
118
  self.output = nn.Linear(attention_hidden_size, hidden_size, bias=False)
119
  # https://github.com/BlinkDL/RWKV-LM/blob/3db37a72356b736966ddd377268f02b80963af3f/RWKV-v4neo/src/model.py#L190C1-L190C1
120
  self.ln_x = nn.GroupNorm(hidden_size // config.head_size, hidden_size)
 
129
  if state is not None:
130
  shifted[:, 0] = state[0][:, :, self.layer_id]
131
  if len(shifted.size()) == 2:
132
+ shifted = shifted.unsqueeze(1)
133
  key = hidden * self.time_mix_key + shifted * (1 - self.time_mix_key)
134
  value = hidden * self.time_mix_value + shifted * (1 - self.time_mix_value)
135
  receptance = hidden * self.time_mix_receptance + shifted * (1 - self.time_mix_receptance)
136
+ gate = hidden * self.time_mix_gate + shifted * (1 - self.time_mix_gate)
137
+
 
 
 
 
 
 
138
  # https://github.com/BlinkDL/ChatRWKV/blob/main/rwkv_pip_package/src/rwkv/model.py#L693
139
  key = self.key(key).to(torch.float32).view(B, T, H, S).transpose(1, 2).transpose(-2, -1)
140
  value = self.value(value).to(torch.float32).view(B, T, H, S).transpose(1, 2)
141
  receptance = self.receptance(receptance).to(torch.float32).view(B, T, H, S).transpose(1, 2)
142
+ gate = F.silu(self.gate(gate))
143
 
 
 
 
144
  if state is not None:
145
  state[0][:, :, self.layer_id] = hidden[:, -1]
146
+
147
+ return receptance, key, value, gate, state
 
 
148
 
149
  def forward(self, hidden, state=None, use_cache=False, seq_mode=True):
150
  B = hidden.shape[0]
 
152
  S = hidden.shape[-1] // H
153
  T = hidden.shape[1]
154
 
155
+ receptance, key, value, gate, state = self.extract_key_value(B, H, S, T, hidden, state=state)
 
 
 
156
  layer_state = state[1][:, :, :, :, self.layer_id] if state is not None else None
157
+ rwkv, layer_state = rwkv_linear_attention_v5(
 
158
  B,
159
  H,
160
  S,
 
174
  return_state=use_cache,
175
  seq_mode=seq_mode,
176
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
  if layer_state is not None:
179
  state[1][:, :, :, :, self.layer_id] = layer_state
 
188
  self.layer_id = layer_id
189
  hidden_size = config.hidden_size
190
  # https://github.com/BlinkDL/RWKV-LM/blob/3db37a72356b736966ddd377268f02b80963af3f/RWKV-v4neo/train.py#L168
191
+ intermediate_size = (
192
+ config.intermediate_size
193
+ if config.intermediate_size is not None
194
+ else int((config.hidden_size * 3.5) // 32 * 32)
195
+ )
 
 
 
196
 
197
  self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
198
  self.time_mix_key = nn.Parameter(torch.empty(1, 1, hidden_size))
199
  self.time_mix_receptance = nn.Parameter(torch.empty(1, 1, hidden_size))
200
+
201
  self.key = nn.Linear(hidden_size, intermediate_size, bias=False)
202
  self.receptance = nn.Linear(hidden_size, hidden_size, bias=False)
203
  self.value = nn.Linear(intermediate_size, hidden_size, bias=False)
 
240
  self.feed_forward = RwkvFeedForward(config, layer_id)
241
 
242
  def forward(self, hidden, state=None, use_cache=False, output_attentions=False, seq_mode=True):
 
243
  attention, state = self.attention(self.ln1(hidden), state=state, use_cache=use_cache, seq_mode=seq_mode)
244
  hidden = hidden + attention
245
 
 
255
  return outputs
256
 
257
 
258
+ class Rwkv5PreTrainedModel(PreTrainedModel):
259
  """
260
  An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
261
  models.
262
  """
263
 
264
  config_class = Rwkv5Config
265
+ base_model_prefix = "rwkv"
266
  _no_split_modules = ["RwkvBlock"]
267
  _keep_in_fp32_modules = ["time_decay", "time_first"]
268
+ supports_gradient_checkpointing = True
269
+ training = False
270
 
271
  def _init_weights(self, module):
272
  """Initialize the weights."""
 
275
  num_hidden_layers = module.config.num_hidden_layers
276
  hidden_size = module.config.hidden_size
277
  attention_hidden_size = module.attention_hidden_size
278
+ num_attention_heads = hidden_size // module.config.num_attention_heads
279
 
280
  ratio_0_to_1 = layer_id / (num_hidden_layers - 1) # 0 to 1
281
  ratio_1_to_almost0 = 1.0 - (layer_id / num_hidden_layers) # 1 to ~0
 
287
  )
288
  time_weight = time_weight[None, None, :]
289
 
290
+ # https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4neo/src/model.py#L398
291
+ decay_speed = [
292
+ -6.0 + 5.0 * (h / (attention_hidden_size - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
293
+ for h in range(attention_hidden_size)
294
+ ]
 
 
 
 
 
 
 
295
  decay_speed = torch.tensor(decay_speed, dtype=module.time_decay.dtype, device=module.time_decay.device)
296
+ tmp = torch.tensor(
297
+ [
298
+ (1.0 - (i / (attention_hidden_size - 1.0))) * ratio_0_to_1 + 0.1 * ((i + 1) % 3 - 1)
299
+ for i in range(attention_hidden_size)
300
+ ],
301
+ dtype=module.time_faaaa.dtype,
302
+ device=module.time_faaaa.device,
303
+ )
 
 
304
 
305
  with torch.no_grad():
306
+ module.time_decay.data = decay_speed.reshape(num_attention_heads, module.config.num_attention_heads)
307
+ module.time_faaaa.data = tmp.reshape(num_attention_heads, module.config.num_attention_heads)
 
 
 
 
 
308
  module.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0)
309
+
310
  module.time_mix_value.data = torch.pow(time_weight, ratio_1_to_almost0) + 0.3 * ratio_0_to_1
311
  module.time_mix_receptance.data = torch.pow(time_weight, 0.5 * ratio_1_to_almost0)
312
+ module.time_mix_gate.data = torch.pow(time_weight, 0.5 * ratio_1_to_almost0)
313
+
314
  elif isinstance(module, RwkvFeedForward):
315
  layer_id = module.layer_id
316
  num_hidden_layers = module.config.num_hidden_layers
 
329
  module.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0)
330
  module.time_mix_receptance.data = torch.pow(time_weight, ratio_1_to_almost0)
331
 
 
 
 
 
332
 
333
  @dataclass
334
+ class Rwkv5Output(ModelOutput):
335
  """
336
  Class for the RWKV model outputs.
337
 
 
343
  avoid providing the old `input_ids`.
344
  hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
345
  Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
346
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of
347
+ the model at the output of each layer plus the optional initial embedding outputs.
 
348
  attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
349
  Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
350
+ sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
351
+ the self-attention heads.
 
 
352
  """
353
 
354
  last_hidden_state: torch.FloatTensor = None
 
358
 
359
 
360
  @dataclass
361
+ class Rwkv5CausalLMOutput(ModelOutput):
362
  """
363
  Base class for causal language model (or autoregressive) outputs.
364
 
 
372
  avoid providing the old `input_ids`.
373
  hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
374
  Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
375
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of
376
+ the model at the output of each layer plus the optional initial embedding outputs.
 
377
  attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
378
  Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
379
+ sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
380
+ the self-attention heads.
 
 
381
  """
382
 
383
  loss: Optional[torch.FloatTensor] = None
384
  logits: torch.FloatTensor = None
385
  state: Optional[List[torch.FloatTensor]] = None
386
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
387
  attentions: Optional[Tuple[torch.FloatTensor]] = None
388
 
389
 
390
  RWKV_START_DOCSTRING = r"""
 
391
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
392
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
393
+ etc.) This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module)
394
+ subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
395
+ general usage and behavior.
 
 
396
 
397
  Parameters:
398
  config ([`Rwkv5Config`]): Model configuration class with all the parameters of the model.
 
405
  input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
406
  `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
407
  `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
408
+ sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their
409
+ past calculated should be passed as `input_ids`. Indices can be obtained using [`AutoTokenizer`]. See
410
+ [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input
411
+ IDs?](../glossary#input-ids)
 
 
 
 
 
412
  inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
413
  Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
414
  is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
 
433
  "The bare RWKV Model transformer outputting raw hidden-states without any specific head on top.",
434
  RWKV_START_DOCSTRING,
435
  )
436
+ class Rwkv5Model(Rwkv5PreTrainedModel):
437
  def __init__(self, config):
438
  super().__init__(config)
439
 
 
444
  self.layers_are_rescaled = False
445
  self.pre_ln_flag = False
446
 
447
+ self.gradient_checkpointing = False
448
+
449
  # Initialize weights and apply final processing
450
  self.post_init()
451
 
 
458
  @add_start_docstrings_to_model_forward(RWKV_INPUTS_DOCSTRING)
459
  @add_code_sample_docstrings(
460
  checkpoint=_CHECKPOINT_FOR_DOC,
461
+ output_type=Rwkv5Output,
462
  config_class=_CONFIG_FOR_DOC,
463
  )
464
  def forward(
465
  self,
466
  input_ids: Optional[torch.LongTensor] = None,
467
+ attention_mask: Optional[torch.LongTensor] = None, # noqa
468
  inputs_embeds: Optional[torch.FloatTensor] = None,
469
  state: Optional[List[torch.FloatTensor]] = None,
470
  use_cache: Optional[bool] = None,
471
  output_attentions: Optional[bool] = None,
472
  output_hidden_states: Optional[bool] = None,
473
  return_dict: Optional[bool] = None,
474
+ ) -> Union[Tuple, Rwkv5Output]:
 
475
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
476
  output_hidden_states = (
477
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
478
  )
479
+ # rwkv5 only support inference in huggingface.
480
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
481
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
482
 
483
+ if self.training == self.layers_are_rescaled and (
484
+ self.embeddings.weight.dtype == torch.float16 or self.embeddings.weight.dtype == torch.bfloat16
485
+ ):
486
  self._rescale_layers()
487
 
488
  if input_ids is not None and inputs_embeds is not None:
 
492
 
493
  if inputs_embeds is None:
494
  if not self.pre_ln_flag:
495
+ normalized_weight = F.layer_norm(
496
+ self.embeddings.weight,
497
+ (self.config.hidden_size,),
498
+ weight=self.blocks[0].pre_ln.weight,
499
+ bias=self.blocks[0].pre_ln.bias,
500
+ )
501
  self.embeddings.weight = nn.Parameter(normalized_weight)
502
  self.pre_ln_flag = True
503
+
504
  inputs_embeds = self.embeddings(input_ids)
505
 
506
  if use_cache and state is None:
507
  # https://github.com/BlinkDL/ChatRWKV/blob/main/rwkv_pip_package/src/rwkv/model.py#L904-L906
508
  state = []
509
+ num_attention_heads = self.config.hidden_size // self.config.num_attention_heads
510
+ state.append(
511
+ torch.zeros(
512
+ (inputs_embeds.size(0), self.config.hidden_size, self.config.num_hidden_layers),
513
+ dtype=inputs_embeds.dtype,
514
+ requires_grad=False,
515
+ device=inputs_embeds.device,
516
+ ).contiguous()
517
+ )
518
+ state.append(
519
+ torch.zeros(
520
+ (
521
+ inputs_embeds.size(0),
522
+ num_attention_heads,
523
+ self.config.hidden_size // num_attention_heads,
524
+ self.config.hidden_size // num_attention_heads,
525
+ self.config.num_hidden_layers,
526
+ ),
527
+ dtype=torch.float32,
528
+ requires_grad=False,
529
+ device=inputs_embeds.device,
530
+ ).contiguous()
531
+ )
532
+ state.append(
533
+ torch.zeros(
534
+ (inputs_embeds.size(0), self.config.hidden_size, self.config.num_hidden_layers),
535
+ dtype=inputs_embeds.dtype,
536
+ requires_grad=False,
537
+ device=inputs_embeds.device,
538
+ ).contiguous()
539
+ )
540
 
541
+ seq_mode = inputs_embeds.shape[1] > 1
542
  hidden_states = inputs_embeds
543
+
 
544
  all_self_attentions = () if output_attentions else None
545
  all_hidden_states = () if output_hidden_states else None
546
  for idx, block in enumerate(self.blocks):
 
568
  if not return_dict:
569
  return (hidden_states, state, all_hidden_states, all_self_attentions)
570
 
571
+ return Rwkv5Output(
572
  last_hidden_state=hidden_states,
573
  state=state,
574
+ hidden_states=all_hidden_states, # None
575
+ attentions=all_self_attentions, # None
576
  )
577
 
578
  def _rescale_layers(self):
 
591
 
592
  self.layers_are_rescaled = not self.training
593
 
594
+
595
  @add_start_docstrings(
596
  """
597
  The RWKV Model transformer with a language modeling head on top (linear layer with weights tied to the input
 
599
  """,
600
  RWKV_START_DOCSTRING,
601
  )
602
+ class Rwkv5ForCausalLM(Rwkv5PreTrainedModel):
603
+ _tied_weights_keys = ["head.weight"]
604
+
605
  def __init__(self, config):
606
  super().__init__(config)
607
+ self.rwkv = Rwkv5Model(config)
608
  self.head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
609
 
610
  # Initialize weights and apply final processing
 
633
  @add_start_docstrings_to_model_forward(RWKV_INPUTS_DOCSTRING)
634
  @add_code_sample_docstrings(
635
  checkpoint=_CHECKPOINT_FOR_DOC,
636
+ output_type=Rwkv5CausalLMOutput,
637
  config_class=_CONFIG_FOR_DOC,
638
  )
639
  def forward(
 
647
  output_attentions: Optional[bool] = None,
648
  output_hidden_states: Optional[bool] = None,
649
  return_dict: Optional[bool] = None,
650
+ ) -> Union[Tuple, Rwkv5CausalLMOutput]:
651
  r"""
652
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
653
  Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
 
665
  output_hidden_states=output_hidden_states,
666
  return_dict=return_dict,
667
  )
668
+ hidden_states = rwkv_outputs[0]
 
669
 
670
+ logits = self.head(hidden_states)
671
 
672
  loss = None
673
  if labels is not None:
674
+ # https://github.com/BlinkDL/ChatRWKV/blob/main/rwkv_pip_package/src/rwkv/model.py#L984
675
+ loss = torch.tensor(0.0, device=logits.device, dtype=logits.dtype)
 
 
 
 
 
 
676
 
677
  if not return_dict:
678
  output = (logits,) + rwkv_outputs[1:]
679
  return ((loss,) + output) if loss is not None else output
680
 
681
+ return Rwkv5CausalLMOutput(
682
  loss=loss,
683
  logits=logits,
684
  state=rwkv_outputs.state,
685
+ hidden_states=rwkv_outputs.hidden_states,
686
  attentions=rwkv_outputs.attentions,
687
  )
tokenization_rwkv_world.py CHANGED
@@ -12,38 +12,20 @@
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
- """Tokenization classes for OpenAI GPT."""
16
 
17
  import json
18
  import os
19
  from typing import TYPE_CHECKING, List, Optional, Tuple, Union
20
- from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
21
- from transformers.utils import logging, to_py_obj
22
- from transformers.tokenization_utils_base import BatchEncoding
23
-
24
- import bisect
25
- import itertools
26
- import re
27
- import unicodedata
28
- from collections import OrderedDict
29
- from typing import Any, Dict, List, Optional, Tuple, Union, overload
30
 
 
31
  from transformers.tokenization_utils_base import (
32
- ENCODE_KWARGS_DOCSTRING,
33
- ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING,
34
- INIT_TOKENIZER_DOCSTRING,
35
- AddedToken,
36
  BatchEncoding,
37
  EncodedInput,
38
- EncodedInputPair,
39
- PreTokenizedInput,
40
- PreTokenizedInputPair,
41
- PreTrainedTokenizerBase,
42
  TextInput,
43
- TextInputPair,
44
  TruncationStrategy,
45
  )
46
- from transformers.utils import PaddingStrategy, TensorType, add_end_docstrings, logging
47
 
48
 
49
  if TYPE_CHECKING:
@@ -54,11 +36,18 @@ logger = logging.get_logger(__name__)
54
  VOCAB_FILES_NAMES = {
55
  "vocab_file": "rwkv_vocab_v20230424.txt",
56
  }
 
 
 
 
 
 
57
 
58
  class TRIE:
59
  __slots__ = tuple("ch,to,values,front".split(","))
60
- to:list
61
- values:set
 
62
  def __init__(self, front=None, ch=None):
63
  self.ch = ch
64
  self.to = [None for ch in range(256)]
@@ -68,64 +57,59 @@ class TRIE:
68
  def __repr__(self):
69
  fr = self
70
  ret = []
71
- while(fr!=None):
72
- if(fr.ch!=None):
73
  ret.append(fr.ch)
74
  fr = fr.front
75
- return "<TRIE %s %s>"%(ret[::-1], self.values)
76
-
77
- def add(self, key:bytes, idx:int=0, val=None):
78
- if(idx == len(key)):
79
- if(val is None):
80
  val = key
81
  self.values.add(val)
82
  return self
83
  ch = key[idx]
84
- if(self.to[ch] is None):
85
  self.to[ch] = TRIE(front=self, ch=ch)
86
- return self.to[ch].add(key, idx=idx+1, val=val)
87
-
88
- def find_longest(self, key:bytes, idx:int=0):
89
- u:TRIE = self
90
- ch:int = key[idx]
91
-
92
- while(u.to[ch] is not None):
93
  u = u.to[ch]
94
  idx += 1
95
- if(u.values):
96
  ret = idx, u, u.values
97
- if(idx==len(key)):
98
  break
99
  ch = key[idx]
100
  return ret
101
 
 
102
  class RWKVWorldTokenizer(PreTrainedTokenizer):
103
  vocab_files_names = VOCAB_FILES_NAMES
104
  model_input_names = ["input_ids", "attention_mask"]
105
 
106
- def __init__(
107
- self,
108
- vocab_file,
109
- errors="replace",
110
- pad_token="0",
111
- **kwargs
112
- ):
113
  self.add_bos_token = False
114
  self.encoder = {}
115
- sorted = [] # must be already sorted
116
  with open(vocab_file, "r", encoding="utf-8") as f:
117
  lines = f.readlines()
118
  for l in lines:
119
- idx = int(l[:l.index(' ')])
120
- x = eval(l[l.index(' '):l.rindex(' ')])
121
  x = x.encode("utf-8") if isinstance(x, str) else x
122
  assert isinstance(x, bytes)
123
- assert len(x) == int(l[l.rindex(' '):])
124
  sorted += [x]
125
  self.encoder[idx] = x
126
-
127
  self.decoder = {}
128
- for k,v in self.encoder.items():
129
  self.decoder[v] = int(k)
130
 
131
  self.trie = TRIE()
@@ -134,13 +118,18 @@ class RWKVWorldTokenizer(PreTrainedTokenizer):
134
  self.errors = errors # how to handle errors in decoding
135
  self.cache = {}
136
  self.first_max_length = 0
137
-
138
- # pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
139
  super().__init__(
140
  errors=errors,
141
- # pad_token=pad_token,
142
  **kwargs,
143
  )
 
 
 
 
 
 
 
 
144
 
145
  @property
146
  def vocab_size(self):
@@ -148,12 +137,12 @@ class RWKVWorldTokenizer(PreTrainedTokenizer):
148
 
149
  def get_vocab(self):
150
  return dict(self.encoder, **self.added_tokens_encoder)
151
-
152
  def add_tokens(self, new_tokens, special_tokens: bool = False):
153
  for token in new_tokens:
154
  token_id = self.convert_tokens_to_ids(token)
155
  self.added_tokens_decoder[token_id] = token
156
-
157
  def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
158
  if isinstance(ids, int):
159
  ids = [ids]
@@ -179,8 +168,7 @@ class RWKVWorldTokenizer(PreTrainedTokenizer):
179
  return output + bos_token_ids + token_ids_1
180
 
181
  def get_special_tokens_mask(
182
- self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None,
183
- already_has_special_tokens: bool = False
184
  ) -> List[int]:
185
  """
186
  Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
@@ -211,19 +199,19 @@ class RWKVWorldTokenizer(PreTrainedTokenizer):
211
  return [1] + ([0] * len(token_ids_0))
212
  return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1))
213
 
214
- def encodeBytes(self, src:bytes):
215
- idx:int = 0
216
  tokens = []
217
- while (idx < len(src)):
218
- _idx:int = idx
219
  idx, _, values = self.trie.find_longest(src, idx)
220
- assert(idx != _idx)
221
- _, token = next(iter(values))
222
  tokens.append(token)
223
  return tokens
224
-
225
  def decodeBytes(self, tokens):
226
- return b''.join(map(lambda i: self.encoder[i], tokens))
227
 
228
  def _tokenize(self, text, **kwargs):
229
  """Tokenize a string."""
@@ -231,21 +219,21 @@ class RWKVWorldTokenizer(PreTrainedTokenizer):
231
 
232
  def _decode_tokens(self, tokens):
233
  try:
234
- return self.decodeBytes(tokens).decode('utf-8')
235
- except:
236
- return '\ufffd' # bad utf-8
237
-
238
- def _decode(self,
239
- token_ids: Union[int, List[int], "np.ndarray", "torch.Tensor", "tf.Tensor"],
240
- skip_special_tokens: bool = False,
241
- **kwargs
242
- ) -> str:
243
-
244
  def remove_zeros_from_first_segment(token_ids, first_max_length):
245
  first_segment = token_ids[:first_max_length]
246
  first_segment_cleaned = [token for token in first_segment if token != 0]
247
  return first_segment_cleaned + token_ids[first_max_length:]
248
-
249
  # Convert inputs to python lists
250
  token_ids = to_py_obj(token_ids)
251
  token_ids = remove_zeros_from_first_segment(token_ids, self.first_max_length)
@@ -263,7 +251,7 @@ class RWKVWorldTokenizer(PreTrainedTokenizer):
263
  break
264
  out_tokens += [token]
265
  tmp = self._decode_tokens(out_tokens[out_last:])
266
- if '\ufffd' not in tmp:
267
  out_str += tmp
268
  out_last = i + 1
269
  return out_str
@@ -318,16 +306,29 @@ class RWKVWorldTokenizer(PreTrainedTokenizer):
318
  return_offsets_mapping: bool = False,
319
  return_length: bool = False,
320
  verbose: bool = True,
321
- **kwargs
322
  ) -> BatchEncoding:
323
- def get_input_ids(text):
 
 
 
324
  if isinstance(text, str):
325
- text_id = self._tokenize(text)
326
- return text_id
 
 
 
327
  elif isinstance(text, list) and len(text) > 0 and isinstance(text[0], str):
328
- return [self._tokenize(t) for t in text]
 
 
 
 
329
  elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
 
 
330
  return text
 
331
  else:
332
  raise ValueError(
333
  "Input is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers."
@@ -383,7 +384,7 @@ class RWKVWorldTokenizer(PreTrainedTokenizer):
383
  return_offsets_mapping: bool = False,
384
  return_length: bool = False,
385
  verbose: bool = True,
386
- **kwargs
387
  ) -> BatchEncoding:
388
  def get_input_ids(text, max_length=None, pad_token_id=0):
389
  def pad_sequence(seq, max_len, pad_tok):
@@ -411,7 +412,6 @@ class RWKVWorldTokenizer(PreTrainedTokenizer):
411
  "Input is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers."
412
  )
413
 
414
-
415
  if return_offsets_mapping:
416
  raise NotImplementedError(
417
  "return_offset_mapping is not available when using Python tokenizers. "
@@ -462,10 +462,10 @@ class RWKVWorldTokenizer(PreTrainedTokenizer):
462
  )
463
 
464
  return BatchEncoding(batch_outputs)
465
-
466
  def decode(
467
  self,
468
- token_ids: Union[int, List[int], "np.ndarray", "torch.Tensor", "tf.Tensor"],
469
  skip_special_tokens: bool = False,
470
  clean_up_tokenization_spaces: bool = None,
471
  **kwargs,
@@ -500,7 +500,7 @@ class RWKVWorldTokenizer(PreTrainedTokenizer):
500
 
501
  def batch_decode(
502
  self,
503
- sequences: Union[List[int], List[List[int]], "np.ndarray", "torch.Tensor", "tf.Tensor"],
504
  skip_special_tokens: bool = False,
505
  clean_up_tokenization_spaces: bool = None,
506
  **kwargs,
@@ -537,5 +537,5 @@ class RWKVWorldTokenizer(PreTrainedTokenizer):
537
  for is_user, text in conversation.iter_texts():
538
  input_ids.extend(self.encode(text, add_special_tokens=False) + [self.eos_token_id])
539
  if len(input_ids) > self.model_max_length:
540
- input_ids = input_ids[-self.model_max_length:]
541
  return input_ids
 
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
+ """Tokenization classes for RWKV5."""
16
 
17
  import json
18
  import os
19
  from typing import TYPE_CHECKING, List, Optional, Tuple, Union
 
 
 
 
 
 
 
 
 
 
20
 
21
+ from transformers.tokenization_utils import PreTrainedTokenizer
22
  from transformers.tokenization_utils_base import (
 
 
 
 
23
  BatchEncoding,
24
  EncodedInput,
 
 
 
 
25
  TextInput,
 
26
  TruncationStrategy,
27
  )
28
+ from transformers.utils import PaddingStrategy, TensorType, logging, to_py_obj
29
 
30
 
31
  if TYPE_CHECKING:
 
36
  VOCAB_FILES_NAMES = {
37
  "vocab_file": "rwkv_vocab_v20230424.txt",
38
  }
39
+ PRETRAINED_VOCAB_FILES_MAP = {
40
+ "vocab_file": {
41
+ "RWKV/rwkv-5-world-169m": "https://huggingface.co/RWKV/rwkv-5-world-169m/blob/main/rwkv_vocab_v20230424.txt",
42
+ },
43
+ }
44
+
45
 
46
  class TRIE:
47
  __slots__ = tuple("ch,to,values,front".split(","))
48
+ to: list
49
+ values: set
50
+
51
  def __init__(self, front=None, ch=None):
52
  self.ch = ch
53
  self.to = [None for ch in range(256)]
 
57
  def __repr__(self):
58
  fr = self
59
  ret = []
60
+ while fr is not None:
61
+ if fr.ch is not None:
62
  ret.append(fr.ch)
63
  fr = fr.front
64
+ return "<TRIE %s %s>" % (ret[::-1], self.values)
65
+
66
+ def add(self, key: bytes, idx: int = 0, val=None):
67
+ if idx == len(key):
68
+ if val is None:
69
  val = key
70
  self.values.add(val)
71
  return self
72
  ch = key[idx]
73
+ if self.to[ch] is None:
74
  self.to[ch] = TRIE(front=self, ch=ch)
75
+ return self.to[ch].add(key, idx=idx + 1, val=val)
76
+
77
+ def find_longest(self, key: bytes, idx: int = 0):
78
+ u: TRIE = self
79
+ ch: int = key[idx]
80
+
81
+ while u.to[ch] is not None:
82
  u = u.to[ch]
83
  idx += 1
84
+ if u.values:
85
  ret = idx, u, u.values
86
+ if idx == len(key):
87
  break
88
  ch = key[idx]
89
  return ret
90
 
91
+
92
  class RWKVWorldTokenizer(PreTrainedTokenizer):
93
  vocab_files_names = VOCAB_FILES_NAMES
94
  model_input_names = ["input_ids", "attention_mask"]
95
 
96
+ def __init__(self, vocab_file, errors="replace", pad_token="0", **kwargs):
 
 
 
 
 
 
97
  self.add_bos_token = False
98
  self.encoder = {}
99
+ sorted = [] # must be already sorted
100
  with open(vocab_file, "r", encoding="utf-8") as f:
101
  lines = f.readlines()
102
  for l in lines:
103
+ idx = int(l[: l.index(" ")])
104
+ x = eval(l[l.index(" ") : l.rindex(" ")])
105
  x = x.encode("utf-8") if isinstance(x, str) else x
106
  assert isinstance(x, bytes)
107
+ assert len(x) == int(l[l.rindex(" ") :])
108
  sorted += [x]
109
  self.encoder[idx] = x
110
+
111
  self.decoder = {}
112
+ for k, v in self.encoder.items():
113
  self.decoder[v] = int(k)
114
 
115
  self.trie = TRIE()
 
118
  self.errors = errors # how to handle errors in decoding
119
  self.cache = {}
120
  self.first_max_length = 0
 
 
121
  super().__init__(
122
  errors=errors,
 
123
  **kwargs,
124
  )
125
+
126
+ @property
127
+ def eos_token_id(self) -> Optional[int]:
128
+ return 0
129
+
130
+ @property
131
+ def eot_token_id(self) -> Optional[int]:
132
+ return 0
133
 
134
  @property
135
  def vocab_size(self):
 
137
 
138
  def get_vocab(self):
139
  return dict(self.encoder, **self.added_tokens_encoder)
140
+
141
  def add_tokens(self, new_tokens, special_tokens: bool = False):
142
  for token in new_tokens:
143
  token_id = self.convert_tokens_to_ids(token)
144
  self.added_tokens_decoder[token_id] = token
145
+
146
  def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
147
  if isinstance(ids, int):
148
  ids = [ids]
 
168
  return output + bos_token_ids + token_ids_1
169
 
170
  def get_special_tokens_mask(
171
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
 
172
  ) -> List[int]:
173
  """
174
  Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
 
199
  return [1] + ([0] * len(token_ids_0))
200
  return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1))
201
 
202
+ def encodeBytes(self, src: bytes):
203
+ idx: int = 0
204
  tokens = []
205
+ while idx < len(src):
206
+ _idx: int = idx
207
  idx, _, values = self.trie.find_longest(src, idx)
208
+ assert idx != _idx
209
+ _, token = next(iter(values))
210
  tokens.append(token)
211
  return tokens
212
+
213
  def decodeBytes(self, tokens):
214
+ return b''.join(map(lambda i: self.encoder[i], tokens)) # noqa
215
 
216
  def _tokenize(self, text, **kwargs):
217
  """Tokenize a string."""
 
219
 
220
  def _decode_tokens(self, tokens):
221
  try:
222
+ return self.decodeBytes(tokens).decode("utf-8")
223
+ except Exception:
224
+ return "\ufffd" # bad utf-8
225
+
226
+ def _decode(
227
+ self,
228
+ token_ids: Union[int, List[int]],
229
+ skip_special_tokens: bool = False,
230
+ **kwargs,
231
+ ) -> str:
232
  def remove_zeros_from_first_segment(token_ids, first_max_length):
233
  first_segment = token_ids[:first_max_length]
234
  first_segment_cleaned = [token for token in first_segment if token != 0]
235
  return first_segment_cleaned + token_ids[first_max_length:]
236
+
237
  # Convert inputs to python lists
238
  token_ids = to_py_obj(token_ids)
239
  token_ids = remove_zeros_from_first_segment(token_ids, self.first_max_length)
 
251
  break
252
  out_tokens += [token]
253
  tmp = self._decode_tokens(out_tokens[out_last:])
254
+ if "\ufffd" not in tmp:
255
  out_str += tmp
256
  out_last = i + 1
257
  return out_str
 
306
  return_offsets_mapping: bool = False,
307
  return_length: bool = False,
308
  verbose: bool = True,
309
+ **kwargs,
310
  ) -> BatchEncoding:
311
+ def get_input_ids(text, max_length=None, pad_token_id=0):
312
+ def pad_sequence(seq, max_len, pad_tok):
313
+ return [pad_tok] * (max_len - len(seq)) + seq
314
+
315
  if isinstance(text, str):
316
+ tokens = self._tokenize(text)
317
+ if max_length is not None:
318
+ tokens = pad_sequence(tokens, max_length, pad_token_id)
319
+ return tokens
320
+
321
  elif isinstance(text, list) and len(text) > 0 and isinstance(text[0], str):
322
+ tokenized_texts = [self._tokenize(t) for t in text]
323
+ if max_length is None:
324
+ max_length = max(len(t) for t in tokenized_texts)
325
+ return [pad_sequence(t, max_length, pad_token_id) for t in tokenized_texts]
326
+
327
  elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
328
+ if max_length is not None and len(text) < max_length:
329
+ return pad_sequence(text, max_length, pad_token_id)
330
  return text
331
+
332
  else:
333
  raise ValueError(
334
  "Input is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers."
 
384
  return_offsets_mapping: bool = False,
385
  return_length: bool = False,
386
  verbose: bool = True,
387
+ **kwargs,
388
  ) -> BatchEncoding:
389
  def get_input_ids(text, max_length=None, pad_token_id=0):
390
  def pad_sequence(seq, max_len, pad_tok):
 
412
  "Input is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers."
413
  )
414
 
 
415
  if return_offsets_mapping:
416
  raise NotImplementedError(
417
  "return_offset_mapping is not available when using Python tokenizers. "
 
462
  )
463
 
464
  return BatchEncoding(batch_outputs)
465
+
466
  def decode(
467
  self,
468
+ token_ids: Union[int, List[int]],
469
  skip_special_tokens: bool = False,
470
  clean_up_tokenization_spaces: bool = None,
471
  **kwargs,
 
500
 
501
  def batch_decode(
502
  self,
503
+ sequences: Union[List[int], List[List[int]]],
504
  skip_special_tokens: bool = False,
505
  clean_up_tokenization_spaces: bool = None,
506
  **kwargs,
 
537
  for is_user, text in conversation.iter_texts():
538
  input_ids.extend(self.encode(text, add_special_tokens=False) + [self.eos_token_id])
539
  if len(input_ids) > self.model_max_length:
540
+ input_ids = input_ids[-self.model_max_length :]
541
  return input_ids