zifei9 commited on
Commit
5981ddd
·
verified ·
1 Parent(s): 4ea6b44

Update modeling_llama.py

Browse files

updating cache methods due to change in transformers

Files changed (1) hide show
  1. modeling_llama.py +2 -2
modeling_llama.py CHANGED
@@ -1218,8 +1218,8 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
1218
  if isinstance(past_key_values, Cache):
1219
  past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
1220
  max_cache_length = (
1221
- torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
1222
- if past_key_values.get_max_length() is not None
1223
  else None
1224
  )
1225
  cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
 
1218
  if isinstance(past_key_values, Cache):
1219
  past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
1220
  max_cache_length = (
1221
+ torch.tensor(past_key_values.get_max_cache_shape(), device=input_ids.device)
1222
+ if past_key_values.get_max_cache_shape() is not None
1223
  else None
1224
  )
1225
  cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)