Update modeling_llama.py
Browse files- modeling_llama.py +7 -2
modeling_llama.py
CHANGED
|
@@ -961,12 +961,13 @@ class LlamaModel(LlamaPreTrainedModel):
|
|
| 961 |
inputs_embeds = self.embed_tokens(input_ids)
|
| 962 |
|
| 963 |
past_seen_tokens = 0
|
|
|
|
| 964 |
if use_cache: # kept for BC (cache positions)
|
| 965 |
if past_key_values is not None and not isinstance(
|
| 966 |
past_key_values, StaticCache
|
| 967 |
):
|
| 968 |
if not isinstance(past_key_values, DynamicCache):
|
| 969 |
-
used_legacy_cache=True
|
| 970 |
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
| 971 |
past_seen_tokens = past_key_values.get_seq_length()
|
| 972 |
|
|
@@ -1038,7 +1039,11 @@ class LlamaModel(LlamaPreTrainedModel):
|
|
| 1038 |
|
| 1039 |
next_cache = None
|
| 1040 |
if use_cache:
|
| 1041 |
-
next_cache =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1042 |
if not return_dict:
|
| 1043 |
return tuple(
|
| 1044 |
v
|
|
|
|
| 961 |
inputs_embeds = self.embed_tokens(input_ids)
|
| 962 |
|
| 963 |
past_seen_tokens = 0
|
| 964 |
+
used_legacy_cache = False
|
| 965 |
if use_cache: # kept for BC (cache positions)
|
| 966 |
if past_key_values is not None and not isinstance(
|
| 967 |
past_key_values, StaticCache
|
| 968 |
):
|
| 969 |
if not isinstance(past_key_values, DynamicCache):
|
| 970 |
+
used_legacy_cache = True
|
| 971 |
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
| 972 |
past_seen_tokens = past_key_values.get_seq_length()
|
| 973 |
|
|
|
|
| 1039 |
|
| 1040 |
next_cache = None
|
| 1041 |
if use_cache:
|
| 1042 |
+
next_cache = (
|
| 1043 |
+
next_decoder_cache.to_legacy_cache()
|
| 1044 |
+
if used_legacy_cache
|
| 1045 |
+
else next_decoder_cache
|
| 1046 |
+
)
|
| 1047 |
if not return_dict:
|
| 1048 |
return tuple(
|
| 1049 |
v
|