Upload BD3LM
Browse files- modeling_bd3lm.py +2 -2
modeling_bd3lm.py
CHANGED
@@ -603,8 +603,8 @@ class BD3LM(transformers.PreTrainedModel):
|
|
603 |
for block in self.backbone.blocks:
|
604 |
block.kv_cache = torch.zeros(
|
605 |
eval_batch_size,
|
606 |
-
self.
|
607 |
-
self.config.
|
608 |
device='cuda',
|
609 |
dtype=torch.bfloat16)
|
610 |
block.cache_idx = 0
|
|
|
603 |
for block in self.backbone.blocks:
|
604 |
block.kv_cache = torch.zeros(
|
605 |
eval_batch_size,
|
606 |
+
self.config.model_length,
|
607 |
+
self.config.hidden_dim * 3,
|
608 |
device='cuda',
|
609 |
dtype=torch.bfloat16)
|
610 |
block.cache_idx = 0
|