marriola commited on
Commit
94cc5ac
·
verified ·
1 Parent(s): 53f8c8f

Upload BD3LM

Browse files
Files changed (1) hide show
  1. modeling_bd3lm.py +2 -2
modeling_bd3lm.py CHANGED
@@ -603,8 +603,8 @@ class BD3LM(transformers.PreTrainedModel):
603
  for block in self.backbone.blocks:
604
  block.kv_cache = torch.zeros(
605
  eval_batch_size,
606
- self.n,
607
- self.config.model.hidden_size * 3,
608
  device='cuda',
609
  dtype=torch.bfloat16)
610
  block.cache_idx = 0
 
603
  for block in self.backbone.blocks:
604
  block.kv_cache = torch.zeros(
605
  eval_batch_size,
606
+ self.config.model_length,
607
+ self.config.hidden_dim * 3,
608
  device='cuda',
609
  dtype=torch.bfloat16)
610
  block.cache_idx = 0