KaleiNeely commited on
Commit
de75931
1 Parent(s): b32aee7

Update modeling_rwkv5.py

Browse files
Files changed (1) hide show
  1. modeling_rwkv5.py +74 -3
modeling_rwkv5.py CHANGED
@@ -329,6 +329,80 @@ class RwkvPreTrainedModel(PreTrainedModel):
329
  _no_split_modules = ["RwkvBlock"]
330
  _keep_in_fp32_modules = ["time_decay", "time_first"]
331
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332
  def _set_gradient_checkpointing(self, module, value=False):
333
  if isinstance(module, RwkvModel):
334
  module.gradient_checkpointing = value
@@ -540,9 +614,6 @@ class RwkvModel(RwkvPreTrainedModel):
540
  if output_attentions:
541
  all_self_attentions = all_self_attentions + (attentions,)
542
 
543
- if self.config.model_version == "5_2" and seq_mode:
544
- hidden_states = hidden_states[:, -1, :].unsqueeze(1)
545
-
546
  hidden_states = self.ln_out(hidden_states)
547
 
548
  if output_hidden_states:
 
329
  _no_split_modules = ["RwkvBlock"]
330
  _keep_in_fp32_modules = ["time_decay", "time_first"]
331
 
332
+ def _init_weights(self, module):
333
+ """Initialize the weights."""
334
+ if isinstance(module, RwkvSelfAttention):
335
+ layer_id = module.layer_id
336
+ num_hidden_layers = module.config.num_hidden_layers
337
+ hidden_size = module.config.hidden_size
338
+ attention_hidden_size = module.attention_hidden_size
339
+ num_attention_heads = hidden_size // module.config.head_size
340
+
341
+ ratio_0_to_1 = layer_id / (num_hidden_layers - 1) # 0 to 1
342
+ ratio_1_to_almost0 = 1.0 - (layer_id / num_hidden_layers) # 1 to ~0
343
+
344
+ time_weight = torch.tensor(
345
+ [i / hidden_size for i in range(hidden_size)],
346
+ dtype=module.time_mix_key.dtype,
347
+ device=module.time_mix_key.device,
348
+ )
349
+ time_weight = time_weight[None, None, :]
350
+
351
+ if module.config.model_version == "5_2":
352
+ # https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4neo/src/model.py#L398
353
+ decay_speed = [
354
+ -6.0 + 5.0 * (h / (attention_hidden_size - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
355
+ for h in range(attention_hidden_size)
356
+ ]
357
+ else:
358
+ # https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4neo/src/model.py#L172
359
+ decay_speed = [
360
+ -6.0 + 5.0 * (h / (num_attention_heads - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
361
+ for h in range(num_attention_heads)
362
+ ]
363
+ decay_speed = torch.tensor(decay_speed, dtype=module.time_decay.dtype, device=module.time_decay.device)
364
+ if module.config.model_version == "5_2":
365
+ tmp = (
366
+ torch.tensor(
367
+ [(1.0 - (i / (attention_hidden_size - 1.0))) * ratio_0_to_1 + 0.1 * ((i + 1) % 3 - 1) for i in range(attention_hidden_size)],
368
+ dtype=module.time_faaaa.dtype,
369
+ device=module.time_faaaa.device,
370
+ )
371
+ )
372
+ else:
373
+ tmp = torch.ones(num_attention_heads) * (-3.0)
374
+
375
+ with torch.no_grad():
376
+ if module.config.model_version == "5_2":
377
+ module.time_decay.data = decay_speed.reshape(num_attention_heads, module.config.head_size)
378
+ module.time_faaaa.data = tmp.reshape(num_attention_heads, module.config.head_size)
379
+ else:
380
+ module.time_decay.data = decay_speed
381
+ module.time_first.data = tmp
382
+
383
+ module.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0)
384
+ module.time_mix_value.data = torch.pow(time_weight, ratio_1_to_almost0) + 0.3 * ratio_0_to_1
385
+ module.time_mix_receptance.data = torch.pow(time_weight, 0.5 * ratio_1_to_almost0)
386
+ if module.config.model_version == "5_2":
387
+ module.time_mix_gate.data = torch.pow(time_weight, 0.5 * ratio_1_to_almost0)
388
+ elif isinstance(module, RwkvFeedForward):
389
+ layer_id = module.layer_id
390
+ num_hidden_layers = module.config.num_hidden_layers
391
+ hidden_size = module.config.hidden_size
392
+
393
+ ratio_1_to_almost0 = 1.0 - (layer_id / num_hidden_layers) # 1 to ~0
394
+
395
+ time_weight = torch.tensor(
396
+ [i / hidden_size for i in range(hidden_size)],
397
+ dtype=module.time_mix_key.dtype,
398
+ device=module.time_mix_key.device,
399
+ )
400
+ time_weight = time_weight[None, None, :]
401
+
402
+ with torch.no_grad():
403
+ module.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0)
404
+ module.time_mix_receptance.data = torch.pow(time_weight, ratio_1_to_almost0)
405
+
406
  def _set_gradient_checkpointing(self, module, value=False):
407
  if isinstance(module, RwkvModel):
408
  module.gradient_checkpointing = value
 
614
  if output_attentions:
615
  all_self_attentions = all_self_attentions + (attentions,)
616
 
 
 
 
617
  hidden_states = self.ln_out(hidden_states)
618
 
619
  if output_hidden_states: