KaleiNeely
commited on
Commit
•
de75931
1
Parent(s):
b32aee7
Update modeling_rwkv5.py
Browse files- modeling_rwkv5.py +74 -3
modeling_rwkv5.py
CHANGED
@@ -329,6 +329,80 @@ class RwkvPreTrainedModel(PreTrainedModel):
|
|
329 |
_no_split_modules = ["RwkvBlock"]
|
330 |
_keep_in_fp32_modules = ["time_decay", "time_first"]
|
331 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
332 |
def _set_gradient_checkpointing(self, module, value=False):
|
333 |
if isinstance(module, RwkvModel):
|
334 |
module.gradient_checkpointing = value
|
@@ -540,9 +614,6 @@ class RwkvModel(RwkvPreTrainedModel):
|
|
540 |
if output_attentions:
|
541 |
all_self_attentions = all_self_attentions + (attentions,)
|
542 |
|
543 |
-
if self.config.model_version == "5_2" and seq_mode:
|
544 |
-
hidden_states = hidden_states[:, -1, :].unsqueeze(1)
|
545 |
-
|
546 |
hidden_states = self.ln_out(hidden_states)
|
547 |
|
548 |
if output_hidden_states:
|
|
|
329 |
_no_split_modules = ["RwkvBlock"]
|
330 |
_keep_in_fp32_modules = ["time_decay", "time_first"]
|
331 |
|
332 |
+
def _init_weights(self, module):
|
333 |
+
"""Initialize the weights."""
|
334 |
+
if isinstance(module, RwkvSelfAttention):
|
335 |
+
layer_id = module.layer_id
|
336 |
+
num_hidden_layers = module.config.num_hidden_layers
|
337 |
+
hidden_size = module.config.hidden_size
|
338 |
+
attention_hidden_size = module.attention_hidden_size
|
339 |
+
num_attention_heads = hidden_size // module.config.head_size
|
340 |
+
|
341 |
+
ratio_0_to_1 = layer_id / (num_hidden_layers - 1) # 0 to 1
|
342 |
+
ratio_1_to_almost0 = 1.0 - (layer_id / num_hidden_layers) # 1 to ~0
|
343 |
+
|
344 |
+
time_weight = torch.tensor(
|
345 |
+
[i / hidden_size for i in range(hidden_size)],
|
346 |
+
dtype=module.time_mix_key.dtype,
|
347 |
+
device=module.time_mix_key.device,
|
348 |
+
)
|
349 |
+
time_weight = time_weight[None, None, :]
|
350 |
+
|
351 |
+
if module.config.model_version == "5_2":
|
352 |
+
# https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4neo/src/model.py#L398
|
353 |
+
decay_speed = [
|
354 |
+
-6.0 + 5.0 * (h / (attention_hidden_size - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
|
355 |
+
for h in range(attention_hidden_size)
|
356 |
+
]
|
357 |
+
else:
|
358 |
+
# https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4neo/src/model.py#L172
|
359 |
+
decay_speed = [
|
360 |
+
-6.0 + 5.0 * (h / (num_attention_heads - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
|
361 |
+
for h in range(num_attention_heads)
|
362 |
+
]
|
363 |
+
decay_speed = torch.tensor(decay_speed, dtype=module.time_decay.dtype, device=module.time_decay.device)
|
364 |
+
if module.config.model_version == "5_2":
|
365 |
+
tmp = (
|
366 |
+
torch.tensor(
|
367 |
+
[(1.0 - (i / (attention_hidden_size - 1.0))) * ratio_0_to_1 + 0.1 * ((i + 1) % 3 - 1) for i in range(attention_hidden_size)],
|
368 |
+
dtype=module.time_faaaa.dtype,
|
369 |
+
device=module.time_faaaa.device,
|
370 |
+
)
|
371 |
+
)
|
372 |
+
else:
|
373 |
+
tmp = torch.ones(num_attention_heads) * (-3.0)
|
374 |
+
|
375 |
+
with torch.no_grad():
|
376 |
+
if module.config.model_version == "5_2":
|
377 |
+
module.time_decay.data = decay_speed.reshape(num_attention_heads, module.config.head_size)
|
378 |
+
module.time_faaaa.data = tmp.reshape(num_attention_heads, module.config.head_size)
|
379 |
+
else:
|
380 |
+
module.time_decay.data = decay_speed
|
381 |
+
module.time_first.data = tmp
|
382 |
+
|
383 |
+
module.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0)
|
384 |
+
module.time_mix_value.data = torch.pow(time_weight, ratio_1_to_almost0) + 0.3 * ratio_0_to_1
|
385 |
+
module.time_mix_receptance.data = torch.pow(time_weight, 0.5 * ratio_1_to_almost0)
|
386 |
+
if module.config.model_version == "5_2":
|
387 |
+
module.time_mix_gate.data = torch.pow(time_weight, 0.5 * ratio_1_to_almost0)
|
388 |
+
elif isinstance(module, RwkvFeedForward):
|
389 |
+
layer_id = module.layer_id
|
390 |
+
num_hidden_layers = module.config.num_hidden_layers
|
391 |
+
hidden_size = module.config.hidden_size
|
392 |
+
|
393 |
+
ratio_1_to_almost0 = 1.0 - (layer_id / num_hidden_layers) # 1 to ~0
|
394 |
+
|
395 |
+
time_weight = torch.tensor(
|
396 |
+
[i / hidden_size for i in range(hidden_size)],
|
397 |
+
dtype=module.time_mix_key.dtype,
|
398 |
+
device=module.time_mix_key.device,
|
399 |
+
)
|
400 |
+
time_weight = time_weight[None, None, :]
|
401 |
+
|
402 |
+
with torch.no_grad():
|
403 |
+
module.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0)
|
404 |
+
module.time_mix_receptance.data = torch.pow(time_weight, ratio_1_to_almost0)
|
405 |
+
|
406 |
def _set_gradient_checkpointing(self, module, value=False):
|
407 |
if isinstance(module, RwkvModel):
|
408 |
module.gradient_checkpointing = value
|
|
|
614 |
if output_attentions:
|
615 |
all_self_attentions = all_self_attentions + (attentions,)
|
616 |
|
|
|
|
|
|
|
617 |
hidden_states = self.ln_out(hidden_states)
|
618 |
|
619 |
if output_hidden_states:
|