KaleiNeely commited on
Commit
7a56a0e
1 Parent(s): aa8a466

Update modeling_rwkv5.py

Browse files
Files changed (1) hide show
  1. modeling_rwkv5.py +189 -149
modeling_rwkv5.py CHANGED
@@ -16,6 +16,7 @@
16
  """PyTorch RWKV5 World model."""
17
 
18
  from dataclasses import dataclass
 
19
  from typing import List, Optional, Tuple, Union
20
 
21
  import torch
@@ -30,6 +31,7 @@ from transformers.utils import (
30
  add_code_sample_docstrings,
31
  add_start_docstrings,
32
  add_start_docstrings_to_model_forward,
 
33
  is_ninja_available,
34
  is_torch_cuda_available,
35
  logging,
@@ -52,6 +54,7 @@ RWKV5_PRETRAINED_MODEL_ARCHIVE_LIST = [
52
  rwkv5_cuda_kernel = None
53
 
54
 
 
55
  def load_wkv5_cuda_kernel(head_size):
56
  from torch.utils.cpp_extension import load as load_kernel
57
 
@@ -86,89 +89,108 @@ def load_wkv5_cuda_kernel(head_size):
86
 
87
  class WKV_5(torch.autograd.Function):
88
  @staticmethod
89
- def forward(ctx, B, T, C, H, r, k, v, w, u, s):
90
  with torch.no_grad():
91
- assert r.dtype == torch.bfloat16
92
- assert k.dtype == torch.bfloat16
93
- assert v.dtype == torch.bfloat16
94
- assert w.dtype == torch.bfloat16
95
- assert u.dtype == torch.bfloat16
96
- assert s.dtype == torch.float32
97
- ctx.B = B
98
- ctx.T = T
99
- ctx.C = C
100
- ctx.H = H
101
- assert r.is_contiguous()
102
- assert k.is_contiguous()
103
- assert v.is_contiguous()
104
- assert w.is_contiguous()
105
- assert u.is_contiguous()
106
- ew = (-torch.exp(w.float())).contiguous()
107
- eew = (torch.exp(ew)).contiguous()
108
- ctx.save_for_backward(r, k, v, eew, ew, u)
109
- y = torch.empty(
110
- (B, T, C), device=r.device, dtype=torch.bfloat16, memory_format=torch.contiguous_format
111
- ) # .uniform_(-1, 1)
112
- rwkv5_cuda_kernel.forward(B, T, C, H, r, k, v, eew, u, y, s)
113
- return y, s
 
 
 
 
 
 
 
 
114
 
115
  @staticmethod
116
- def backward(ctx, gy):
117
  with torch.no_grad():
118
- assert gy.dtype == torch.bfloat16
119
- B = ctx.B
120
- T = ctx.T
121
- C = ctx.C
122
- H = ctx.H
123
- assert gy.is_contiguous()
124
- r, k, v, eew, ew, u = ctx.saved_tensors
125
- gr = torch.empty(
126
- (B, T, C),
127
- device=gy.device,
128
  requires_grad=False,
129
  dtype=torch.bfloat16,
130
  memory_format=torch.contiguous_format,
131
- ) # .uniform_(-1, 1)
132
- gk = torch.empty(
133
- (B, T, C),
134
- device=gy.device,
135
  requires_grad=False,
136
  dtype=torch.bfloat16,
137
  memory_format=torch.contiguous_format,
138
- ) # .uniform_(-1, 1)
139
- gv = torch.empty(
140
- (B, T, C),
141
- device=gy.device,
142
  requires_grad=False,
143
  dtype=torch.bfloat16,
144
  memory_format=torch.contiguous_format,
145
- ) # .uniform_(-1, 1)
146
- gw = torch.empty(
147
- (B, C),
148
- device=gy.device,
149
  requires_grad=False,
150
  dtype=torch.bfloat16,
151
  memory_format=torch.contiguous_format,
152
- ) # .uniform_(-1, 1)
153
- gu = torch.empty(
154
- (B, C),
155
- device=gy.device,
156
  requires_grad=False,
157
  dtype=torch.bfloat16,
158
  memory_format=torch.contiguous_format,
159
- ) # .uniform_(-1, 1)
160
- rwkv5_cuda_kernel.backward(B, T, C, H, r, k, v, eew, ew, u, gy, gr, gk, gv, gw, gu)
161
- gw = torch.sum(gw, 0).view(H, C // H)
162
- gu = torch.sum(gu, 0).view(H, C // H)
163
- return (None, None, None, None, gr, gk, gv, gw, gu)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
 
166
  def rwkv_linear_attention_v5_cpu(
167
- B,
168
- H,
169
- S,
170
- T,
171
- n_head,
172
  hidden,
173
  time_decay,
174
  time_first,
@@ -176,20 +198,24 @@ def rwkv_linear_attention_v5_cpu(
176
  key,
177
  value,
178
  gate,
179
- lxw,
180
- lxb,
181
- ow,
182
  state,
183
  ):
184
- key = key.to(torch.float32).view(B, T, H, S).transpose(1, 2).transpose(-2, -1)
185
- value = value.to(torch.float32).view(B, T, H, S).transpose(1, 2)
186
- receptance = receptance.to(torch.float32).view(B, T, H, S).transpose(1, 2)
187
- time_decay = torch.exp(-torch.exp(time_decay.float())).reshape(-1, 1, 1).reshape(n_head, -1, 1)
188
- time_first = time_first.float().reshape(-1, 1, 1).reshape(n_head, -1, 1)
189
- lxw = lxw.float()
190
- lxb = lxb.float()
191
- out = torch.zeros_like(key).reshape(B, T, H, S)
192
- for t in range(T):
 
 
 
 
193
  rt = receptance[:, :, t : t + 1, :]
194
  kt = key[:, :, :, t : t + 1]
195
  vt = value[:, :, t : t + 1, :]
@@ -198,20 +224,17 @@ def rwkv_linear_attention_v5_cpu(
198
  with torch.no_grad():
199
  state = at + time_decay * state
200
 
201
- out = out.reshape(B * T, H * S)
202
- out = F.group_norm(out, num_groups=H, weight=lxw, bias=lxb).reshape(B, T, H * S)
 
 
203
  out = out.to(dtype=hidden.dtype) * gate
204
- out = out @ ow
205
 
206
  return out, state
207
 
208
 
209
  def rwkv_linear_attention(
210
- B,
211
- H,
212
- S,
213
- T,
214
- n_head,
215
  hidden,
216
  time_decay,
217
  time_first,
@@ -219,22 +242,21 @@ def rwkv_linear_attention(
219
  key,
220
  value,
221
  gate,
222
- lxw,
223
- lxb,
224
- ow,
225
  state,
226
  ):
 
 
 
 
227
  no_cuda = any(t.device.type != "cuda" for t in [time_decay, time_first, receptance, key, value])
228
  # Launching the CUDA kernel for just one token will actually be slower (there is no for loop in the CPU version
229
  # in this case).
230
  one_token = key.size(1) == 1
231
  if rwkv5_cuda_kernel is None or no_cuda or one_token:
232
  return rwkv_linear_attention_v5_cpu(
233
- B,
234
- H,
235
- S,
236
- T,
237
- n_head,
238
  hidden,
239
  time_decay,
240
  time_first,
@@ -242,17 +264,30 @@ def rwkv_linear_attention(
242
  key,
243
  value,
244
  gate,
245
- lxw,
246
- lxb,
247
- ow,
248
  state,
249
  )
250
  else:
251
- out, state = WKV_5.apply(B, T, H * S, H, receptance, key, value, time_decay, time_first, state)
252
- out = out.reshape(B * T, H * S)
253
- out = F.group_norm(out, num_groups=H, weight=lxw, bias=lxb).reshape(B, T, H * S)
 
 
 
 
 
 
 
 
 
 
 
 
 
254
  out = out.to(dtype=hidden.dtype) * gate
255
- out = out @ ow
256
  return out, state
257
 
258
 
@@ -268,7 +303,6 @@ class RwkvSelfAttention(nn.Module):
268
  logger.info("Could not load the custom CUDA kernel for RWKV5 attention.")
269
  self.layer_id = layer_id
270
  hidden_size = config.hidden_size
271
- # https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4neo/src/model.py#L146
272
  num_attention_heads = hidden_size // config.head_size
273
  self.num_attention_heads = num_attention_heads
274
  attention_hidden_size = (
@@ -290,11 +324,9 @@ class RwkvSelfAttention(nn.Module):
290
  self.receptance = nn.Linear(hidden_size, attention_hidden_size, bias=False)
291
  self.gate = nn.Linear(hidden_size, attention_hidden_size, bias=False)
292
  self.output = nn.Linear(attention_hidden_size, hidden_size, bias=False)
293
- # https://github.com/BlinkDL/RWKV-LM/blob/3db37a72356b736966ddd377268f02b80963af3f/RWKV-v4neo/src/model.py#L190C1-L190C1
294
  self.ln_x = nn.GroupNorm(hidden_size // config.head_size, hidden_size)
295
 
296
- # TODO: maybe jit, otherwise move inside forward
297
- def extract_key_value(self, B, H, S, T, hidden, state=None):
298
  # Mix hidden with the previous timestep to produce key, value, receptance
299
  if hidden.size(1) == 1 and state is not None:
300
  shifted = state[0][:, :, self.layer_id]
@@ -309,7 +341,6 @@ class RwkvSelfAttention(nn.Module):
309
  receptance = hidden * self.time_mix_receptance + shifted * (1 - self.time_mix_receptance)
310
  gate = hidden * self.time_mix_gate + shifted * (1 - self.time_mix_gate)
311
 
312
- # https://github.com/BlinkDL/ChatRWKV/blob/main/rwkv_pip_package/src/rwkv/model.py#L693
313
  key = self.key(key)
314
  value = self.value(value)
315
  receptance = self.receptance(receptance)
@@ -321,19 +352,9 @@ class RwkvSelfAttention(nn.Module):
321
  return receptance, key, value, gate, state
322
 
323
  def forward(self, hidden, state=None, use_cache=False, seq_mode=True):
324
- B = hidden.shape[0]
325
- H = self.time_decay.shape[0]
326
- S = hidden.shape[-1] // H
327
- T = hidden.shape[1]
328
-
329
- receptance, key, value, gate, state = self.extract_key_value(B, H, S, T, hidden, state=state)
330
  layer_state = state[1][:, :, :, :, self.layer_id] if state is not None else None
331
  rwkv, layer_state = rwkv_linear_attention(
332
- B,
333
- H,
334
- S,
335
- T,
336
- self.num_attention_heads,
337
  hidden,
338
  self.time_decay,
339
  self.time_faaaa,
@@ -359,7 +380,6 @@ class RwkvFeedForward(nn.Module):
359
  self.config = config
360
  self.layer_id = layer_id
361
  hidden_size = config.hidden_size
362
- # https://github.com/BlinkDL/RWKV-LM/blob/3db37a72356b736966ddd377268f02b80963af3f/RWKV-v4neo/train.py#L168
363
  intermediate_size = (
364
  config.intermediate_size
365
  if config.intermediate_size is not None
@@ -396,7 +416,8 @@ class RwkvFeedForward(nn.Module):
396
  return receptance * value, state
397
 
398
 
399
- class RwkvBlock(nn.Module):
 
400
  def __init__(self, config, layer_id):
401
  super().__init__()
402
  self.config = config
@@ -437,7 +458,7 @@ class Rwkv5PreTrainedModel(PreTrainedModel):
437
 
438
  config_class = Rwkv5Config
439
  base_model_prefix = "rwkv"
440
- _no_split_modules = ["RwkvBlock"]
441
  _keep_in_fp32_modules = ["time_decay", "time_first"]
442
  supports_gradient_checkpointing = True
443
 
@@ -460,7 +481,6 @@ class Rwkv5PreTrainedModel(PreTrainedModel):
460
  )
461
  time_weight = time_weight[None, None, :]
462
 
463
- # https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4neo/src/model.py#L398
464
  decay_speed = [
465
  -6.0 + 5.0 * (h / (attention_hidden_size - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
466
  for h in range(attention_hidden_size)
@@ -503,6 +523,7 @@ class Rwkv5PreTrainedModel(PreTrainedModel):
503
  module.time_mix_receptance.data = torch.pow(time_weight, ratio_1_to_almost0)
504
 
505
 
 
506
  @dataclass
507
  class Rwkv5Output(ModelOutput):
508
  """
@@ -530,6 +551,7 @@ class Rwkv5Output(ModelOutput):
530
  attentions: Optional[Tuple[torch.FloatTensor]] = None
531
 
532
 
 
533
  @dataclass
534
  class Rwkv5CausalLMOutput(ModelOutput):
535
  """
@@ -611,7 +633,7 @@ class Rwkv5Model(Rwkv5PreTrainedModel):
611
  super().__init__(config)
612
 
613
  self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
614
- self.blocks = nn.ModuleList([RwkvBlock(config, layer_id=idx) for idx in range(config.num_hidden_layers)])
615
  self.ln_out = nn.LayerNorm(config.hidden_size)
616
 
617
  self.layers_are_rescaled = False
@@ -665,39 +687,35 @@ class Rwkv5Model(Rwkv5PreTrainedModel):
665
  inputs_embeds = self.embeddings(input_ids)
666
 
667
  if use_cache and state is None:
668
- # https://github.com/BlinkDL/ChatRWKV/blob/main/rwkv_pip_package/src/rwkv/model.py#L904-L906
669
  state = []
670
  num_attention_heads = self.config.hidden_size // self.config.num_attention_heads
671
- state.append(
672
- torch.zeros(
673
- (inputs_embeds.size(0), self.config.hidden_size, self.config.num_hidden_layers),
674
- dtype=inputs_embeds.dtype,
675
- requires_grad=False,
676
- device=inputs_embeds.device,
677
- ).contiguous()
678
- )
679
- state.append(
680
- torch.zeros(
681
- (
682
- inputs_embeds.size(0),
683
- num_attention_heads,
684
- self.config.hidden_size // num_attention_heads,
685
- self.config.hidden_size // num_attention_heads,
686
- self.config.num_hidden_layers,
687
- ),
688
- dtype=torch.float32,
689
- requires_grad=False,
690
- device=inputs_embeds.device,
691
- ).contiguous()
692
- )
693
- state.append(
694
- torch.zeros(
695
- (inputs_embeds.size(0), self.config.hidden_size, self.config.num_hidden_layers),
696
- dtype=inputs_embeds.dtype,
697
- requires_grad=False,
698
- device=inputs_embeds.device,
699
- ).contiguous()
700
- )
701
 
702
  seq_mode = inputs_embeds.shape[1] > 1
703
  hidden_states = inputs_embeds
@@ -752,10 +770,32 @@ class Rwkv5Model(Rwkv5PreTrainedModel):
752
 
753
  self.layers_are_rescaled = not self.training
754
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
755
 
 
756
  @add_start_docstrings(
757
  """
758
- The RWKV Model transformer with a language modeling head on top (linear layer with weights tied to the input
759
  embeddings).
760
  """,
761
  RWKV_START_DOCSTRING,
 
16
  """PyTorch RWKV5 World model."""
17
 
18
  from dataclasses import dataclass
19
+ from pathlib import Path
20
  from typing import List, Optional, Tuple, Union
21
 
22
  import torch
 
31
  add_code_sample_docstrings,
32
  add_start_docstrings,
33
  add_start_docstrings_to_model_forward,
34
+ is_bitsandbytes_available,
35
  is_ninja_available,
36
  is_torch_cuda_available,
37
  logging,
 
54
  rwkv5_cuda_kernel = None
55
 
56
 
57
+ # Copied from https://github.com/huggingface/transformers/blob/18cbaf13dcaca7145f5652aefb9b19734c56c3cd/src/transformers/models/rwkv/modeling_rwkv.py#L65
58
  def load_wkv5_cuda_kernel(head_size):
59
  from torch.utils.cpp_extension import load as load_kernel
60
 
 
89
 
90
  class WKV_5(torch.autograd.Function):
91
  @staticmethod
92
+ def forward(ctx, receptance, key, value, time_decay, time_first, state):
93
  with torch.no_grad():
94
+ Batch = key.shape[0]
95
+ SequenceLength = key.shape[1]
96
+ HiddenSize = key.shape[2]
97
+ HeadSize = HiddenSize // time_decay.shape[0]
98
+ ctx.Batch = Batch
99
+ ctx.SequenceLength = SequenceLength
100
+ ctx.HiddenSize = HiddenSize
101
+ ctx.HeadSize = HeadSize
102
+ e_time_decay = (-torch.exp(time_decay.float())).contiguous()
103
+ ee_time_decay = (torch.exp(e_time_decay)).contiguous()
104
+ ctx.save_for_backward(receptance, key, value, ee_time_decay, e_time_decay, time_first)
105
+ out = torch.empty(
106
+ (Batch, SequenceLength, HiddenSize),
107
+ device=receptance.device,
108
+ dtype=torch.bfloat16,
109
+ memory_format=torch.contiguous_format,
110
+ )
111
+ rwkv5_cuda_kernel.forward(
112
+ Batch,
113
+ SequenceLength,
114
+ HiddenSize,
115
+ HeadSize,
116
+ receptance,
117
+ key,
118
+ value,
119
+ ee_time_decay,
120
+ time_first,
121
+ out,
122
+ state,
123
+ )
124
+ return out, state
125
 
126
  @staticmethod
127
+ def backward(ctx, gout):
128
  with torch.no_grad():
129
+ assert gout.dtype == torch.bfloat16
130
+ Batch = ctx.Batch
131
+ SequenceLength = ctx.SequenceLength
132
+ HiddenSize = ctx.HiddenSize
133
+ HeadSize = ctx.HeadSize
134
+ receptance, key, value, ee_time_decay, e_time_decay, time_first = ctx.saved_tensors
135
+ greceptance = torch.empty(
136
+ (Batch, SequenceLength, HiddenSize),
137
+ device=gout.device,
 
138
  requires_grad=False,
139
  dtype=torch.bfloat16,
140
  memory_format=torch.contiguous_format,
141
+ )
142
+ g_key = torch.empty(
143
+ (Batch, SequenceLength, HiddenSize),
144
+ device=gout.device,
145
  requires_grad=False,
146
  dtype=torch.bfloat16,
147
  memory_format=torch.contiguous_format,
148
+ )
149
+ g_value = torch.empty(
150
+ (Batch, SequenceLength, HiddenSize),
151
+ device=gout.device,
152
  requires_grad=False,
153
  dtype=torch.bfloat16,
154
  memory_format=torch.contiguous_format,
155
+ )
156
+ g_time_decay = torch.empty(
157
+ (Batch, HiddenSize),
158
+ device=gout.device,
159
  requires_grad=False,
160
  dtype=torch.bfloat16,
161
  memory_format=torch.contiguous_format,
162
+ )
163
+ g_time_first = torch.empty(
164
+ (Batch, HiddenSize),
165
+ device=gout.device,
166
  requires_grad=False,
167
  dtype=torch.bfloat16,
168
  memory_format=torch.contiguous_format,
169
+ )
170
+ rwkv5_cuda_kernel.backward(
171
+ Batch,
172
+ SequenceLength,
173
+ HiddenSize,
174
+ HeadSize,
175
+ receptance,
176
+ key,
177
+ value,
178
+ ee_time_decay,
179
+ e_time_decay,
180
+ time_first,
181
+ gout,
182
+ greceptance,
183
+ g_key,
184
+ g_value,
185
+ g_time_decay,
186
+ g_time_first,
187
+ )
188
+ g_time_decay = torch.sum(g_time_decay, 0).view(HeadSize, HiddenSize // HeadSize)
189
+ g_time_first = torch.sum(g_time_first, 0).view(HeadSize, HiddenSize // HeadSize)
190
+ return (None, None, None, None, greceptance, g_key, g_value, g_time_decay, g_time_first)
191
 
192
 
193
  def rwkv_linear_attention_v5_cpu(
 
 
 
 
 
194
  hidden,
195
  time_decay,
196
  time_first,
 
198
  key,
199
  value,
200
  gate,
201
+ layer_norm_weight,
202
+ layer_norm_bias,
203
+ output_weight,
204
  state,
205
  ):
206
+ Batch = hidden.shape[0]
207
+ AttentionHeads = time_decay.shape[0]
208
+ HeadSize = hidden.shape[-1] // AttentionHeads
209
+ SequenceLength = hidden.shape[1]
210
+ key = key.to(torch.float32).view(Batch, SequenceLength, AttentionHeads, HeadSize).transpose(1, 2).transpose(-2, -1)
211
+ value = value.to(torch.float32).view(Batch, SequenceLength, AttentionHeads, HeadSize).transpose(1, 2)
212
+ receptance = receptance.to(torch.float32).view(Batch, SequenceLength, AttentionHeads, HeadSize).transpose(1, 2)
213
+ time_decay = torch.exp(-torch.exp(time_decay.float())).reshape(-1, 1, 1).reshape(AttentionHeads, -1, 1)
214
+ time_first = time_first.float().reshape(-1, 1, 1).reshape(AttentionHeads, -1, 1)
215
+ layer_norm_weight = layer_norm_weight.float()
216
+ layer_norm_bias = layer_norm_bias.float()
217
+ out = torch.zeros_like(key).reshape(Batch, SequenceLength, AttentionHeads, HeadSize)
218
+ for t in range(SequenceLength):
219
  rt = receptance[:, :, t : t + 1, :]
220
  kt = key[:, :, :, t : t + 1]
221
  vt = value[:, :, t : t + 1, :]
 
224
  with torch.no_grad():
225
  state = at + time_decay * state
226
 
227
+ out = out.reshape(Batch * SequenceLength, AttentionHeads * HeadSize)
228
+ out = F.group_norm(out, num_groups=AttentionHeads, weight=layer_norm_weight, bias=layer_norm_bias).reshape(
229
+ Batch, SequenceLength, AttentionHeads * HeadSize
230
+ )
231
  out = out.to(dtype=hidden.dtype) * gate
232
+ out = out @ output_weight
233
 
234
  return out, state
235
 
236
 
237
  def rwkv_linear_attention(
 
 
 
 
 
238
  hidden,
239
  time_decay,
240
  time_first,
 
242
  key,
243
  value,
244
  gate,
245
+ layer_norm_weight,
246
+ layer_norm_bias,
247
+ output_weight,
248
  state,
249
  ):
250
+ Batch = hidden.shape[0]
251
+ AttentionHeads = time_decay.shape[0]
252
+ HeadSize = hidden.shape[-1] // AttentionHeads
253
+ SequenceLength = hidden.shape[1]
254
  no_cuda = any(t.device.type != "cuda" for t in [time_decay, time_first, receptance, key, value])
255
  # Launching the CUDA kernel for just one token will actually be slower (there is no for loop in the CPU version
256
  # in this case).
257
  one_token = key.size(1) == 1
258
  if rwkv5_cuda_kernel is None or no_cuda or one_token:
259
  return rwkv_linear_attention_v5_cpu(
 
 
 
 
 
260
  hidden,
261
  time_decay,
262
  time_first,
 
264
  key,
265
  value,
266
  gate,
267
+ layer_norm_weight,
268
+ layer_norm_bias,
269
+ output_weight,
270
  state,
271
  )
272
  else:
273
+ out, state = WKV_5.apply(
274
+ Batch,
275
+ SequenceLength,
276
+ AttentionHeads * HeadSize,
277
+ AttentionHeads,
278
+ receptance,
279
+ key,
280
+ value,
281
+ time_decay,
282
+ time_first,
283
+ state,
284
+ )
285
+ out = out.reshape(Batch * SequenceLength, AttentionHeads * HeadSize)
286
+ out = F.group_norm(out, num_groups=AttentionHeads, weight=layer_norm_weight, bias=layer_norm_bias).reshape(
287
+ Batch, SequenceLength, AttentionHeads * HeadSize
288
+ )
289
  out = out.to(dtype=hidden.dtype) * gate
290
+ out = out @ output_weight
291
  return out, state
292
 
293
 
 
303
  logger.info("Could not load the custom CUDA kernel for RWKV5 attention.")
304
  self.layer_id = layer_id
305
  hidden_size = config.hidden_size
 
306
  num_attention_heads = hidden_size // config.head_size
307
  self.num_attention_heads = num_attention_heads
308
  attention_hidden_size = (
 
324
  self.receptance = nn.Linear(hidden_size, attention_hidden_size, bias=False)
325
  self.gate = nn.Linear(hidden_size, attention_hidden_size, bias=False)
326
  self.output = nn.Linear(attention_hidden_size, hidden_size, bias=False)
 
327
  self.ln_x = nn.GroupNorm(hidden_size // config.head_size, hidden_size)
328
 
329
+ def extract_key_value(self, hidden, state=None):
 
330
  # Mix hidden with the previous timestep to produce key, value, receptance
331
  if hidden.size(1) == 1 and state is not None:
332
  shifted = state[0][:, :, self.layer_id]
 
341
  receptance = hidden * self.time_mix_receptance + shifted * (1 - self.time_mix_receptance)
342
  gate = hidden * self.time_mix_gate + shifted * (1 - self.time_mix_gate)
343
 
 
344
  key = self.key(key)
345
  value = self.value(value)
346
  receptance = self.receptance(receptance)
 
352
  return receptance, key, value, gate, state
353
 
354
  def forward(self, hidden, state=None, use_cache=False, seq_mode=True):
355
+ receptance, key, value, gate, state = self.extract_key_value(hidden, state=state)
 
 
 
 
 
356
  layer_state = state[1][:, :, :, :, self.layer_id] if state is not None else None
357
  rwkv, layer_state = rwkv_linear_attention(
 
 
 
 
 
358
  hidden,
359
  self.time_decay,
360
  self.time_faaaa,
 
380
  self.config = config
381
  self.layer_id = layer_id
382
  hidden_size = config.hidden_size
 
383
  intermediate_size = (
384
  config.intermediate_size
385
  if config.intermediate_size is not None
 
416
  return receptance * value, state
417
 
418
 
419
+ # copied from HuggingFace https://github.com/huggingface/transformers/blob/main/src/transformers/models/rwkv/modeling_rwkv.py
420
+ class Rwkv5Block(nn.Module):
421
  def __init__(self, config, layer_id):
422
  super().__init__()
423
  self.config = config
 
458
 
459
  config_class = Rwkv5Config
460
  base_model_prefix = "rwkv"
461
+ _no_split_modules = ["Rwkv5Block"]
462
  _keep_in_fp32_modules = ["time_decay", "time_first"]
463
  supports_gradient_checkpointing = True
464
 
 
481
  )
482
  time_weight = time_weight[None, None, :]
483
 
 
484
  decay_speed = [
485
  -6.0 + 5.0 * (h / (attention_hidden_size - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
486
  for h in range(attention_hidden_size)
 
523
  module.time_mix_receptance.data = torch.pow(time_weight, ratio_1_to_almost0)
524
 
525
 
526
+ # copied from HuggingFace https://github.com/huggingface/transformers/blob/main/src/transformers/models/rwkv/modeling_rwkv.py
527
  @dataclass
528
  class Rwkv5Output(ModelOutput):
529
  """
 
551
  attentions: Optional[Tuple[torch.FloatTensor]] = None
552
 
553
 
554
+ # copied from HuggingFace https://github.com/huggingface/transformers/blob/main/src/transformers/models/rwkv/modeling_rwkv.py
555
  @dataclass
556
  class Rwkv5CausalLMOutput(ModelOutput):
557
  """
 
633
  super().__init__(config)
634
 
635
  self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
636
+ self.blocks = nn.ModuleList([Rwkv5Block(config, layer_id=idx) for idx in range(config.num_hidden_layers)])
637
  self.ln_out = nn.LayerNorm(config.hidden_size)
638
 
639
  self.layers_are_rescaled = False
 
687
  inputs_embeds = self.embeddings(input_ids)
688
 
689
  if use_cache and state is None:
 
690
  state = []
691
  num_attention_heads = self.config.hidden_size // self.config.num_attention_heads
692
+ state_attn_x = torch.zeros(
693
+ (inputs_embeds.size(0), self.config.hidden_size, self.config.num_hidden_layers),
694
+ dtype=inputs_embeds.dtype,
695
+ requires_grad=False,
696
+ device=inputs_embeds.device,
697
+ ).contiguous()
698
+ state_attn_kv = torch.zeros(
699
+ (
700
+ inputs_embeds.size(0),
701
+ num_attention_heads,
702
+ self.config.hidden_size // num_attention_heads,
703
+ self.config.hidden_size // num_attention_heads,
704
+ self.config.num_hidden_layers,
705
+ ),
706
+ dtype=torch.float32,
707
+ requires_grad=False,
708
+ device=inputs_embeds.device,
709
+ ).contiguous()
710
+ state_ffn_x = torch.zeros(
711
+ (inputs_embeds.size(0), self.config.hidden_size, self.config.num_hidden_layers),
712
+ dtype=inputs_embeds.dtype,
713
+ requires_grad=False,
714
+ device=inputs_embeds.device,
715
+ ).contiguous()
716
+ state.append(state_attn_x)
717
+ state.append(state_attn_kv)
718
+ state.append(state_ffn_x)
 
 
 
719
 
720
  seq_mode = inputs_embeds.shape[1] > 1
721
  hidden_states = inputs_embeds
 
770
 
771
  self.layers_are_rescaled = not self.training
772
 
773
+ def _bnb_4bit_dequantize_and_rescale(self, target_layer, block_id):
774
+ r"""
775
+ Perform the dequantization and rescaling of the weights of a given layer. After that operation the layer will
776
+ be quantized again.
777
+ """
778
+ if not is_bitsandbytes_available():
779
+ raise ImportError("Please install bitsandbytes to use this method.")
780
+ import bitsandbytes as bnb
781
+
782
+ dequant_weights = bnb.functional.dequantize_4bit(target_layer.weight.data, target_layer.weight.quant_state)
783
+
784
+ dequant_weights.div_(2 ** int(block_id // self.config.rescale_every))
785
+
786
+ # re-quantize the model:
787
+ # we need to put it first on CPU then back to the device
788
+ # this will create an overhead :/
789
+ # We set requires_grad=False as we cannot compute gradients on top of 4bit parameters anyway and to avoid
790
+ # bugs with bnb
791
+ quant_weight = bnb.nn.Params4bit(dequant_weights.to("cpu"), requires_grad=False).to(dequant_weights.device)
792
+ setattr(target_layer, "weight", quant_weight)
793
+
794
 
795
+ # copied from HuggingFace https://github.com/huggingface/transformers/blob/main/src/transformers/models/rwkv/modeling_rwkv.py
796
  @add_start_docstrings(
797
  """
798
+ The RWKV5 Model transformer with a language modeling head on top (linear layer with weights tied to the input
799
  embeddings).
800
  """,
801
  RWKV_START_DOCSTRING,