Update modeling_llama.py
Browse filesupdating cache methods due to change in transformers
- modeling_llama.py +2 -2
modeling_llama.py
CHANGED
@@ -1218,8 +1218,8 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
|
1218 |
if isinstance(past_key_values, Cache):
|
1219 |
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
|
1220 |
max_cache_length = (
|
1221 |
-
torch.tensor(past_key_values.
|
1222 |
-
if past_key_values.
|
1223 |
else None
|
1224 |
)
|
1225 |
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
|
|
|
1218 |
if isinstance(past_key_values, Cache):
|
1219 |
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
|
1220 |
max_cache_length = (
|
1221 |
+
torch.tensor(past_key_values.get_max_cache_shape(), device=input_ids.device)
|
1222 |
+
if past_key_values.get_max_cache_shape() is not None
|
1223 |
else None
|
1224 |
)
|
1225 |
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
|