akhauriyash commited on
Commit
2898354
·
verified ·
1 Parent(s): 7a3b00d

Update modeling_llama_butler.py

Browse files

Fix inference to actually use sparsity, remove printing of effective sparsity (it is still calculated)

Files changed (1) hide show
  1. modeling_llama_butler.py +2 -2
modeling_llama_butler.py CHANGED
@@ -918,7 +918,7 @@ class LlamaAttentionExperimental(nn.Module):
918
  self.num_key_value_groups = self.num_heads // self.num_key_value_heads
919
  self.max_position_embeddings = config.max_position_embeddings
920
  self.rope_theta = config.rope_theta
921
- self.inference_mode = False
922
  self.producer = producer
923
  self.layer_idx = layer_idx
924
  self.token_sparse_method = None
@@ -1217,7 +1217,7 @@ class LlamaAttentionExperimental(nn.Module):
1217
  num_active = (~attention_mask.bool()).sum(dim=-1).expand_as(num_deact) # Number of tokens active at this position if zero-sparsity
1218
  effective_sparsity = 100 * (additional_deact.float() / num_active.float()).mean().item()
1219
  self.effective_sparsity = effective_sparsity
1220
- print("Effective Sparsity:", effective_sparsity, "%\t Sequence Length:", q_len)
1221
  if self.layer_idx == 0:
1222
  if self.effective_sparsity is None:
1223
  self.effective_sparsity = 0.0
 
918
  self.num_key_value_groups = self.num_heads // self.num_key_value_heads
919
  self.max_position_embeddings = config.max_position_embeddings
920
  self.rope_theta = config.rope_theta
921
+ self.inference_mode = True
922
  self.producer = producer
923
  self.layer_idx = layer_idx
924
  self.token_sparse_method = None
 
1217
  num_active = (~attention_mask.bool()).sum(dim=-1).expand_as(num_deact) # Number of tokens active at this position if zero-sparsity
1218
  effective_sparsity = 100 * (additional_deact.float() / num_active.float()).mean().item()
1219
  self.effective_sparsity = effective_sparsity
1220
+ # print("Effective Sparsity:", effective_sparsity, "%\t Sequence Length:", q_len)
1221
  if self.layer_idx == 0:
1222
  if self.effective_sparsity is None:
1223
  self.effective_sparsity = 0.0