KaleiNeely
commited on
Commit
•
8e8cdaa
1
Parent(s):
0f622bc
Update modeling_rwkv5.py
Browse files- modeling_rwkv5.py +13 -9
modeling_rwkv5.py
CHANGED
@@ -22,6 +22,7 @@ import torch
|
|
22 |
import torch.nn.functional as F
|
23 |
import torch.utils.checkpoint
|
24 |
from torch import nn
|
|
|
25 |
|
26 |
from transformers.modeling_utils import PreTrainedModel
|
27 |
from transformers.utils import (
|
@@ -42,6 +43,7 @@ _CONFIG_FOR_DOC = "Rwkv5Config"
|
|
42 |
|
43 |
RWKV5_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
44 |
"RWKV/rwkv-5-world-1b5",
|
|
|
45 |
# See all RWKV models at https://huggingface.co/models?filter=rwkv
|
46 |
]
|
47 |
|
@@ -63,22 +65,20 @@ def rwkv_linear_attention_v5(
|
|
63 |
lxb,
|
64 |
ow,
|
65 |
state,
|
66 |
-
return_state=False,
|
67 |
-
seq_mode=True,
|
68 |
):
|
69 |
time_decay = torch.exp(-torch.exp(time_decay.float())).reshape(-1, 1, 1).reshape(n_head, -1, 1)
|
70 |
time_first = time_first.float().reshape(-1, 1, 1).reshape(n_head, -1, 1)
|
71 |
lxw = lxw.float()
|
72 |
lxb = lxb.float()
|
73 |
-
|
74 |
-
out = torch.empty((B, T, H, S), dtype=receptance.dtype, device=receptance.device)
|
75 |
for t in range(T):
|
76 |
rt = receptance[:, :, t : t + 1, :]
|
77 |
kt = key[:, :, :, t : t + 1]
|
78 |
vt = value[:, :, t : t + 1, :]
|
79 |
at = kt @ vt
|
80 |
out[:, t] = (rt @ (time_first * at + state)).squeeze(2)
|
81 |
-
|
|
|
82 |
|
83 |
out = out.reshape(B * T, H * S)
|
84 |
out = F.group_norm(out, num_groups=H, weight=lxw, bias=lxb).reshape(B, T, H * S)
|
@@ -171,8 +171,6 @@ class RwkvSelfAttention(nn.Module):
|
|
171 |
self.ln_x.bias,
|
172 |
self.output.weight.t(),
|
173 |
state=layer_state,
|
174 |
-
return_state=use_cache,
|
175 |
-
seq_mode=seq_mode,
|
176 |
)
|
177 |
|
178 |
if layer_state is not None:
|
@@ -671,8 +669,14 @@ class Rwkv5ForCausalLM(Rwkv5PreTrainedModel):
|
|
671 |
|
672 |
loss = None
|
673 |
if labels is not None:
|
674 |
-
#
|
675 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
676 |
|
677 |
if not return_dict:
|
678 |
output = (logits,) + rwkv_outputs[1:]
|
|
|
22 |
import torch.nn.functional as F
|
23 |
import torch.utils.checkpoint
|
24 |
from torch import nn
|
25 |
+
from torch.nn import CrossEntropyLoss
|
26 |
|
27 |
from transformers.modeling_utils import PreTrainedModel
|
28 |
from transformers.utils import (
|
|
|
43 |
|
44 |
RWKV5_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
45 |
"RWKV/rwkv-5-world-1b5",
|
46 |
+
"RWKV/rwkv-5-world-3b",
|
47 |
# See all RWKV models at https://huggingface.co/models?filter=rwkv
|
48 |
]
|
49 |
|
|
|
65 |
lxb,
|
66 |
ow,
|
67 |
state,
|
|
|
|
|
68 |
):
|
69 |
time_decay = torch.exp(-torch.exp(time_decay.float())).reshape(-1, 1, 1).reshape(n_head, -1, 1)
|
70 |
time_first = time_first.float().reshape(-1, 1, 1).reshape(n_head, -1, 1)
|
71 |
lxw = lxw.float()
|
72 |
lxb = lxb.float()
|
73 |
+
out = torch.zeros_like(key).reshape(B, T, H, S)
|
|
|
74 |
for t in range(T):
|
75 |
rt = receptance[:, :, t : t + 1, :]
|
76 |
kt = key[:, :, :, t : t + 1]
|
77 |
vt = value[:, :, t : t + 1, :]
|
78 |
at = kt @ vt
|
79 |
out[:, t] = (rt @ (time_first * at + state)).squeeze(2)
|
80 |
+
with torch.no_grad():
|
81 |
+
state = at + time_decay * state
|
82 |
|
83 |
out = out.reshape(B * T, H * S)
|
84 |
out = F.group_norm(out, num_groups=H, weight=lxw, bias=lxb).reshape(B, T, H * S)
|
|
|
171 |
self.ln_x.bias,
|
172 |
self.output.weight.t(),
|
173 |
state=layer_state,
|
|
|
|
|
174 |
)
|
175 |
|
176 |
if layer_state is not None:
|
|
|
669 |
|
670 |
loss = None
|
671 |
if labels is not None:
|
672 |
+
# move labels to correct device to enable model parallelism
|
673 |
+
labels = labels.to(logits.device)
|
674 |
+
# Shift so that tokens < n predict n
|
675 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
676 |
+
shift_labels = labels[..., 1:].contiguous()
|
677 |
+
# Flatten the tokens
|
678 |
+
loss_fct = CrossEntropyLoss()
|
679 |
+
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
680 |
|
681 |
if not return_dict:
|
682 |
output = (logits,) + rwkv_outputs[1:]
|