KaleiNeely commited on
Commit
f683510
1 Parent(s): 8e8cdaa

Update modeling_rwkv5.py

Browse files
Files changed (1) hide show
  1. modeling_rwkv5.py +181 -18
modeling_rwkv5.py CHANGED
@@ -30,6 +30,8 @@ from transformers.utils import (
30
  add_code_sample_docstrings,
31
  add_start_docstrings,
32
  add_start_docstrings_to_model_forward,
 
 
33
  logging,
34
  )
35
 
@@ -47,8 +49,121 @@ RWKV5_PRETRAINED_MODEL_ARCHIVE_LIST = [
47
  # See all RWKV models at https://huggingface.co/models?filter=rwkv
48
  ]
49
 
 
50
 
51
- def rwkv_linear_attention_v5(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  B,
53
  H,
54
  S,
@@ -66,6 +181,9 @@ def rwkv_linear_attention_v5(
66
  ow,
67
  state,
68
  ):
 
 
 
69
  time_decay = torch.exp(-torch.exp(time_decay.float())).reshape(-1, 1, 1).reshape(n_head, -1, 1)
70
  time_first = time_first.float().reshape(-1, 1, 1).reshape(n_head, -1, 1)
71
  lxw = lxw.float()
@@ -88,10 +206,66 @@ def rwkv_linear_attention_v5(
88
  return out, state
89
 
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  class RwkvSelfAttention(nn.Module):
92
  def __init__(self, config, layer_id=0):
93
  super().__init__()
94
  self.config = config
 
 
 
 
 
 
95
  self.layer_id = layer_id
96
  hidden_size = config.hidden_size
97
  # https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4neo/src/model.py#L146
@@ -136,9 +310,9 @@ class RwkvSelfAttention(nn.Module):
136
  gate = hidden * self.time_mix_gate + shifted * (1 - self.time_mix_gate)
137
 
138
  # https://github.com/BlinkDL/ChatRWKV/blob/main/rwkv_pip_package/src/rwkv/model.py#L693
139
- key = self.key(key).to(torch.float32).view(B, T, H, S).transpose(1, 2).transpose(-2, -1)
140
- value = self.value(value).to(torch.float32).view(B, T, H, S).transpose(1, 2)
141
- receptance = self.receptance(receptance).to(torch.float32).view(B, T, H, S).transpose(1, 2)
142
  gate = F.silu(self.gate(gate))
143
 
144
  if state is not None:
@@ -154,7 +328,7 @@ class RwkvSelfAttention(nn.Module):
154
 
155
  receptance, key, value, gate, state = self.extract_key_value(B, H, S, T, hidden, state=state)
156
  layer_state = state[1][:, :, :, :, self.layer_id] if state is not None else None
157
- rwkv, layer_state = rwkv_linear_attention_v5(
158
  B,
159
  H,
160
  S,
@@ -238,6 +412,8 @@ class RwkvBlock(nn.Module):
238
  self.feed_forward = RwkvFeedForward(config, layer_id)
239
 
240
  def forward(self, hidden, state=None, use_cache=False, output_attentions=False, seq_mode=True):
 
 
241
  attention, state = self.attention(self.ln1(hidden), state=state, use_cache=use_cache, seq_mode=seq_mode)
242
  hidden = hidden + attention
243
 
@@ -264,7 +440,6 @@ class Rwkv5PreTrainedModel(PreTrainedModel):
264
  _no_split_modules = ["RwkvBlock"]
265
  _keep_in_fp32_modules = ["time_decay", "time_first"]
266
  supports_gradient_checkpointing = True
267
- training = False
268
 
269
  def _init_weights(self, module):
270
  """Initialize the weights."""
@@ -440,8 +615,6 @@ class Rwkv5Model(Rwkv5PreTrainedModel):
440
  self.ln_out = nn.LayerNorm(config.hidden_size)
441
 
442
  self.layers_are_rescaled = False
443
- self.pre_ln_flag = False
444
-
445
  self.gradient_checkpointing = False
446
 
447
  # Initialize weights and apply final processing
@@ -489,16 +662,6 @@ class Rwkv5Model(Rwkv5PreTrainedModel):
489
  raise ValueError("You have to specify either input_ids or inputs_embeds")
490
 
491
  if inputs_embeds is None:
492
- if not self.pre_ln_flag:
493
- normalized_weight = F.layer_norm(
494
- self.embeddings.weight,
495
- (self.config.hidden_size,),
496
- weight=self.blocks[0].pre_ln.weight,
497
- bias=self.blocks[0].pre_ln.bias,
498
- )
499
- self.embeddings.weight = nn.Parameter(normalized_weight)
500
- self.pre_ln_flag = True
501
-
502
  inputs_embeds = self.embeddings(input_ids)
503
 
504
  if use_cache and state is None:
 
30
  add_code_sample_docstrings,
31
  add_start_docstrings,
32
  add_start_docstrings_to_model_forward,
33
+ is_ninja_available,
34
+ is_torch_cuda_available,
35
  logging,
36
  )
37
 
 
49
  # See all RWKV models at https://huggingface.co/models?filter=rwkv
50
  ]
51
 
52
+ rwkv5_cuda_kernel = None
53
 
54
+
55
+ def load_wkv5_cuda_kernel(head_size):
56
+ from torch.utils.cpp_extension import load as load_kernel
57
+
58
+ global rwkv5_cuda_kernel
59
+
60
+ kernel_folder = Path(__file__).resolve().parent.parent.parent / "kernels" / "rwkv5"
61
+ cuda_kernel_files = [kernel_folder / f for f in ["wkv5_op.cpp", "wkv5_cuda.cu"]]
62
+
63
+ # Only load the kernel if it's not been loaded yet or if we changed the context length
64
+ if rwkv5_cuda_kernel is not None and rwkv5_cuda_kernel.head_size == head_size:
65
+ return
66
+
67
+ logger.info(f"Loading CUDA kernel for RWKV at head size of {head_size}.")
68
+
69
+ flags = [
70
+ "-res-usage",
71
+ "--maxrregcount 60",
72
+ "--use_fast_math",
73
+ "-O3",
74
+ "-Xptxas -O3",
75
+ "--extra-device-vectorization",
76
+ f"-D_N_={head_size}",
77
+ ]
78
+ rwkv5_cuda_kernel = load_kernel(
79
+ name=f"wkv_{head_size}",
80
+ sources=cuda_kernel_files,
81
+ verbose=(logging.get_verbosity() == logging.DEBUG),
82
+ extra_cuda_cflags=flags,
83
+ )
84
+ rwkv5_cuda_kernel.head_size = head_size
85
+
86
+
87
+ class WKV_5(torch.autograd.Function):
88
+ @staticmethod
89
+ def forward(ctx, B, T, C, H, r, k, v, w, u, s):
90
+ with torch.no_grad():
91
+ assert r.dtype == torch.bfloat16
92
+ assert k.dtype == torch.bfloat16
93
+ assert v.dtype == torch.bfloat16
94
+ assert w.dtype == torch.bfloat16
95
+ assert u.dtype == torch.bfloat16
96
+ assert s.dtype == torch.float32
97
+ ctx.B = B
98
+ ctx.T = T
99
+ ctx.C = C
100
+ ctx.H = H
101
+ assert r.is_contiguous()
102
+ assert k.is_contiguous()
103
+ assert v.is_contiguous()
104
+ assert w.is_contiguous()
105
+ assert u.is_contiguous()
106
+ ew = (-torch.exp(w.float())).contiguous()
107
+ eew = (torch.exp(ew)).contiguous()
108
+ ctx.save_for_backward(r, k, v, eew, ew, u)
109
+ y = torch.empty(
110
+ (B, T, C), device=r.device, dtype=torch.bfloat16, memory_format=torch.contiguous_format
111
+ ) # .uniform_(-1, 1)
112
+ rwkv5_cuda_kernel.forward(B, T, C, H, r, k, v, eew, u, y, s)
113
+ return y, s
114
+
115
+ @staticmethod
116
+ def backward(ctx, gy):
117
+ with torch.no_grad():
118
+ assert gy.dtype == torch.bfloat16
119
+ B = ctx.B
120
+ T = ctx.T
121
+ C = ctx.C
122
+ H = ctx.H
123
+ assert gy.is_contiguous()
124
+ r, k, v, eew, ew, u = ctx.saved_tensors
125
+ gr = torch.empty(
126
+ (B, T, C),
127
+ device=gy.device,
128
+ requires_grad=False,
129
+ dtype=torch.bfloat16,
130
+ memory_format=torch.contiguous_format,
131
+ ) # .uniform_(-1, 1)
132
+ gk = torch.empty(
133
+ (B, T, C),
134
+ device=gy.device,
135
+ requires_grad=False,
136
+ dtype=torch.bfloat16,
137
+ memory_format=torch.contiguous_format,
138
+ ) # .uniform_(-1, 1)
139
+ gv = torch.empty(
140
+ (B, T, C),
141
+ device=gy.device,
142
+ requires_grad=False,
143
+ dtype=torch.bfloat16,
144
+ memory_format=torch.contiguous_format,
145
+ ) # .uniform_(-1, 1)
146
+ gw = torch.empty(
147
+ (B, C),
148
+ device=gy.device,
149
+ requires_grad=False,
150
+ dtype=torch.bfloat16,
151
+ memory_format=torch.contiguous_format,
152
+ ) # .uniform_(-1, 1)
153
+ gu = torch.empty(
154
+ (B, C),
155
+ device=gy.device,
156
+ requires_grad=False,
157
+ dtype=torch.bfloat16,
158
+ memory_format=torch.contiguous_format,
159
+ ) # .uniform_(-1, 1)
160
+ rwkv5_cuda_kernel.backward(B, T, C, H, r, k, v, eew, ew, u, gy, gr, gk, gv, gw, gu)
161
+ gw = torch.sum(gw, 0).view(H, C // H)
162
+ gu = torch.sum(gu, 0).view(H, C // H)
163
+ return (None, None, None, None, gr, gk, gv, gw, gu)
164
+
165
+
166
+ def rwkv_linear_attention_v5_cpu(
167
  B,
168
  H,
169
  S,
 
181
  ow,
182
  state,
183
  ):
184
+ key = key.to(torch.float32).view(B, T, H, S).transpose(1, 2).transpose(-2, -1)
185
+ value = value.to(torch.float32).view(B, T, H, S).transpose(1, 2)
186
+ receptance = receptance.to(torch.float32).view(B, T, H, S).transpose(1, 2)
187
  time_decay = torch.exp(-torch.exp(time_decay.float())).reshape(-1, 1, 1).reshape(n_head, -1, 1)
188
  time_first = time_first.float().reshape(-1, 1, 1).reshape(n_head, -1, 1)
189
  lxw = lxw.float()
 
206
  return out, state
207
 
208
 
209
+ def rwkv_linear_attention(
210
+ B,
211
+ H,
212
+ S,
213
+ T,
214
+ n_head,
215
+ hidden,
216
+ time_decay,
217
+ time_first,
218
+ receptance,
219
+ key,
220
+ value,
221
+ gate,
222
+ lxw,
223
+ lxb,
224
+ ow,
225
+ state,
226
+ ):
227
+ no_cuda = any(t.device.type != "cuda" for t in [time_decay, time_first, receptance, key, value])
228
+ # Launching the CUDA kernel for just one token will actually be slower (there is no for loop in the CPU version
229
+ # in this case).
230
+ one_token = key.size(1) == 1
231
+ if rwkv5_cuda_kernel is None or no_cuda or one_token:
232
+ return rwkv_linear_attention_v5_cpu(
233
+ B,
234
+ H,
235
+ S,
236
+ T,
237
+ n_head,
238
+ hidden,
239
+ time_decay,
240
+ time_first,
241
+ receptance,
242
+ key,
243
+ value,
244
+ gate,
245
+ lxw,
246
+ lxb,
247
+ ow,
248
+ state,
249
+ )
250
+ else:
251
+ out, state = WKV_5.apply(B, T, H * S, H, receptance, key, value, time_decay, time_first, state)
252
+ out = out.reshape(B * T, H * S)
253
+ out = F.group_norm(out, num_groups=H, weight=lxw, bias=lxb).reshape(B, T, H * S)
254
+ out = out.to(dtype=hidden.dtype) * gate
255
+ out = out @ ow
256
+ return out, state
257
+
258
+
259
  class RwkvSelfAttention(nn.Module):
260
  def __init__(self, config, layer_id=0):
261
  super().__init__()
262
  self.config = config
263
+ kernel_loaded = rwkv5_cuda_kernel is not None and rwkv5_cuda_kernel.head_size == config.head_size
264
+ if is_ninja_available() and is_torch_cuda_available() and not kernel_loaded:
265
+ try:
266
+ load_wkv5_cuda_kernel(config.context_length)
267
+ except Exception:
268
+ logger.info("Could not load the custom CUDA kernel for RWKV5 attention.")
269
  self.layer_id = layer_id
270
  hidden_size = config.hidden_size
271
  # https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4neo/src/model.py#L146
 
310
  gate = hidden * self.time_mix_gate + shifted * (1 - self.time_mix_gate)
311
 
312
  # https://github.com/BlinkDL/ChatRWKV/blob/main/rwkv_pip_package/src/rwkv/model.py#L693
313
+ key = self.key(key)
314
+ value = self.value(value)
315
+ receptance = self.receptance(receptance)
316
  gate = F.silu(self.gate(gate))
317
 
318
  if state is not None:
 
328
 
329
  receptance, key, value, gate, state = self.extract_key_value(B, H, S, T, hidden, state=state)
330
  layer_state = state[1][:, :, :, :, self.layer_id] if state is not None else None
331
+ rwkv, layer_state = rwkv_linear_attention(
332
  B,
333
  H,
334
  S,
 
412
  self.feed_forward = RwkvFeedForward(config, layer_id)
413
 
414
  def forward(self, hidden, state=None, use_cache=False, output_attentions=False, seq_mode=True):
415
+ if self.layer_id == 0:
416
+ hidden = self.pre_ln(hidden)
417
  attention, state = self.attention(self.ln1(hidden), state=state, use_cache=use_cache, seq_mode=seq_mode)
418
  hidden = hidden + attention
419
 
 
440
  _no_split_modules = ["RwkvBlock"]
441
  _keep_in_fp32_modules = ["time_decay", "time_first"]
442
  supports_gradient_checkpointing = True
 
443
 
444
  def _init_weights(self, module):
445
  """Initialize the weights."""
 
615
  self.ln_out = nn.LayerNorm(config.hidden_size)
616
 
617
  self.layers_are_rescaled = False
 
 
618
  self.gradient_checkpointing = False
619
 
620
  # Initialize weights and apply final processing
 
662
  raise ValueError("You have to specify either input_ids or inputs_embeds")
663
 
664
  if inputs_embeds is None:
 
 
 
 
 
 
 
 
 
 
665
  inputs_embeds = self.embeddings(input_ids)
666
 
667
  if use_cache and state is None: