KaleiNeely
commited on
Commit
•
3edf4ac
1
Parent(s):
1c0f950
Update modeling_rwkv5.py
Browse files- modeling_rwkv5.py +189 -149
modeling_rwkv5.py
CHANGED
@@ -16,6 +16,7 @@
|
|
16 |
"""PyTorch RWKV5 World model."""
|
17 |
|
18 |
from dataclasses import dataclass
|
|
|
19 |
from typing import List, Optional, Tuple, Union
|
20 |
|
21 |
import torch
|
@@ -30,6 +31,7 @@ from transformers.utils import (
|
|
30 |
add_code_sample_docstrings,
|
31 |
add_start_docstrings,
|
32 |
add_start_docstrings_to_model_forward,
|
|
|
33 |
is_ninja_available,
|
34 |
is_torch_cuda_available,
|
35 |
logging,
|
@@ -52,6 +54,7 @@ RWKV5_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
|
52 |
rwkv5_cuda_kernel = None
|
53 |
|
54 |
|
|
|
55 |
def load_wkv5_cuda_kernel(head_size):
|
56 |
from torch.utils.cpp_extension import load as load_kernel
|
57 |
|
@@ -86,89 +89,108 @@ def load_wkv5_cuda_kernel(head_size):
|
|
86 |
|
87 |
class WKV_5(torch.autograd.Function):
|
88 |
@staticmethod
|
89 |
-
def forward(ctx,
|
90 |
with torch.no_grad():
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
ctx.
|
98 |
-
ctx.
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
|
115 |
@staticmethod
|
116 |
-
def backward(ctx,
|
117 |
with torch.no_grad():
|
118 |
-
assert
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
device=gy.device,
|
128 |
requires_grad=False,
|
129 |
dtype=torch.bfloat16,
|
130 |
memory_format=torch.contiguous_format,
|
131 |
-
)
|
132 |
-
|
133 |
-
(
|
134 |
-
device=
|
135 |
requires_grad=False,
|
136 |
dtype=torch.bfloat16,
|
137 |
memory_format=torch.contiguous_format,
|
138 |
-
)
|
139 |
-
|
140 |
-
(
|
141 |
-
device=
|
142 |
requires_grad=False,
|
143 |
dtype=torch.bfloat16,
|
144 |
memory_format=torch.contiguous_format,
|
145 |
-
)
|
146 |
-
|
147 |
-
(
|
148 |
-
device=
|
149 |
requires_grad=False,
|
150 |
dtype=torch.bfloat16,
|
151 |
memory_format=torch.contiguous_format,
|
152 |
-
)
|
153 |
-
|
154 |
-
(
|
155 |
-
device=
|
156 |
requires_grad=False,
|
157 |
dtype=torch.bfloat16,
|
158 |
memory_format=torch.contiguous_format,
|
159 |
-
)
|
160 |
-
rwkv5_cuda_kernel.backward(
|
161 |
-
|
162 |
-
|
163 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
|
165 |
|
166 |
def rwkv_linear_attention_v5_cpu(
|
167 |
-
B,
|
168 |
-
H,
|
169 |
-
S,
|
170 |
-
T,
|
171 |
-
n_head,
|
172 |
hidden,
|
173 |
time_decay,
|
174 |
time_first,
|
@@ -176,20 +198,24 @@ def rwkv_linear_attention_v5_cpu(
|
|
176 |
key,
|
177 |
value,
|
178 |
gate,
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
state,
|
183 |
):
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
|
|
|
|
|
|
|
|
193 |
rt = receptance[:, :, t : t + 1, :]
|
194 |
kt = key[:, :, :, t : t + 1]
|
195 |
vt = value[:, :, t : t + 1, :]
|
@@ -198,20 +224,17 @@ def rwkv_linear_attention_v5_cpu(
|
|
198 |
with torch.no_grad():
|
199 |
state = at + time_decay * state
|
200 |
|
201 |
-
out = out.reshape(
|
202 |
-
out = F.group_norm(out, num_groups=
|
|
|
|
|
203 |
out = out.to(dtype=hidden.dtype) * gate
|
204 |
-
out = out @
|
205 |
|
206 |
return out, state
|
207 |
|
208 |
|
209 |
def rwkv_linear_attention(
|
210 |
-
B,
|
211 |
-
H,
|
212 |
-
S,
|
213 |
-
T,
|
214 |
-
n_head,
|
215 |
hidden,
|
216 |
time_decay,
|
217 |
time_first,
|
@@ -219,22 +242,21 @@ def rwkv_linear_attention(
|
|
219 |
key,
|
220 |
value,
|
221 |
gate,
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
state,
|
226 |
):
|
|
|
|
|
|
|
|
|
227 |
no_cuda = any(t.device.type != "cuda" for t in [time_decay, time_first, receptance, key, value])
|
228 |
# Launching the CUDA kernel for just one token will actually be slower (there is no for loop in the CPU version
|
229 |
# in this case).
|
230 |
one_token = key.size(1) == 1
|
231 |
if rwkv5_cuda_kernel is None or no_cuda or one_token:
|
232 |
return rwkv_linear_attention_v5_cpu(
|
233 |
-
B,
|
234 |
-
H,
|
235 |
-
S,
|
236 |
-
T,
|
237 |
-
n_head,
|
238 |
hidden,
|
239 |
time_decay,
|
240 |
time_first,
|
@@ -242,17 +264,30 @@ def rwkv_linear_attention(
|
|
242 |
key,
|
243 |
value,
|
244 |
gate,
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
state,
|
249 |
)
|
250 |
else:
|
251 |
-
out, state = WKV_5.apply(
|
252 |
-
|
253 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
254 |
out = out.to(dtype=hidden.dtype) * gate
|
255 |
-
out = out @
|
256 |
return out, state
|
257 |
|
258 |
|
@@ -268,7 +303,6 @@ class RwkvSelfAttention(nn.Module):
|
|
268 |
logger.info("Could not load the custom CUDA kernel for RWKV5 attention.")
|
269 |
self.layer_id = layer_id
|
270 |
hidden_size = config.hidden_size
|
271 |
-
# https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4neo/src/model.py#L146
|
272 |
num_attention_heads = hidden_size // config.head_size
|
273 |
self.num_attention_heads = num_attention_heads
|
274 |
attention_hidden_size = (
|
@@ -290,11 +324,9 @@ class RwkvSelfAttention(nn.Module):
|
|
290 |
self.receptance = nn.Linear(hidden_size, attention_hidden_size, bias=False)
|
291 |
self.gate = nn.Linear(hidden_size, attention_hidden_size, bias=False)
|
292 |
self.output = nn.Linear(attention_hidden_size, hidden_size, bias=False)
|
293 |
-
# https://github.com/BlinkDL/RWKV-LM/blob/3db37a72356b736966ddd377268f02b80963af3f/RWKV-v4neo/src/model.py#L190C1-L190C1
|
294 |
self.ln_x = nn.GroupNorm(hidden_size // config.head_size, hidden_size)
|
295 |
|
296 |
-
|
297 |
-
def extract_key_value(self, B, H, S, T, hidden, state=None):
|
298 |
# Mix hidden with the previous timestep to produce key, value, receptance
|
299 |
if hidden.size(1) == 1 and state is not None:
|
300 |
shifted = state[0][:, :, self.layer_id]
|
@@ -309,7 +341,6 @@ class RwkvSelfAttention(nn.Module):
|
|
309 |
receptance = hidden * self.time_mix_receptance + shifted * (1 - self.time_mix_receptance)
|
310 |
gate = hidden * self.time_mix_gate + shifted * (1 - self.time_mix_gate)
|
311 |
|
312 |
-
# https://github.com/BlinkDL/ChatRWKV/blob/main/rwkv_pip_package/src/rwkv/model.py#L693
|
313 |
key = self.key(key)
|
314 |
value = self.value(value)
|
315 |
receptance = self.receptance(receptance)
|
@@ -321,19 +352,9 @@ class RwkvSelfAttention(nn.Module):
|
|
321 |
return receptance, key, value, gate, state
|
322 |
|
323 |
def forward(self, hidden, state=None, use_cache=False, seq_mode=True):
|
324 |
-
|
325 |
-
H = self.time_decay.shape[0]
|
326 |
-
S = hidden.shape[-1] // H
|
327 |
-
T = hidden.shape[1]
|
328 |
-
|
329 |
-
receptance, key, value, gate, state = self.extract_key_value(B, H, S, T, hidden, state=state)
|
330 |
layer_state = state[1][:, :, :, :, self.layer_id] if state is not None else None
|
331 |
rwkv, layer_state = rwkv_linear_attention(
|
332 |
-
B,
|
333 |
-
H,
|
334 |
-
S,
|
335 |
-
T,
|
336 |
-
self.num_attention_heads,
|
337 |
hidden,
|
338 |
self.time_decay,
|
339 |
self.time_faaaa,
|
@@ -359,7 +380,6 @@ class RwkvFeedForward(nn.Module):
|
|
359 |
self.config = config
|
360 |
self.layer_id = layer_id
|
361 |
hidden_size = config.hidden_size
|
362 |
-
# https://github.com/BlinkDL/RWKV-LM/blob/3db37a72356b736966ddd377268f02b80963af3f/RWKV-v4neo/train.py#L168
|
363 |
intermediate_size = (
|
364 |
config.intermediate_size
|
365 |
if config.intermediate_size is not None
|
@@ -396,7 +416,8 @@ class RwkvFeedForward(nn.Module):
|
|
396 |
return receptance * value, state
|
397 |
|
398 |
|
399 |
-
|
|
|
400 |
def __init__(self, config, layer_id):
|
401 |
super().__init__()
|
402 |
self.config = config
|
@@ -437,7 +458,7 @@ class Rwkv5PreTrainedModel(PreTrainedModel):
|
|
437 |
|
438 |
config_class = Rwkv5Config
|
439 |
base_model_prefix = "rwkv"
|
440 |
-
_no_split_modules = ["
|
441 |
_keep_in_fp32_modules = ["time_decay", "time_first"]
|
442 |
supports_gradient_checkpointing = True
|
443 |
|
@@ -460,7 +481,6 @@ class Rwkv5PreTrainedModel(PreTrainedModel):
|
|
460 |
)
|
461 |
time_weight = time_weight[None, None, :]
|
462 |
|
463 |
-
# https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4neo/src/model.py#L398
|
464 |
decay_speed = [
|
465 |
-6.0 + 5.0 * (h / (attention_hidden_size - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
|
466 |
for h in range(attention_hidden_size)
|
@@ -503,6 +523,7 @@ class Rwkv5PreTrainedModel(PreTrainedModel):
|
|
503 |
module.time_mix_receptance.data = torch.pow(time_weight, ratio_1_to_almost0)
|
504 |
|
505 |
|
|
|
506 |
@dataclass
|
507 |
class Rwkv5Output(ModelOutput):
|
508 |
"""
|
@@ -530,6 +551,7 @@ class Rwkv5Output(ModelOutput):
|
|
530 |
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
531 |
|
532 |
|
|
|
533 |
@dataclass
|
534 |
class Rwkv5CausalLMOutput(ModelOutput):
|
535 |
"""
|
@@ -611,7 +633,7 @@ class Rwkv5Model(Rwkv5PreTrainedModel):
|
|
611 |
super().__init__(config)
|
612 |
|
613 |
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
|
614 |
-
self.blocks = nn.ModuleList([
|
615 |
self.ln_out = nn.LayerNorm(config.hidden_size)
|
616 |
|
617 |
self.layers_are_rescaled = False
|
@@ -665,39 +687,35 @@ class Rwkv5Model(Rwkv5PreTrainedModel):
|
|
665 |
inputs_embeds = self.embeddings(input_ids)
|
666 |
|
667 |
if use_cache and state is None:
|
668 |
-
# https://github.com/BlinkDL/ChatRWKV/blob/main/rwkv_pip_package/src/rwkv/model.py#L904-L906
|
669 |
state = []
|
670 |
num_attention_heads = self.config.hidden_size // self.config.num_attention_heads
|
671 |
-
|
672 |
-
|
673 |
-
|
674 |
-
|
675 |
-
|
676 |
-
|
677 |
-
|
678 |
-
|
679 |
-
|
680 |
-
|
681 |
-
|
682 |
-
|
683 |
-
|
684 |
-
|
685 |
-
|
686 |
-
|
687 |
-
|
688 |
-
|
689 |
-
|
690 |
-
|
691 |
-
|
692 |
-
|
693 |
-
|
694 |
-
|
695 |
-
|
696 |
-
|
697 |
-
|
698 |
-
device=inputs_embeds.device,
|
699 |
-
).contiguous()
|
700 |
-
)
|
701 |
|
702 |
seq_mode = inputs_embeds.shape[1] > 1
|
703 |
hidden_states = inputs_embeds
|
@@ -752,10 +770,32 @@ class Rwkv5Model(Rwkv5PreTrainedModel):
|
|
752 |
|
753 |
self.layers_are_rescaled = not self.training
|
754 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
755 |
|
|
|
756 |
@add_start_docstrings(
|
757 |
"""
|
758 |
-
The
|
759 |
embeddings).
|
760 |
""",
|
761 |
RWKV_START_DOCSTRING,
|
|
|
16 |
"""PyTorch RWKV5 World model."""
|
17 |
|
18 |
from dataclasses import dataclass
|
19 |
+
from pathlib import Path
|
20 |
from typing import List, Optional, Tuple, Union
|
21 |
|
22 |
import torch
|
|
|
31 |
add_code_sample_docstrings,
|
32 |
add_start_docstrings,
|
33 |
add_start_docstrings_to_model_forward,
|
34 |
+
is_bitsandbytes_available,
|
35 |
is_ninja_available,
|
36 |
is_torch_cuda_available,
|
37 |
logging,
|
|
|
54 |
rwkv5_cuda_kernel = None
|
55 |
|
56 |
|
57 |
+
# Copied from https://github.com/huggingface/transformers/blob/18cbaf13dcaca7145f5652aefb9b19734c56c3cd/src/transformers/models/rwkv/modeling_rwkv.py#L65
|
58 |
def load_wkv5_cuda_kernel(head_size):
|
59 |
from torch.utils.cpp_extension import load as load_kernel
|
60 |
|
|
|
89 |
|
90 |
class WKV_5(torch.autograd.Function):
|
91 |
@staticmethod
|
92 |
+
def forward(ctx, receptance, key, value, time_decay, time_first, state):
|
93 |
with torch.no_grad():
|
94 |
+
Batch = key.shape[0]
|
95 |
+
SequenceLength = key.shape[1]
|
96 |
+
HiddenSize = key.shape[2]
|
97 |
+
HeadSize = HiddenSize // time_decay.shape[0]
|
98 |
+
ctx.Batch = Batch
|
99 |
+
ctx.SequenceLength = SequenceLength
|
100 |
+
ctx.HiddenSize = HiddenSize
|
101 |
+
ctx.HeadSize = HeadSize
|
102 |
+
e_time_decay = (-torch.exp(time_decay.float())).contiguous()
|
103 |
+
ee_time_decay = (torch.exp(e_time_decay)).contiguous()
|
104 |
+
ctx.save_for_backward(receptance, key, value, ee_time_decay, e_time_decay, time_first)
|
105 |
+
out = torch.empty(
|
106 |
+
(Batch, SequenceLength, HiddenSize),
|
107 |
+
device=receptance.device,
|
108 |
+
dtype=torch.bfloat16,
|
109 |
+
memory_format=torch.contiguous_format,
|
110 |
+
)
|
111 |
+
rwkv5_cuda_kernel.forward(
|
112 |
+
Batch,
|
113 |
+
SequenceLength,
|
114 |
+
HiddenSize,
|
115 |
+
HeadSize,
|
116 |
+
receptance,
|
117 |
+
key,
|
118 |
+
value,
|
119 |
+
ee_time_decay,
|
120 |
+
time_first,
|
121 |
+
out,
|
122 |
+
state,
|
123 |
+
)
|
124 |
+
return out, state
|
125 |
|
126 |
@staticmethod
|
127 |
+
def backward(ctx, gout):
|
128 |
with torch.no_grad():
|
129 |
+
assert gout.dtype == torch.bfloat16
|
130 |
+
Batch = ctx.Batch
|
131 |
+
SequenceLength = ctx.SequenceLength
|
132 |
+
HiddenSize = ctx.HiddenSize
|
133 |
+
HeadSize = ctx.HeadSize
|
134 |
+
receptance, key, value, ee_time_decay, e_time_decay, time_first = ctx.saved_tensors
|
135 |
+
greceptance = torch.empty(
|
136 |
+
(Batch, SequenceLength, HiddenSize),
|
137 |
+
device=gout.device,
|
|
|
138 |
requires_grad=False,
|
139 |
dtype=torch.bfloat16,
|
140 |
memory_format=torch.contiguous_format,
|
141 |
+
)
|
142 |
+
g_key = torch.empty(
|
143 |
+
(Batch, SequenceLength, HiddenSize),
|
144 |
+
device=gout.device,
|
145 |
requires_grad=False,
|
146 |
dtype=torch.bfloat16,
|
147 |
memory_format=torch.contiguous_format,
|
148 |
+
)
|
149 |
+
g_value = torch.empty(
|
150 |
+
(Batch, SequenceLength, HiddenSize),
|
151 |
+
device=gout.device,
|
152 |
requires_grad=False,
|
153 |
dtype=torch.bfloat16,
|
154 |
memory_format=torch.contiguous_format,
|
155 |
+
)
|
156 |
+
g_time_decay = torch.empty(
|
157 |
+
(Batch, HiddenSize),
|
158 |
+
device=gout.device,
|
159 |
requires_grad=False,
|
160 |
dtype=torch.bfloat16,
|
161 |
memory_format=torch.contiguous_format,
|
162 |
+
)
|
163 |
+
g_time_first = torch.empty(
|
164 |
+
(Batch, HiddenSize),
|
165 |
+
device=gout.device,
|
166 |
requires_grad=False,
|
167 |
dtype=torch.bfloat16,
|
168 |
memory_format=torch.contiguous_format,
|
169 |
+
)
|
170 |
+
rwkv5_cuda_kernel.backward(
|
171 |
+
Batch,
|
172 |
+
SequenceLength,
|
173 |
+
HiddenSize,
|
174 |
+
HeadSize,
|
175 |
+
receptance,
|
176 |
+
key,
|
177 |
+
value,
|
178 |
+
ee_time_decay,
|
179 |
+
e_time_decay,
|
180 |
+
time_first,
|
181 |
+
gout,
|
182 |
+
greceptance,
|
183 |
+
g_key,
|
184 |
+
g_value,
|
185 |
+
g_time_decay,
|
186 |
+
g_time_first,
|
187 |
+
)
|
188 |
+
g_time_decay = torch.sum(g_time_decay, 0).view(HeadSize, HiddenSize // HeadSize)
|
189 |
+
g_time_first = torch.sum(g_time_first, 0).view(HeadSize, HiddenSize // HeadSize)
|
190 |
+
return (None, None, None, None, greceptance, g_key, g_value, g_time_decay, g_time_first)
|
191 |
|
192 |
|
193 |
def rwkv_linear_attention_v5_cpu(
|
|
|
|
|
|
|
|
|
|
|
194 |
hidden,
|
195 |
time_decay,
|
196 |
time_first,
|
|
|
198 |
key,
|
199 |
value,
|
200 |
gate,
|
201 |
+
layer_norm_weight,
|
202 |
+
layer_norm_bias,
|
203 |
+
output_weight,
|
204 |
state,
|
205 |
):
|
206 |
+
Batch = hidden.shape[0]
|
207 |
+
AttentionHeads = time_decay.shape[0]
|
208 |
+
HeadSize = hidden.shape[-1] // AttentionHeads
|
209 |
+
SequenceLength = hidden.shape[1]
|
210 |
+
key = key.to(torch.float32).view(Batch, SequenceLength, AttentionHeads, HeadSize).transpose(1, 2).transpose(-2, -1)
|
211 |
+
value = value.to(torch.float32).view(Batch, SequenceLength, AttentionHeads, HeadSize).transpose(1, 2)
|
212 |
+
receptance = receptance.to(torch.float32).view(Batch, SequenceLength, AttentionHeads, HeadSize).transpose(1, 2)
|
213 |
+
time_decay = torch.exp(-torch.exp(time_decay.float())).reshape(-1, 1, 1).reshape(AttentionHeads, -1, 1)
|
214 |
+
time_first = time_first.float().reshape(-1, 1, 1).reshape(AttentionHeads, -1, 1)
|
215 |
+
layer_norm_weight = layer_norm_weight.float()
|
216 |
+
layer_norm_bias = layer_norm_bias.float()
|
217 |
+
out = torch.zeros_like(key).reshape(Batch, SequenceLength, AttentionHeads, HeadSize)
|
218 |
+
for t in range(SequenceLength):
|
219 |
rt = receptance[:, :, t : t + 1, :]
|
220 |
kt = key[:, :, :, t : t + 1]
|
221 |
vt = value[:, :, t : t + 1, :]
|
|
|
224 |
with torch.no_grad():
|
225 |
state = at + time_decay * state
|
226 |
|
227 |
+
out = out.reshape(Batch * SequenceLength, AttentionHeads * HeadSize)
|
228 |
+
out = F.group_norm(out, num_groups=AttentionHeads, weight=layer_norm_weight, bias=layer_norm_bias).reshape(
|
229 |
+
Batch, SequenceLength, AttentionHeads * HeadSize
|
230 |
+
)
|
231 |
out = out.to(dtype=hidden.dtype) * gate
|
232 |
+
out = out @ output_weight
|
233 |
|
234 |
return out, state
|
235 |
|
236 |
|
237 |
def rwkv_linear_attention(
|
|
|
|
|
|
|
|
|
|
|
238 |
hidden,
|
239 |
time_decay,
|
240 |
time_first,
|
|
|
242 |
key,
|
243 |
value,
|
244 |
gate,
|
245 |
+
layer_norm_weight,
|
246 |
+
layer_norm_bias,
|
247 |
+
output_weight,
|
248 |
state,
|
249 |
):
|
250 |
+
Batch = hidden.shape[0]
|
251 |
+
AttentionHeads = time_decay.shape[0]
|
252 |
+
HeadSize = hidden.shape[-1] // AttentionHeads
|
253 |
+
SequenceLength = hidden.shape[1]
|
254 |
no_cuda = any(t.device.type != "cuda" for t in [time_decay, time_first, receptance, key, value])
|
255 |
# Launching the CUDA kernel for just one token will actually be slower (there is no for loop in the CPU version
|
256 |
# in this case).
|
257 |
one_token = key.size(1) == 1
|
258 |
if rwkv5_cuda_kernel is None or no_cuda or one_token:
|
259 |
return rwkv_linear_attention_v5_cpu(
|
|
|
|
|
|
|
|
|
|
|
260 |
hidden,
|
261 |
time_decay,
|
262 |
time_first,
|
|
|
264 |
key,
|
265 |
value,
|
266 |
gate,
|
267 |
+
layer_norm_weight,
|
268 |
+
layer_norm_bias,
|
269 |
+
output_weight,
|
270 |
state,
|
271 |
)
|
272 |
else:
|
273 |
+
out, state = WKV_5.apply(
|
274 |
+
Batch,
|
275 |
+
SequenceLength,
|
276 |
+
AttentionHeads * HeadSize,
|
277 |
+
AttentionHeads,
|
278 |
+
receptance,
|
279 |
+
key,
|
280 |
+
value,
|
281 |
+
time_decay,
|
282 |
+
time_first,
|
283 |
+
state,
|
284 |
+
)
|
285 |
+
out = out.reshape(Batch * SequenceLength, AttentionHeads * HeadSize)
|
286 |
+
out = F.group_norm(out, num_groups=AttentionHeads, weight=layer_norm_weight, bias=layer_norm_bias).reshape(
|
287 |
+
Batch, SequenceLength, AttentionHeads * HeadSize
|
288 |
+
)
|
289 |
out = out.to(dtype=hidden.dtype) * gate
|
290 |
+
out = out @ output_weight
|
291 |
return out, state
|
292 |
|
293 |
|
|
|
303 |
logger.info("Could not load the custom CUDA kernel for RWKV5 attention.")
|
304 |
self.layer_id = layer_id
|
305 |
hidden_size = config.hidden_size
|
|
|
306 |
num_attention_heads = hidden_size // config.head_size
|
307 |
self.num_attention_heads = num_attention_heads
|
308 |
attention_hidden_size = (
|
|
|
324 |
self.receptance = nn.Linear(hidden_size, attention_hidden_size, bias=False)
|
325 |
self.gate = nn.Linear(hidden_size, attention_hidden_size, bias=False)
|
326 |
self.output = nn.Linear(attention_hidden_size, hidden_size, bias=False)
|
|
|
327 |
self.ln_x = nn.GroupNorm(hidden_size // config.head_size, hidden_size)
|
328 |
|
329 |
+
def extract_key_value(self, hidden, state=None):
|
|
|
330 |
# Mix hidden with the previous timestep to produce key, value, receptance
|
331 |
if hidden.size(1) == 1 and state is not None:
|
332 |
shifted = state[0][:, :, self.layer_id]
|
|
|
341 |
receptance = hidden * self.time_mix_receptance + shifted * (1 - self.time_mix_receptance)
|
342 |
gate = hidden * self.time_mix_gate + shifted * (1 - self.time_mix_gate)
|
343 |
|
|
|
344 |
key = self.key(key)
|
345 |
value = self.value(value)
|
346 |
receptance = self.receptance(receptance)
|
|
|
352 |
return receptance, key, value, gate, state
|
353 |
|
354 |
def forward(self, hidden, state=None, use_cache=False, seq_mode=True):
|
355 |
+
receptance, key, value, gate, state = self.extract_key_value(hidden, state=state)
|
|
|
|
|
|
|
|
|
|
|
356 |
layer_state = state[1][:, :, :, :, self.layer_id] if state is not None else None
|
357 |
rwkv, layer_state = rwkv_linear_attention(
|
|
|
|
|
|
|
|
|
|
|
358 |
hidden,
|
359 |
self.time_decay,
|
360 |
self.time_faaaa,
|
|
|
380 |
self.config = config
|
381 |
self.layer_id = layer_id
|
382 |
hidden_size = config.hidden_size
|
|
|
383 |
intermediate_size = (
|
384 |
config.intermediate_size
|
385 |
if config.intermediate_size is not None
|
|
|
416 |
return receptance * value, state
|
417 |
|
418 |
|
419 |
+
# copied from HuggingFace https://github.com/huggingface/transformers/blob/main/src/transformers/models/rwkv/modeling_rwkv.py
|
420 |
+
class Rwkv5Block(nn.Module):
|
421 |
def __init__(self, config, layer_id):
|
422 |
super().__init__()
|
423 |
self.config = config
|
|
|
458 |
|
459 |
config_class = Rwkv5Config
|
460 |
base_model_prefix = "rwkv"
|
461 |
+
_no_split_modules = ["Rwkv5Block"]
|
462 |
_keep_in_fp32_modules = ["time_decay", "time_first"]
|
463 |
supports_gradient_checkpointing = True
|
464 |
|
|
|
481 |
)
|
482 |
time_weight = time_weight[None, None, :]
|
483 |
|
|
|
484 |
decay_speed = [
|
485 |
-6.0 + 5.0 * (h / (attention_hidden_size - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
|
486 |
for h in range(attention_hidden_size)
|
|
|
523 |
module.time_mix_receptance.data = torch.pow(time_weight, ratio_1_to_almost0)
|
524 |
|
525 |
|
526 |
+
# copied from HuggingFace https://github.com/huggingface/transformers/blob/main/src/transformers/models/rwkv/modeling_rwkv.py
|
527 |
@dataclass
|
528 |
class Rwkv5Output(ModelOutput):
|
529 |
"""
|
|
|
551 |
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
552 |
|
553 |
|
554 |
+
# copied from HuggingFace https://github.com/huggingface/transformers/blob/main/src/transformers/models/rwkv/modeling_rwkv.py
|
555 |
@dataclass
|
556 |
class Rwkv5CausalLMOutput(ModelOutput):
|
557 |
"""
|
|
|
633 |
super().__init__(config)
|
634 |
|
635 |
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
|
636 |
+
self.blocks = nn.ModuleList([Rwkv5Block(config, layer_id=idx) for idx in range(config.num_hidden_layers)])
|
637 |
self.ln_out = nn.LayerNorm(config.hidden_size)
|
638 |
|
639 |
self.layers_are_rescaled = False
|
|
|
687 |
inputs_embeds = self.embeddings(input_ids)
|
688 |
|
689 |
if use_cache and state is None:
|
|
|
690 |
state = []
|
691 |
num_attention_heads = self.config.hidden_size // self.config.num_attention_heads
|
692 |
+
state_attn_x = torch.zeros(
|
693 |
+
(inputs_embeds.size(0), self.config.hidden_size, self.config.num_hidden_layers),
|
694 |
+
dtype=inputs_embeds.dtype,
|
695 |
+
requires_grad=False,
|
696 |
+
device=inputs_embeds.device,
|
697 |
+
).contiguous()
|
698 |
+
state_attn_kv = torch.zeros(
|
699 |
+
(
|
700 |
+
inputs_embeds.size(0),
|
701 |
+
num_attention_heads,
|
702 |
+
self.config.hidden_size // num_attention_heads,
|
703 |
+
self.config.hidden_size // num_attention_heads,
|
704 |
+
self.config.num_hidden_layers,
|
705 |
+
),
|
706 |
+
dtype=torch.float32,
|
707 |
+
requires_grad=False,
|
708 |
+
device=inputs_embeds.device,
|
709 |
+
).contiguous()
|
710 |
+
state_ffn_x = torch.zeros(
|
711 |
+
(inputs_embeds.size(0), self.config.hidden_size, self.config.num_hidden_layers),
|
712 |
+
dtype=inputs_embeds.dtype,
|
713 |
+
requires_grad=False,
|
714 |
+
device=inputs_embeds.device,
|
715 |
+
).contiguous()
|
716 |
+
state.append(state_attn_x)
|
717 |
+
state.append(state_attn_kv)
|
718 |
+
state.append(state_ffn_x)
|
|
|
|
|
|
|
719 |
|
720 |
seq_mode = inputs_embeds.shape[1] > 1
|
721 |
hidden_states = inputs_embeds
|
|
|
770 |
|
771 |
self.layers_are_rescaled = not self.training
|
772 |
|
773 |
+
def _bnb_4bit_dequantize_and_rescale(self, target_layer, block_id):
|
774 |
+
r"""
|
775 |
+
Perform the dequantization and rescaling of the weights of a given layer. After that operation the layer will
|
776 |
+
be quantized again.
|
777 |
+
"""
|
778 |
+
if not is_bitsandbytes_available():
|
779 |
+
raise ImportError("Please install bitsandbytes to use this method.")
|
780 |
+
import bitsandbytes as bnb
|
781 |
+
|
782 |
+
dequant_weights = bnb.functional.dequantize_4bit(target_layer.weight.data, target_layer.weight.quant_state)
|
783 |
+
|
784 |
+
dequant_weights.div_(2 ** int(block_id // self.config.rescale_every))
|
785 |
+
|
786 |
+
# re-quantize the model:
|
787 |
+
# we need to put it first on CPU then back to the device
|
788 |
+
# this will create an overhead :/
|
789 |
+
# We set requires_grad=False as we cannot compute gradients on top of 4bit parameters anyway and to avoid
|
790 |
+
# bugs with bnb
|
791 |
+
quant_weight = bnb.nn.Params4bit(dequant_weights.to("cpu"), requires_grad=False).to(dequant_weights.device)
|
792 |
+
setattr(target_layer, "weight", quant_weight)
|
793 |
+
|
794 |
|
795 |
+
# copied from HuggingFace https://github.com/huggingface/transformers/blob/main/src/transformers/models/rwkv/modeling_rwkv.py
|
796 |
@add_start_docstrings(
|
797 |
"""
|
798 |
+
The RWKV5 Model transformer with a language modeling head on top (linear layer with weights tied to the input
|
799 |
embeddings).
|
800 |
""",
|
801 |
RWKV_START_DOCSTRING,
|