KaleiNeely
commited on
Commit
•
f683510
1
Parent(s):
8e8cdaa
Update modeling_rwkv5.py
Browse files- modeling_rwkv5.py +181 -18
modeling_rwkv5.py
CHANGED
@@ -30,6 +30,8 @@ from transformers.utils import (
|
|
30 |
add_code_sample_docstrings,
|
31 |
add_start_docstrings,
|
32 |
add_start_docstrings_to_model_forward,
|
|
|
|
|
33 |
logging,
|
34 |
)
|
35 |
|
@@ -47,8 +49,121 @@ RWKV5_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
|
47 |
# See all RWKV models at https://huggingface.co/models?filter=rwkv
|
48 |
]
|
49 |
|
|
|
50 |
|
51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
B,
|
53 |
H,
|
54 |
S,
|
@@ -66,6 +181,9 @@ def rwkv_linear_attention_v5(
|
|
66 |
ow,
|
67 |
state,
|
68 |
):
|
|
|
|
|
|
|
69 |
time_decay = torch.exp(-torch.exp(time_decay.float())).reshape(-1, 1, 1).reshape(n_head, -1, 1)
|
70 |
time_first = time_first.float().reshape(-1, 1, 1).reshape(n_head, -1, 1)
|
71 |
lxw = lxw.float()
|
@@ -88,10 +206,66 @@ def rwkv_linear_attention_v5(
|
|
88 |
return out, state
|
89 |
|
90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
class RwkvSelfAttention(nn.Module):
|
92 |
def __init__(self, config, layer_id=0):
|
93 |
super().__init__()
|
94 |
self.config = config
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
self.layer_id = layer_id
|
96 |
hidden_size = config.hidden_size
|
97 |
# https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4neo/src/model.py#L146
|
@@ -136,9 +310,9 @@ class RwkvSelfAttention(nn.Module):
|
|
136 |
gate = hidden * self.time_mix_gate + shifted * (1 - self.time_mix_gate)
|
137 |
|
138 |
# https://github.com/BlinkDL/ChatRWKV/blob/main/rwkv_pip_package/src/rwkv/model.py#L693
|
139 |
-
key = self.key(key)
|
140 |
-
value = self.value(value)
|
141 |
-
receptance = self.receptance(receptance)
|
142 |
gate = F.silu(self.gate(gate))
|
143 |
|
144 |
if state is not None:
|
@@ -154,7 +328,7 @@ class RwkvSelfAttention(nn.Module):
|
|
154 |
|
155 |
receptance, key, value, gate, state = self.extract_key_value(B, H, S, T, hidden, state=state)
|
156 |
layer_state = state[1][:, :, :, :, self.layer_id] if state is not None else None
|
157 |
-
rwkv, layer_state =
|
158 |
B,
|
159 |
H,
|
160 |
S,
|
@@ -238,6 +412,8 @@ class RwkvBlock(nn.Module):
|
|
238 |
self.feed_forward = RwkvFeedForward(config, layer_id)
|
239 |
|
240 |
def forward(self, hidden, state=None, use_cache=False, output_attentions=False, seq_mode=True):
|
|
|
|
|
241 |
attention, state = self.attention(self.ln1(hidden), state=state, use_cache=use_cache, seq_mode=seq_mode)
|
242 |
hidden = hidden + attention
|
243 |
|
@@ -264,7 +440,6 @@ class Rwkv5PreTrainedModel(PreTrainedModel):
|
|
264 |
_no_split_modules = ["RwkvBlock"]
|
265 |
_keep_in_fp32_modules = ["time_decay", "time_first"]
|
266 |
supports_gradient_checkpointing = True
|
267 |
-
training = False
|
268 |
|
269 |
def _init_weights(self, module):
|
270 |
"""Initialize the weights."""
|
@@ -440,8 +615,6 @@ class Rwkv5Model(Rwkv5PreTrainedModel):
|
|
440 |
self.ln_out = nn.LayerNorm(config.hidden_size)
|
441 |
|
442 |
self.layers_are_rescaled = False
|
443 |
-
self.pre_ln_flag = False
|
444 |
-
|
445 |
self.gradient_checkpointing = False
|
446 |
|
447 |
# Initialize weights and apply final processing
|
@@ -489,16 +662,6 @@ class Rwkv5Model(Rwkv5PreTrainedModel):
|
|
489 |
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
490 |
|
491 |
if inputs_embeds is None:
|
492 |
-
if not self.pre_ln_flag:
|
493 |
-
normalized_weight = F.layer_norm(
|
494 |
-
self.embeddings.weight,
|
495 |
-
(self.config.hidden_size,),
|
496 |
-
weight=self.blocks[0].pre_ln.weight,
|
497 |
-
bias=self.blocks[0].pre_ln.bias,
|
498 |
-
)
|
499 |
-
self.embeddings.weight = nn.Parameter(normalized_weight)
|
500 |
-
self.pre_ln_flag = True
|
501 |
-
|
502 |
inputs_embeds = self.embeddings(input_ids)
|
503 |
|
504 |
if use_cache and state is None:
|
|
|
30 |
add_code_sample_docstrings,
|
31 |
add_start_docstrings,
|
32 |
add_start_docstrings_to_model_forward,
|
33 |
+
is_ninja_available,
|
34 |
+
is_torch_cuda_available,
|
35 |
logging,
|
36 |
)
|
37 |
|
|
|
49 |
# See all RWKV models at https://huggingface.co/models?filter=rwkv
|
50 |
]
|
51 |
|
52 |
+
rwkv5_cuda_kernel = None
|
53 |
|
54 |
+
|
55 |
+
def load_wkv5_cuda_kernel(head_size):
|
56 |
+
from torch.utils.cpp_extension import load as load_kernel
|
57 |
+
|
58 |
+
global rwkv5_cuda_kernel
|
59 |
+
|
60 |
+
kernel_folder = Path(__file__).resolve().parent.parent.parent / "kernels" / "rwkv5"
|
61 |
+
cuda_kernel_files = [kernel_folder / f for f in ["wkv5_op.cpp", "wkv5_cuda.cu"]]
|
62 |
+
|
63 |
+
# Only load the kernel if it's not been loaded yet or if we changed the context length
|
64 |
+
if rwkv5_cuda_kernel is not None and rwkv5_cuda_kernel.head_size == head_size:
|
65 |
+
return
|
66 |
+
|
67 |
+
logger.info(f"Loading CUDA kernel for RWKV at head size of {head_size}.")
|
68 |
+
|
69 |
+
flags = [
|
70 |
+
"-res-usage",
|
71 |
+
"--maxrregcount 60",
|
72 |
+
"--use_fast_math",
|
73 |
+
"-O3",
|
74 |
+
"-Xptxas -O3",
|
75 |
+
"--extra-device-vectorization",
|
76 |
+
f"-D_N_={head_size}",
|
77 |
+
]
|
78 |
+
rwkv5_cuda_kernel = load_kernel(
|
79 |
+
name=f"wkv_{head_size}",
|
80 |
+
sources=cuda_kernel_files,
|
81 |
+
verbose=(logging.get_verbosity() == logging.DEBUG),
|
82 |
+
extra_cuda_cflags=flags,
|
83 |
+
)
|
84 |
+
rwkv5_cuda_kernel.head_size = head_size
|
85 |
+
|
86 |
+
|
87 |
+
class WKV_5(torch.autograd.Function):
|
88 |
+
@staticmethod
|
89 |
+
def forward(ctx, B, T, C, H, r, k, v, w, u, s):
|
90 |
+
with torch.no_grad():
|
91 |
+
assert r.dtype == torch.bfloat16
|
92 |
+
assert k.dtype == torch.bfloat16
|
93 |
+
assert v.dtype == torch.bfloat16
|
94 |
+
assert w.dtype == torch.bfloat16
|
95 |
+
assert u.dtype == torch.bfloat16
|
96 |
+
assert s.dtype == torch.float32
|
97 |
+
ctx.B = B
|
98 |
+
ctx.T = T
|
99 |
+
ctx.C = C
|
100 |
+
ctx.H = H
|
101 |
+
assert r.is_contiguous()
|
102 |
+
assert k.is_contiguous()
|
103 |
+
assert v.is_contiguous()
|
104 |
+
assert w.is_contiguous()
|
105 |
+
assert u.is_contiguous()
|
106 |
+
ew = (-torch.exp(w.float())).contiguous()
|
107 |
+
eew = (torch.exp(ew)).contiguous()
|
108 |
+
ctx.save_for_backward(r, k, v, eew, ew, u)
|
109 |
+
y = torch.empty(
|
110 |
+
(B, T, C), device=r.device, dtype=torch.bfloat16, memory_format=torch.contiguous_format
|
111 |
+
) # .uniform_(-1, 1)
|
112 |
+
rwkv5_cuda_kernel.forward(B, T, C, H, r, k, v, eew, u, y, s)
|
113 |
+
return y, s
|
114 |
+
|
115 |
+
@staticmethod
|
116 |
+
def backward(ctx, gy):
|
117 |
+
with torch.no_grad():
|
118 |
+
assert gy.dtype == torch.bfloat16
|
119 |
+
B = ctx.B
|
120 |
+
T = ctx.T
|
121 |
+
C = ctx.C
|
122 |
+
H = ctx.H
|
123 |
+
assert gy.is_contiguous()
|
124 |
+
r, k, v, eew, ew, u = ctx.saved_tensors
|
125 |
+
gr = torch.empty(
|
126 |
+
(B, T, C),
|
127 |
+
device=gy.device,
|
128 |
+
requires_grad=False,
|
129 |
+
dtype=torch.bfloat16,
|
130 |
+
memory_format=torch.contiguous_format,
|
131 |
+
) # .uniform_(-1, 1)
|
132 |
+
gk = torch.empty(
|
133 |
+
(B, T, C),
|
134 |
+
device=gy.device,
|
135 |
+
requires_grad=False,
|
136 |
+
dtype=torch.bfloat16,
|
137 |
+
memory_format=torch.contiguous_format,
|
138 |
+
) # .uniform_(-1, 1)
|
139 |
+
gv = torch.empty(
|
140 |
+
(B, T, C),
|
141 |
+
device=gy.device,
|
142 |
+
requires_grad=False,
|
143 |
+
dtype=torch.bfloat16,
|
144 |
+
memory_format=torch.contiguous_format,
|
145 |
+
) # .uniform_(-1, 1)
|
146 |
+
gw = torch.empty(
|
147 |
+
(B, C),
|
148 |
+
device=gy.device,
|
149 |
+
requires_grad=False,
|
150 |
+
dtype=torch.bfloat16,
|
151 |
+
memory_format=torch.contiguous_format,
|
152 |
+
) # .uniform_(-1, 1)
|
153 |
+
gu = torch.empty(
|
154 |
+
(B, C),
|
155 |
+
device=gy.device,
|
156 |
+
requires_grad=False,
|
157 |
+
dtype=torch.bfloat16,
|
158 |
+
memory_format=torch.contiguous_format,
|
159 |
+
) # .uniform_(-1, 1)
|
160 |
+
rwkv5_cuda_kernel.backward(B, T, C, H, r, k, v, eew, ew, u, gy, gr, gk, gv, gw, gu)
|
161 |
+
gw = torch.sum(gw, 0).view(H, C // H)
|
162 |
+
gu = torch.sum(gu, 0).view(H, C // H)
|
163 |
+
return (None, None, None, None, gr, gk, gv, gw, gu)
|
164 |
+
|
165 |
+
|
166 |
+
def rwkv_linear_attention_v5_cpu(
|
167 |
B,
|
168 |
H,
|
169 |
S,
|
|
|
181 |
ow,
|
182 |
state,
|
183 |
):
|
184 |
+
key = key.to(torch.float32).view(B, T, H, S).transpose(1, 2).transpose(-2, -1)
|
185 |
+
value = value.to(torch.float32).view(B, T, H, S).transpose(1, 2)
|
186 |
+
receptance = receptance.to(torch.float32).view(B, T, H, S).transpose(1, 2)
|
187 |
time_decay = torch.exp(-torch.exp(time_decay.float())).reshape(-1, 1, 1).reshape(n_head, -1, 1)
|
188 |
time_first = time_first.float().reshape(-1, 1, 1).reshape(n_head, -1, 1)
|
189 |
lxw = lxw.float()
|
|
|
206 |
return out, state
|
207 |
|
208 |
|
209 |
+
def rwkv_linear_attention(
|
210 |
+
B,
|
211 |
+
H,
|
212 |
+
S,
|
213 |
+
T,
|
214 |
+
n_head,
|
215 |
+
hidden,
|
216 |
+
time_decay,
|
217 |
+
time_first,
|
218 |
+
receptance,
|
219 |
+
key,
|
220 |
+
value,
|
221 |
+
gate,
|
222 |
+
lxw,
|
223 |
+
lxb,
|
224 |
+
ow,
|
225 |
+
state,
|
226 |
+
):
|
227 |
+
no_cuda = any(t.device.type != "cuda" for t in [time_decay, time_first, receptance, key, value])
|
228 |
+
# Launching the CUDA kernel for just one token will actually be slower (there is no for loop in the CPU version
|
229 |
+
# in this case).
|
230 |
+
one_token = key.size(1) == 1
|
231 |
+
if rwkv5_cuda_kernel is None or no_cuda or one_token:
|
232 |
+
return rwkv_linear_attention_v5_cpu(
|
233 |
+
B,
|
234 |
+
H,
|
235 |
+
S,
|
236 |
+
T,
|
237 |
+
n_head,
|
238 |
+
hidden,
|
239 |
+
time_decay,
|
240 |
+
time_first,
|
241 |
+
receptance,
|
242 |
+
key,
|
243 |
+
value,
|
244 |
+
gate,
|
245 |
+
lxw,
|
246 |
+
lxb,
|
247 |
+
ow,
|
248 |
+
state,
|
249 |
+
)
|
250 |
+
else:
|
251 |
+
out, state = WKV_5.apply(B, T, H * S, H, receptance, key, value, time_decay, time_first, state)
|
252 |
+
out = out.reshape(B * T, H * S)
|
253 |
+
out = F.group_norm(out, num_groups=H, weight=lxw, bias=lxb).reshape(B, T, H * S)
|
254 |
+
out = out.to(dtype=hidden.dtype) * gate
|
255 |
+
out = out @ ow
|
256 |
+
return out, state
|
257 |
+
|
258 |
+
|
259 |
class RwkvSelfAttention(nn.Module):
|
260 |
def __init__(self, config, layer_id=0):
|
261 |
super().__init__()
|
262 |
self.config = config
|
263 |
+
kernel_loaded = rwkv5_cuda_kernel is not None and rwkv5_cuda_kernel.head_size == config.head_size
|
264 |
+
if is_ninja_available() and is_torch_cuda_available() and not kernel_loaded:
|
265 |
+
try:
|
266 |
+
load_wkv5_cuda_kernel(config.context_length)
|
267 |
+
except Exception:
|
268 |
+
logger.info("Could not load the custom CUDA kernel for RWKV5 attention.")
|
269 |
self.layer_id = layer_id
|
270 |
hidden_size = config.hidden_size
|
271 |
# https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4neo/src/model.py#L146
|
|
|
310 |
gate = hidden * self.time_mix_gate + shifted * (1 - self.time_mix_gate)
|
311 |
|
312 |
# https://github.com/BlinkDL/ChatRWKV/blob/main/rwkv_pip_package/src/rwkv/model.py#L693
|
313 |
+
key = self.key(key)
|
314 |
+
value = self.value(value)
|
315 |
+
receptance = self.receptance(receptance)
|
316 |
gate = F.silu(self.gate(gate))
|
317 |
|
318 |
if state is not None:
|
|
|
328 |
|
329 |
receptance, key, value, gate, state = self.extract_key_value(B, H, S, T, hidden, state=state)
|
330 |
layer_state = state[1][:, :, :, :, self.layer_id] if state is not None else None
|
331 |
+
rwkv, layer_state = rwkv_linear_attention(
|
332 |
B,
|
333 |
H,
|
334 |
S,
|
|
|
412 |
self.feed_forward = RwkvFeedForward(config, layer_id)
|
413 |
|
414 |
def forward(self, hidden, state=None, use_cache=False, output_attentions=False, seq_mode=True):
|
415 |
+
if self.layer_id == 0:
|
416 |
+
hidden = self.pre_ln(hidden)
|
417 |
attention, state = self.attention(self.ln1(hidden), state=state, use_cache=use_cache, seq_mode=seq_mode)
|
418 |
hidden = hidden + attention
|
419 |
|
|
|
440 |
_no_split_modules = ["RwkvBlock"]
|
441 |
_keep_in_fp32_modules = ["time_decay", "time_first"]
|
442 |
supports_gradient_checkpointing = True
|
|
|
443 |
|
444 |
def _init_weights(self, module):
|
445 |
"""Initialize the weights."""
|
|
|
615 |
self.ln_out = nn.LayerNorm(config.hidden_size)
|
616 |
|
617 |
self.layers_are_rescaled = False
|
|
|
|
|
618 |
self.gradient_checkpointing = False
|
619 |
|
620 |
# Initialize weights and apply final processing
|
|
|
662 |
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
663 |
|
664 |
if inputs_embeds is None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
665 |
inputs_embeds = self.embeddings(input_ids)
|
666 |
|
667 |
if use_cache and state is None:
|