whynlp commited on
Commit
61c49cc
1 Parent(s): 2a52d2e

Upload LCKVLlamaForCausalLM

Browse files
Files changed (8) hide show
  1. README.md +199 -0
  2. cache_utils.py +550 -0
  3. config.json +41 -0
  4. configuration_lckv.py +81 -0
  5. generation_config.json +6 -0
  6. model.safetensors +3 -0
  7. modeling_lckv.py +807 -0
  8. utils.py +190 -0
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
cache_utils.py ADDED
@@ -0,0 +1,550 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional, Tuple
2
+
3
+ import torch
4
+
5
+ from transformers.cache_utils import Cache, DynamicCache, SinkCache
6
+
7
+ from .utils import LayerTypeParser
8
+
9
+
10
+ class IndexedCache(Cache):
11
+ """
12
+ Similar to the `DynamicCache` class, but with the ability to index the cache by layer index. DynamicCache
13
+ assumes that all layers compute KVs, while IndexedCache allows for a more flexible cache structure.
14
+ """
15
+ build_position_ids_based_on_cache = False
16
+
17
+ def __init__(self) -> None:
18
+ super().__init__()
19
+ self.key_cache: Dict[int, torch.Tensor] = {}
20
+ self.value_cache: Dict[int, torch.Tensor] = {}
21
+ self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
22
+ self._update = True # to prevent the cache from updating when inference with iterations
23
+
24
+ def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
25
+ """
26
+ Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
27
+ sequence length.
28
+ """
29
+ if layer_idx in self.key_cache:
30
+ return (self.key_cache[layer_idx], self.value_cache[layer_idx])
31
+ else:
32
+ raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
33
+
34
+ def __iter__(self):
35
+ """
36
+ Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
37
+ keys and values
38
+ """
39
+ for layer_idx in sorted(self.key_cache.keys()):
40
+ yield (self.key_cache[layer_idx], self.value_cache[layer_idx])
41
+
42
+ def __len__(self):
43
+ """
44
+ Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
45
+ to the number of layers that compute KVs in the model.
46
+ """
47
+ return len(self.key_cache)
48
+
49
+ @property
50
+ def min_layer(self) -> int:
51
+ return min(self.key_cache.keys()) if len(self.key_cache) > 0 else None
52
+
53
+ def is_min_layer(self, layer_idx: int) -> bool:
54
+ return self.min_layer is None or self.min_layer == layer_idx
55
+
56
+ def update(
57
+ self,
58
+ key_states: torch.Tensor,
59
+ value_states: torch.Tensor,
60
+ layer_idx: int,
61
+ cache_kwargs: Optional[Dict[str, Any]] = None,
62
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
63
+ """
64
+ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
65
+
66
+ Parameters:
67
+ key_states (`torch.Tensor`):
68
+ The new key states to cache.
69
+ value_states (`torch.Tensor`):
70
+ The new value states to cache.
71
+ layer_idx (`int`):
72
+ The index of the layer to cache the states for.
73
+ cache_kwargs (`Dict[str, Any]`, `optional`):
74
+ Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
75
+
76
+ Return:
77
+ A tuple containing the updated key and value states.
78
+ """
79
+ # Update the number of seen tokens
80
+ if self.is_min_layer(layer_idx):
81
+ self._seen_tokens += key_states.shape[-2]
82
+
83
+ # Retrieve the cache
84
+ if layer_idx not in self.key_cache:
85
+ new_key_states = key_states
86
+ new_value_states = value_states
87
+ else:
88
+ new_key_states = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
89
+ new_value_states = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
90
+
91
+ # Update the cache
92
+ if self._update:
93
+ self.key_cache[layer_idx] = new_key_states
94
+ self.value_cache[layer_idx] = new_value_states
95
+
96
+ return new_key_states, new_value_states
97
+
98
+ def get_seq_length(self, layer_idx: Optional[int] = None) -> int:
99
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
100
+ if layer_idx is None:
101
+ layer_idx = self.min_layer
102
+
103
+ # TODO: deprecate this function in favor of `cache_position`
104
+ is_empty_layer = (
105
+ (len(self.key_cache) == 0) # no cache in any layer
106
+ or (layer_idx not in self.key_cache) # skipped `layer_idx` and hasn't run a layer with cache after it
107
+ )
108
+ layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0
109
+ return layer_seq_length
110
+
111
+ def get_max_length(self) -> Optional[int]:
112
+ """Returns the maximum sequence length of the cached states. IndexedCache does not have a maximum length."""
113
+ return None
114
+
115
+ @classmethod
116
+ def from_cache(cls, dynamic_cache: DynamicCache, *args, **kwargs) -> "IndexedCache":
117
+ """Converts a dynamic cache into an equivalent `IndexedCache`."""
118
+ cache = cls(*args, **kwargs)
119
+
120
+ cache._seen_tokens = dynamic_cache._seen_tokens
121
+ for layer_idx in range(len(dynamic_cache.key_cache)):
122
+ key_states, value_states = dynamic_cache[layer_idx]
123
+ cache.update(key_states, value_states, layer_idx)
124
+
125
+ return cache
126
+
127
+
128
+ class IndexedSinkCache(Cache):
129
+ """
130
+ This is a fix to the SinkCache class in the transformers library. It also allows for the cache to be indexed by
131
+ layer index, similar to the `IndexedCache` class.
132
+ """
133
+ build_position_ids_based_on_cache = True
134
+
135
+ def __init__(self, window_length: int = None, num_sink_tokens: int = None) -> None:
136
+ super().__init__()
137
+ self.key_cache: Dict[int, torch.Tensor] = {}
138
+ self.value_cache: Dict[int, torch.Tensor] = {}
139
+ self.window_length = window_length
140
+ self.num_sink_tokens = num_sink_tokens
141
+ self.cos_sin_rerotation_cache = {}
142
+ self._cos_cache = None
143
+ self._sin_cache = None
144
+ self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
145
+ self._update = True # to prevent the cache from updating when inference with iterations
146
+
147
+ @staticmethod
148
+ def _rotate_half(x):
149
+ x1 = x[..., : x.shape[-1] // 2]
150
+ x2 = x[..., x.shape[-1] // 2 :]
151
+ return torch.cat((-x2, x1), dim=-1)
152
+
153
+ def _apply_key_rotary_pos_emb(
154
+ self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
155
+ ) -> torch.Tensor:
156
+ rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin)
157
+ return rotated_key_states
158
+
159
+ def _get_rerotation_cos_sin(
160
+ self, offset: int, dtype: torch.dtype, cos: torch.Tensor, sin: torch.Tensor
161
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
162
+ if offset not in self.cos_sin_rerotation_cache:
163
+ # Upcast to float32 temporarily for better accuracy
164
+ cos = cos.to(torch.float32)
165
+ sin = sin.to(torch.float32)
166
+
167
+ # Compute the cos and sin required for back- and forward-rotating to one position earlier in the sequence
168
+ original_cos = cos[self.num_sink_tokens + offset :]
169
+ shifted_cos = cos[self.num_sink_tokens : -offset]
170
+ original_sin = sin[self.num_sink_tokens + offset :]
171
+ shifted_sin = sin[self.num_sink_tokens : -offset]
172
+ rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin
173
+ rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin
174
+
175
+ self.cos_sin_rerotation_cache[offset] = (
176
+ rerotation_cos.to(dtype).unsqueeze(0),
177
+ rerotation_sin.to(dtype).unsqueeze(0),
178
+ )
179
+ return self.cos_sin_rerotation_cache[offset]
180
+
181
+ @property
182
+ def min_layer(self) -> int:
183
+ return min(self.key_cache.keys()) if len(self.key_cache) > 0 else None
184
+
185
+ def is_min_layer(self, layer_idx: int) -> bool:
186
+ return self.min_layer is None or self.min_layer == layer_idx
187
+
188
+ def get_seq_length(self, layer_idx: Optional[int] = None) -> int:
189
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
190
+ # TODO: deprecate this function in favor of `cache_position`
191
+ # Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
192
+ if layer_idx is None:
193
+ layer_idx = self.min_layer
194
+
195
+ if layer_idx not in self.key_cache:
196
+ return 0
197
+
198
+ return self.key_cache[layer_idx].shape[-2]
199
+
200
+ def get_max_length(self) -> Optional[int]:
201
+ """Returns the maximum sequence length of the cached states."""
202
+ return self.window_length
203
+
204
+ def update(
205
+ self,
206
+ key_states: torch.Tensor,
207
+ value_states: torch.Tensor,
208
+ layer_idx: int,
209
+ cache_kwargs: Optional[Dict[str, Any]] = None,
210
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
211
+ """
212
+ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
213
+
214
+ Parameters:
215
+ key_states (`torch.Tensor`):
216
+ The new key states to cache.
217
+ value_states (`torch.Tensor`):
218
+ The new value states to cache.
219
+ layer_idx (`int`):
220
+ The index of the layer to cache the states for.
221
+ cache_kwargs (`Dict[str, Any]`, `optional`):
222
+ Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`,
223
+ `cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the
224
+ rotation as the tokens are shifted.
225
+
226
+ Return:
227
+ A tuple containing the updated key and value states.
228
+ """
229
+ # Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models
230
+ # with partially rotated position embeddings, like Phi or Persimmon.
231
+ sin = cache_kwargs.get("sin")
232
+ cos = cache_kwargs.get("cos")
233
+ partial_rotation_size = cache_kwargs.get("partial_rotation_size")
234
+ using_rope = cos is not None and sin is not None
235
+
236
+ # Update the number of seen tokens
237
+ if self.is_min_layer(layer_idx):
238
+ self._seen_tokens += key_states.shape[-2]
239
+
240
+ # Update the sin/cos cache, which holds sin/cos values for all possible positions
241
+ if using_rope and self.is_min_layer(layer_idx):
242
+ # BC: some models still pass `sin`/`cos` with 2 dims. In those models, they are the full sin/cos. Remove
243
+ # after all RoPE models have a llama-like cache utilization.
244
+ if cos.dim() == 2:
245
+ self._cos_cache = cos
246
+ self._sin_cache = sin
247
+ else:
248
+ if self._cos_cache is None:
249
+ self._cos_cache = cos[0, ...]
250
+ self._sin_cache = sin[0, ...]
251
+ elif self._cos_cache.shape[0] < self.window_length + key_states.shape[-2]:
252
+ self._cos_cache = torch.cat([self._cos_cache[: self.window_length], cos[0, ...]], dim=0)
253
+ self._sin_cache = torch.cat([self._sin_cache[: self.window_length], sin[0, ...]], dim=0)
254
+
255
+ # [bsz, num_heads, seq_len, head_dim]
256
+ if layer_idx not in self.key_cache:
257
+ # Empty cache
258
+ new_key_states = key_states
259
+ new_value_states = value_states
260
+
261
+ else:
262
+ # Growing cache
263
+ new_key_states = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
264
+ new_value_states = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
265
+
266
+ if self._update:
267
+ self.key_cache[layer_idx] = new_key_states
268
+ self.value_cache[layer_idx] = new_value_states
269
+
270
+ # If the cache is full, we need to shift the cache
271
+ if (seq_length := self.get_seq_length(layer_idx)) > self.window_length:
272
+ # Shifting cache
273
+ keys_to_keep = self.key_cache[layer_idx][:, :, -self.window_length + self.num_sink_tokens :]
274
+
275
+ # On RoPE models, we need to recompute the Key rotation as the tokens are shifted
276
+ if using_rope:
277
+ rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(
278
+ seq_length - self.window_length,
279
+ key_states.dtype,
280
+ self._cos_cache[:seq_length],
281
+ self._sin_cache[:seq_length],
282
+ )
283
+ if partial_rotation_size is not None:
284
+ keys_to_keep, keys_pass = (
285
+ keys_to_keep[..., :partial_rotation_size],
286
+ keys_to_keep[..., partial_rotation_size:],
287
+ )
288
+ keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, rerotation_cos, rerotation_sin)
289
+ if partial_rotation_size is not None:
290
+ keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1)
291
+
292
+ # Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens
293
+ sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens]
294
+ self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep], dim=-2)
295
+
296
+ sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens]
297
+ values_to_keep = self.value_cache[layer_idx][:, :, -self.window_length + self.num_sink_tokens :]
298
+ self.value_cache[layer_idx] = torch.cat([sink_values, values_to_keep], dim=-2)
299
+
300
+ return new_key_states, new_value_states
301
+
302
+ @classmethod
303
+ def from_cache(cls, sink_cache: SinkCache, *args, **kwargs) -> "IndexedSinkCache":
304
+ """Converts a dynamic cache into an equivalent `IndexedCache`."""
305
+ cache = cls(*args, **kwargs)
306
+
307
+ cache.window_length = sink_cache.window_length
308
+ cache.num_sink_tokens = sink_cache.num_sink_tokens
309
+ cache._seen_tokens = sink_cache._seen_tokens
310
+ cache._cos_cache = sink_cache._cos_cache
311
+ cache._sin_cache = sink_cache._sin_cache
312
+ cache.cos_sin_rerotation_cache = sink_cache.cos_sin_rerotation_cache
313
+ for layer_idx in range(len(sink_cache.key_cache)):
314
+ cache.key_cache[layer_idx] = sink_cache.key_cache[layer_idx]
315
+ cache.value_cache[layer_idx] = sink_cache.value_cache[layer_idx]
316
+
317
+ return cache
318
+
319
+
320
+ class IndexedSlidingWindowCache(IndexedCache):
321
+ """
322
+ Similar to the `SlidingWindowCache` class, but with the ability to index the cache by layer index. It is no longer
323
+ a subclass of `StaticCache` as it is dynamic.
324
+ """
325
+ build_position_ids_based_on_cache = False
326
+
327
+ def __init__(self, sliding_window: int = None) -> None:
328
+ super().__init__()
329
+ self.sliding_window = sliding_window
330
+
331
+ def update(
332
+ self,
333
+ key_states: torch.Tensor,
334
+ value_states: torch.Tensor,
335
+ layer_idx: int,
336
+ cache_kwargs: Optional[Dict[str, Any]] = None,
337
+ ) -> Tuple[torch.Tensor]:
338
+ # Update the number of seen tokens
339
+ if self.is_min_layer(layer_idx):
340
+ self._seen_tokens += key_states.shape[-2]
341
+
342
+ # [bsz, num_heads, seq_len, head_dim]
343
+ if layer_idx not in self.key_cache:
344
+ # Empty cache
345
+ new_key_states = key_states
346
+ new_value_states = value_states
347
+
348
+ else:
349
+ # Growing cache
350
+ new_key_states = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
351
+ new_value_states = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
352
+
353
+ if self._update:
354
+ self.key_cache[layer_idx] = new_key_states
355
+ self.value_cache[layer_idx] = new_value_states
356
+
357
+ # If the cache is full, we need to shift the cache
358
+ if self.get_seq_length(layer_idx) > self.sliding_window:
359
+ self.key_cache[layer_idx] = self.key_cache[layer_idx][:, :, -self.sliding_window :]
360
+ self.value_cache[layer_idx] = self.value_cache[layer_idx][:, :, -self.sliding_window :]
361
+
362
+ return new_key_states, new_value_states
363
+
364
+ def get_max_length(self) -> Optional[int]:
365
+ return self.sliding_window
366
+
367
+ @classmethod
368
+ def from_cache(cls, sliding_window_cache: "IndexedSlidingWindowCache", *args, **kwargs) -> "IndexedSlidingWindowCache":
369
+ """This is to override the `from_cache` method in the `IndexedCache` class."""
370
+ cache = cls(*args, **kwargs)
371
+
372
+ cache._seen_tokens = sliding_window_cache._seen_tokens
373
+ cache.sliding_window = sliding_window_cache.sliding_window
374
+ for layer_idx in range(len(sliding_window_cache.key_cache)):
375
+ cache.key_cache[layer_idx] = sliding_window_cache.key_cache[layer_idx]
376
+ cache.value_cache[layer_idx] = sliding_window_cache.value_cache[layer_idx]
377
+
378
+ return cache
379
+
380
+
381
+ class IndexedHybridCache(IndexedSlidingWindowCache, IndexedCache):
382
+ """
383
+ Hybrid Cache class to be used for models that alternate between a local sliding window attention and global
384
+ attention in every other layer. Under the hood, Hybrid Cache leverages ["IndexedSlidingWindowCache"] for
385
+ sliding window attention and ["IndexedCache"] for global attention.
386
+ """
387
+ build_position_ids_based_on_cache = False
388
+
389
+ def __init__(self, parser: LayerTypeParser = None, sliding_window: int = None) -> None:
390
+ super().__init__(sliding_window=sliding_window)
391
+ self.parser = parser
392
+
393
+ def update(
394
+ self,
395
+ key_states: torch.Tensor,
396
+ value_states: torch.Tensor,
397
+ layer_idx: int,
398
+ cache_kwargs: Optional[Dict[str, Any]] = None,
399
+ ) -> Tuple[torch.Tensor]:
400
+ if self.parser[layer_idx].use_sliding_window:
401
+ return IndexedSlidingWindowCache.update(self, key_states, value_states, layer_idx, cache_kwargs)
402
+ else:
403
+ return IndexedCache.update(self, key_states, value_states, layer_idx, cache_kwargs)
404
+
405
+ def get_max_length(self) -> Optional[int]:
406
+ return IndexedCache.get_max_length(self)
407
+
408
+ @classmethod
409
+ def from_cache(cls, hybrid_cache: "IndexedHybridCache", *args, **kwargs) -> "IndexedHybridCache":
410
+ """This is to override the `from_cache` method in the `IndexedSlidingWindowCache` class."""
411
+ cache = cls(*args, **kwargs)
412
+
413
+ cache._seen_tokens = hybrid_cache._seen_tokens
414
+ cache.sliding_window = hybrid_cache.sliding_window
415
+ cache.parser = hybrid_cache.parser
416
+ for layer_idx in range(len(hybrid_cache.key_cache)):
417
+ cache.key_cache[layer_idx] = hybrid_cache.key_cache[layer_idx]
418
+ cache.value_cache[layer_idx] = hybrid_cache.value_cache[layer_idx]
419
+
420
+ return cache
421
+
422
+
423
+ class LayerCache(torch.nn.Module):
424
+ """
425
+ A cache for storing the key-value pairs for layers.
426
+ """
427
+ def __init__(self) -> None:
428
+ """
429
+ The placeholder is used to expand the key-value pairs if the layer attends to the top layers.
430
+ Size: (batch_size, num_key_value_heads, 1, head_dim)
431
+ """
432
+ super().__init__()
433
+ self.key_layer_cache: Dict[int, torch.Tensor] = {}
434
+ self.value_layer_cache: Dict[int, torch.Tensor] = {}
435
+ self.layer_type = None
436
+ self.placeholder = None
437
+
438
+ def setup(self, placeholder: torch.Tensor):
439
+ """setup the cache, calling this function is necessary if there is a layer that attends to the top layers"""
440
+ self.placeholder = placeholder
441
+
442
+ def initialize(self, parser: LayerTypeParser, sequence_length: int):
443
+ """initialize the cache"""
444
+ layers_to_init = {parser[idx].attends_to for idx in range(len(parser)) if parser[idx].attends_top}
445
+
446
+ if layers_to_init:
447
+ b, h, _, d = self.placeholder.size()
448
+ init_kvs = self.placeholder.new_zeros((b, h, sequence_length, d))
449
+
450
+ for layer_idx in layers_to_init:
451
+ self.layer_append(layer_idx, init_kvs, init_kvs)
452
+
453
+ def layer_get(self, layer_idx: int, zerofill: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
454
+ key_states = self.key_layer_cache.get(layer_idx, None)
455
+ value_states = self.value_layer_cache.get(layer_idx, None)
456
+
457
+ if zerofill:
458
+ if key_states is None:
459
+ key_states = self.placeholder
460
+ value_states = self.placeholder
461
+ else:
462
+ key_states = torch.cat([self.placeholder, key_states], dim=2)
463
+ value_states = torch.cat([self.placeholder, value_states], dim=2)
464
+
465
+ return key_states, value_states
466
+
467
+ def layer_set(self, layer_idx: int, key: torch.Tensor, value: torch.Tensor):
468
+ self.key_layer_cache[layer_idx] = key
469
+ self.value_layer_cache[layer_idx] = value
470
+
471
+ def layer_append(self, layer_idx: int, key: torch.Tensor, value: torch.Tensor):
472
+ if layer_idx not in self.key_layer_cache:
473
+ self.key_layer_cache[layer_idx] = key
474
+ self.value_layer_cache[layer_idx] = value
475
+ else:
476
+ self.key_layer_cache[layer_idx] = torch.cat([self.key_layer_cache[layer_idx], key], dim=2)
477
+ self.value_layer_cache[layer_idx] = torch.cat([self.value_layer_cache[layer_idx], value], dim=2)
478
+
479
+
480
+ class LayerIndexedCache(LayerCache, IndexedCache):
481
+ """
482
+ A cache for storing the key-value pairs for layers, in combination with the ability of standard KV cache.
483
+ """
484
+ def __init__(self) -> None:
485
+ LayerCache.__init__(self)
486
+ IndexedCache.__init__(self)
487
+
488
+
489
+ class LayerIndexedSinkCache(LayerCache, IndexedSinkCache):
490
+ """
491
+ A cache for storing the key-value pairs for layers, in combination with the ability of sink KV cache.
492
+ """
493
+ def __init__(self) -> None:
494
+ LayerCache.__init__(self)
495
+ IndexedSinkCache.__init__(self)
496
+
497
+
498
+ class LayerIndexedSlidingWindowCache(LayerCache, IndexedSlidingWindowCache):
499
+ """
500
+ A cache for storing the key-value pairs for layers, in combination with the ability of sliding window KV cache.
501
+ """
502
+ def __init__(self) -> None:
503
+ LayerCache.__init__(self)
504
+ IndexedSlidingWindowCache.__init__(self)
505
+
506
+
507
+ class LayerIndexedHybridCache(LayerCache, IndexedHybridCache):
508
+ """
509
+ A cache for storing the key-value pairs for layers, in combination with the ability of hybrid KV cache.
510
+ """
511
+ def __init__(self) -> None:
512
+ LayerCache.__init__(self)
513
+ IndexedHybridCache.__init__(self)
514
+
515
+
516
+ class AutoLayerCache(torch.nn.Module):
517
+ """
518
+ AutoLayerCache is a module that automatically creates a cache from an existing cache.
519
+ """
520
+ CACHE_MAPPING = {
521
+ DynamicCache: LayerIndexedCache,
522
+ SinkCache: LayerIndexedSinkCache,
523
+ IndexedSlidingWindowCache: LayerIndexedSlidingWindowCache,
524
+ IndexedHybridCache: LayerIndexedHybridCache,
525
+ }
526
+
527
+ def __init__(self, *args, **kwargs):
528
+ raise RuntimeError(
529
+ f"{self.__class__.__name__} is designed to be instantiated "
530
+ f"using the `{self.__class__.__name__}.from_cache(cache)` method."
531
+ )
532
+
533
+ @classmethod
534
+ def from_cache(cls, cache: Cache, *args, **kwargs):
535
+ """
536
+ Create a new cache from an existing cache. The new cache will have the same type as the original cache.
537
+ """
538
+ cache_type = type(cache)
539
+ if cache_type not in cls.CACHE_MAPPING:
540
+ raise ValueError(f"Cache type {cache_type} is not supported by {cls.__name__}.")
541
+
542
+ cache_class = cls.CACHE_MAPPING[cache_type]
543
+
544
+ if hasattr(cache_class, "from_cache"):
545
+ return cache_class.from_cache(cache, *args, **kwargs)
546
+ else:
547
+ # we init an empty cache and copy the attributes
548
+ new_cache = cache_class(*args, **kwargs)
549
+ new_cache.__dict__.update(cache.__dict__)
550
+ return new_cache
config.json ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "outputs/tinyllama-lckv-w10-ft-250b",
3
+ "architectures": [
4
+ "LCKVLlamaForCausalLM"
5
+ ],
6
+ "attention_bias": false,
7
+ "attention_dropout": 0.0,
8
+ "auto_map": {
9
+ "AutoConfig": "configuration_lckv.LCKVLlamaConfig",
10
+ "AutoModelForCausalLM": "modeling_lckv.LCKVLlamaForCausalLM"
11
+ },
12
+ "backward_passes": 2,
13
+ "bos_token_id": 1,
14
+ "eos_token_id": 2,
15
+ "force_nodiag": true,
16
+ "forward_passes": 7,
17
+ "head_dim": 64,
18
+ "hidden_act": "silu",
19
+ "hidden_size": 2048,
20
+ "initializer_range": 0.02,
21
+ "intermediate_size": 5632,
22
+ "layer_types": "0_1_2_3_4_16_16_16_16_16_16_16_16_16_16_16_16_17_18_19_20_21",
23
+ "max_position_embeddings": 2048,
24
+ "mlp_bias": false,
25
+ "model_type": "lckv-llama",
26
+ "num_attention_heads": 32,
27
+ "num_hidden_layers": 22,
28
+ "num_key_value_heads": 4,
29
+ "pretraining_tp": 1,
30
+ "rms_norm_eps": 1e-05,
31
+ "rope_scaling": null,
32
+ "rope_theta": 10000.0,
33
+ "sliding_window": 4096,
34
+ "tie_word_embeddings": false,
35
+ "tokenizer_class": "LlamaTokenizer",
36
+ "torch_dtype": "bfloat16",
37
+ "transformers_version": "4.45.2",
38
+ "use_cache": true,
39
+ "use_sequential": false,
40
+ "vocab_size": 32000
41
+ }
configuration_lckv.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """ LCKV LLaMA model configuration"""
21
+ from transformers.models.llama.configuration_llama import LlamaConfig
22
+
23
+ from .utils import LayerTypeParser
24
+
25
+
26
+ class LCKVLlamaConfig(LlamaConfig):
27
+
28
+ model_type = "lckv-llama"
29
+
30
+ def __init__(
31
+ self,
32
+ layer_types: str = None,
33
+ forward_passes: int = 7,
34
+ backward_passes: int = 2,
35
+ sliding_window: int = 4096,
36
+ use_sequential: bool = False,
37
+ force_nodiag: bool = False,
38
+ **kwargs,
39
+ ):
40
+ """
41
+ Initialize a LCKV LLaMA configuration. Instantiating a configuration with the defaults
42
+ will yield a similar configuration to that of the LLaMA-7B with the standard transformer
43
+ training scheme.
44
+
45
+ Args:
46
+ layer_types (`str`, *optional*):
47
+ A string of integers separated by underscores. The i-th integer means the layer
48
+ will use the key-value pair in the i-th layer as the kv cache. Special characters
49
+ may be placed after the integers:
50
+ - `s` means the layer will use sliding window attention.
51
+ The default value is "0_1_2_..." till the number of layers in the current config.
52
+ forward_passes (`int`, *optional*, defaults to 7):
53
+ The number of forward passes during training and prompt encoding. Equivlent
54
+ to `m` in the paper.
55
+ backward_passes (`int`, *optional*, defaults to 2):
56
+ The number of backward passes during training and prompt encoding. Equivlent
57
+ to `b` in the paper.
58
+ sliding_window (`int`, *optional*, defaults to 4096):
59
+ Sliding window attention window size. If not specified, will default to `4096`.
60
+ It will only be effective if the corresponding layer uses sliding window attention.
61
+ use_sequential (`bool`, *optional*, defaults to False):
62
+ Whether to do forwarding sequentially, token by token. Useful for testing purpose
63
+ for models with cyclic dependency. Also can be used for sequential training.
64
+ force_nodiag (`bool`, *optional*, defaults to False):
65
+ Whether to force the model to not use the diagonal attention. By default, the model
66
+ will mask the diagonal attention only in layers necessary. If set to `True`, the model
67
+ will never use the diagonal attention in any layer. This is mainly for backward compatibility.
68
+ """
69
+ super().__init__(**kwargs)
70
+ self.layer_types = layer_types
71
+ self.forward_passes = forward_passes
72
+ self.backward_passes = backward_passes
73
+ self.sliding_window = sliding_window
74
+ self.use_sequential = use_sequential
75
+ self.force_nodiag = force_nodiag
76
+
77
+ if self.layer_types is None:
78
+ self.layer_types = "_".join(map(str, range(self.num_hidden_layers)))
79
+
80
+ # post check
81
+ LayerTypeParser(self.layer_types).check(self.num_hidden_layers)
generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "transformers_version": "4.45.2"
6
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6d07eabdbaa1683be0a28d1af9a6495ee3f4201b9d04bd9ab9ac4e831188c9b2
3
+ size 2177048648
modeling_lckv.py ADDED
@@ -0,0 +1,807 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """ PyTorch LLaMA model."""
21
+ import copy
22
+ import math
23
+ from typing import List, Optional, Tuple, Union
24
+
25
+ import torch
26
+ from torch import nn
27
+
28
+ from transformers.cache_utils import Cache, StaticCache
29
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
30
+ from transformers.modeling_outputs import BaseModelOutputWithPast
31
+ from transformers.models.llama.modeling_llama import (
32
+ LLAMA_INPUTS_DOCSTRING,
33
+ LlamaAttention,
34
+ LlamaDecoderLayer,
35
+ LlamaForCausalLM,
36
+ LlamaModel,
37
+ LlamaPreTrainedModel,
38
+ _prepare_4d_causal_attention_mask_with_cache_position,
39
+ logger,
40
+ repeat_kv,
41
+ rotate_half,
42
+ )
43
+ from transformers.utils import add_start_docstrings_to_model_forward, is_flash_attn_greater_or_equal_2_10
44
+
45
+ from .cache_utils import AutoLayerCache, LayerCache
46
+ from .configuration_lckv import LCKVLlamaConfig
47
+ from .utils import IterStep, LayerTypeParser, flash_attention_forward
48
+
49
+
50
+ def apply_rotary(q, cos, sin, unsqueeze_dim=1):
51
+ cos = cos.unsqueeze(unsqueeze_dim)
52
+ sin = sin.unsqueeze(unsqueeze_dim)
53
+ q_embed = (q * cos) + (rotate_half(q) * sin)
54
+ return q_embed
55
+
56
+
57
+ class LCKVLlamaAttention(LlamaAttention):
58
+ """
59
+ LCKV Attention may not need to initialize weights for the key and value projections.
60
+ """
61
+
62
+ def __init__(self, config: LCKVLlamaConfig, layer_idx: Optional[int] = None):
63
+ super().__init__(config, layer_idx)
64
+ self.layer_type = LayerTypeParser(config.layer_types)[layer_idx]
65
+ self.sliding_window = config.sliding_window if self.layer_type.use_sliding_window else None
66
+
67
+ # Some layers may not need to compute key-value pairs
68
+ if not self.layer_type.computes_kv:
69
+ del self.k_proj
70
+ del self.v_proj
71
+
72
+ def forward(
73
+ self,
74
+ hidden_states: torch.Tensor,
75
+ attention_mask: Optional[torch.Tensor] = None,
76
+ position_ids: Optional[torch.LongTensor] = None,
77
+ past_key_value: Optional[Cache] = None,
78
+ output_attentions: bool = False,
79
+ use_cache: bool = False,
80
+ cache_position: Optional[torch.LongTensor] = None,
81
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
82
+ **kwargs,
83
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
84
+ bsz, q_len, _ = hidden_states.size()
85
+ cos, sin = position_embeddings
86
+
87
+ query_states = self.q_proj(hidden_states)
88
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
89
+ query_states = apply_rotary(query_states, cos, sin)
90
+
91
+ # compute key and value states
92
+ if self.layer_type.computes_kv:
93
+ key_states = self.k_proj(hidden_states)
94
+ value_states = self.v_proj(hidden_states)
95
+
96
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
97
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
98
+ key_states = apply_rotary(key_states, cos, sin)
99
+
100
+ if isinstance(past_key_value, Cache):
101
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
102
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
103
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
104
+
105
+ past_key_value.layer_set(self.layer_idx, key_states, value_states)
106
+
107
+ # get the cached key and value states
108
+ # if the layer attends to the top layers, there are two cases:
109
+ # 1. the query length is 1, in which case we will not do iterative updates. Therefore, the kv lacks the current
110
+ # query length and we need to fill it with zeros.
111
+ # 2. the query length is greater than 1, in which case we will do iterative updates and the kv will have the
112
+ # correct query length.
113
+ key_states, value_states = past_key_value.layer_get(
114
+ self.layer_type.attends_to,
115
+ zerofill=self.layer_type.attends_top and q_len == 1,
116
+ )
117
+
118
+ # handle GQA
119
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
120
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
121
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
122
+
123
+ if attention_mask is not None: # no matter the length, we just slice it
124
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
125
+ attn_weights = attn_weights + causal_mask
126
+
127
+ # diagonal mask from the right bottom corner
128
+ if self.config.force_nodiag or self.layer_type.attends_top:
129
+ kv_len = key_states.size(2)
130
+ mask = attn_weights.new_full((q_len, kv_len), torch.finfo(attn_weights.dtype).min)
131
+ mask = mask.tril(diagonal=kv_len - q_len).triu(diagonal=kv_len - q_len)
132
+ attn_weights = attn_weights + mask
133
+
134
+ # sliding window mask
135
+ if self.sliding_window:
136
+ kv_len = key_states.size(2)
137
+ mask = attn_weights.new_full((q_len, kv_len), torch.finfo(attn_weights.dtype).min)
138
+ mask = mask.tril(diagonal=kv_len - q_len - self.sliding_window)
139
+ attn_weights = attn_weights + mask
140
+
141
+ # upcast attention to fp32
142
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
143
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
144
+ attn_output = torch.matmul(attn_weights, value_states)
145
+
146
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
147
+ raise ValueError(
148
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
149
+ f" {attn_output.size()}"
150
+ )
151
+
152
+ attn_output = attn_output.transpose(1, 2).contiguous()
153
+ attn_output = attn_output.reshape(bsz, q_len, -1)
154
+ attn_output = self.o_proj(attn_output)
155
+
156
+ if not output_attentions:
157
+ attn_weights = None
158
+
159
+ return attn_output, attn_weights, past_key_value
160
+
161
+
162
+ class LCKVLlamaFlashAttention2(LCKVLlamaAttention):
163
+ """
164
+ LCKV Attention may not need to initialize weights for the key and value projections.
165
+ """
166
+
167
+ def __init__(self, config: LCKVLlamaConfig, layer_idx: Optional[int] = None):
168
+ super().__init__(config, layer_idx)
169
+
170
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
171
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
172
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
173
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
174
+
175
+ def forward(
176
+ self,
177
+ hidden_states: torch.Tensor,
178
+ attention_mask: Optional[torch.LongTensor] = None,
179
+ position_ids: Optional[torch.LongTensor] = None,
180
+ past_key_value: Optional[LayerCache] = None,
181
+ output_attentions: bool = False,
182
+ use_cache: bool = False,
183
+ cache_position: Optional[torch.LongTensor] = None,
184
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
185
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
186
+
187
+ output_attentions = False
188
+
189
+ bsz, q_len, _ = hidden_states.size()
190
+ cos, sin = position_embeddings
191
+
192
+ # Flash attention requires the input to have the shape
193
+ # batch_size x seq_length x head_dim x hidden_dim
194
+ # therefore we just need to keep the original shape
195
+ query_states = self.q_proj(hidden_states)
196
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
197
+ query_states = apply_rotary(query_states, cos, sin)
198
+
199
+ # compute key and value states
200
+ if self.layer_type.computes_kv:
201
+ key_states = self.k_proj(hidden_states)
202
+ value_states = self.v_proj(hidden_states)
203
+
204
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
205
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
206
+ key_states = apply_rotary(key_states, cos, sin)
207
+
208
+ if isinstance(past_key_value, Cache):
209
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
210
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
211
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
212
+
213
+ past_key_value.layer_set(self.layer_idx, key_states, value_states)
214
+
215
+ # get the cached key and value states
216
+ # if the layer attends to the top layers, there are two cases:
217
+ # 1. the query length is 1, in which case we will not do iterative updates. Therefore, the kv lacks the current
218
+ # query length and we need to fill it with zeros.
219
+ # 2. the query length is greater than 1, in which case we will do iterative updates and the kv will have the
220
+ # correct query length.
221
+ key_states, value_states = past_key_value.layer_get(
222
+ self.layer_type.attends_to,
223
+ zerofill=self.layer_type.attends_top and q_len == 1,
224
+ )
225
+
226
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
227
+ # to be able to avoid many of these transpose/reshape/view.
228
+ query_states = query_states.transpose(1, 2)
229
+ key_states = key_states.transpose(1, 2)
230
+ value_states = value_states.transpose(1, 2)
231
+
232
+ dropout_rate = self.attention_dropout if self.training else 0.0
233
+
234
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
235
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
236
+ # cast them back in the correct dtype just to be sure everything works as expected.
237
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
238
+ # in fp32. (LlamaRMSNorm handles it correctly)
239
+
240
+ input_dtype = query_states.dtype
241
+ if input_dtype == torch.float32:
242
+ if torch.is_autocast_enabled():
243
+ target_dtype = torch.get_autocast_gpu_dtype()
244
+ # Handle the case where the model is quantized
245
+ elif hasattr(self.config, "_pre_quantization_dtype"):
246
+ target_dtype = self.config._pre_quantization_dtype
247
+ else:
248
+ target_dtype = self.q_proj.weight.dtype
249
+
250
+ logger.warning_once(
251
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
252
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
253
+ f" {target_dtype}."
254
+ )
255
+
256
+ query_states = query_states.to(target_dtype)
257
+ key_states = key_states.to(target_dtype)
258
+ value_states = value_states.to(target_dtype)
259
+
260
+ attn_output = flash_attention_forward(
261
+ query_states,
262
+ key_states,
263
+ value_states,
264
+ attention_mask,
265
+ q_len,
266
+ position_ids=position_ids,
267
+ dropout=dropout_rate,
268
+ sliding_window=self.sliding_window,
269
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
270
+ is_causal=self.is_causal,
271
+ no_diag=(self.config.force_nodiag or self.layer_type.attends_top),
272
+ )
273
+
274
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
275
+ attn_output = self.o_proj(attn_output)
276
+
277
+ if not output_attentions:
278
+ attn_weights = None
279
+
280
+ return attn_output, attn_weights, past_key_value
281
+
282
+
283
+ LCKV_LLAMA_ATTENTION_CLASSES = {
284
+ "eager": LCKVLlamaAttention,
285
+ "flash_attention_2": LCKVLlamaFlashAttention2,
286
+ }
287
+
288
+
289
+ class LCKVLlamaDecoderLayer(LlamaDecoderLayer):
290
+ def __init__(self, config: LCKVLlamaConfig, layer_idx: int):
291
+ super().__init__(config, layer_idx)
292
+ self.self_attn = LCKV_LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
293
+
294
+
295
+ class LCKVLlamaPreTrainedModel(LlamaPreTrainedModel):
296
+ config_class = LCKVLlamaConfig
297
+ supports_gradient_checkpointing = False # not tested yet
298
+ _no_split_modules = ["LCKVLlamaDecoderLayer"]
299
+ _supports_flash_attn_2 = True
300
+ _supports_sdpa = False
301
+
302
+
303
+ class LCKVLlamaModel(LCKVLlamaPreTrainedModel, LlamaModel):
304
+ def __init__(self, config: LCKVLlamaConfig):
305
+ LCKVLlamaPreTrainedModel.__init__(self, config)
306
+ LlamaModel.__init__(self, copy.deepcopy(config)) # copy config to avoid modifying the original
307
+ self.layers = nn.ModuleList([LCKVLlamaDecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
308
+ self.parser = LayerTypeParser(config.layer_types)
309
+
310
+ # Initialize weights and apply final processing
311
+ self.post_init()
312
+
313
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
314
+ def forward(
315
+ self,
316
+ input_ids: torch.LongTensor = None,
317
+ attention_mask: Optional[torch.Tensor] = None,
318
+ position_ids: Optional[torch.LongTensor] = None,
319
+ past_key_values: Optional[LayerCache] = None,
320
+ inputs_embeds: Optional[torch.FloatTensor] = None,
321
+ use_cache: Optional[bool] = None,
322
+ output_attentions: Optional[bool] = None,
323
+ output_hidden_states: Optional[bool] = None,
324
+ return_dict: Optional[bool] = None,
325
+ cache_position: Optional[torch.LongTensor] = None,
326
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
327
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
328
+ output_hidden_states = (
329
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
330
+ )
331
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
332
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
333
+
334
+ if (input_ids is None) ^ (inputs_embeds is not None):
335
+ raise ValueError(
336
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
337
+ )
338
+
339
+ if self.gradient_checkpointing and self.training and use_cache:
340
+ logger.warning_once(
341
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
342
+ )
343
+ use_cache = False
344
+
345
+ if inputs_embeds is None:
346
+ inputs_embeds = self.embed_tokens(input_ids)
347
+
348
+ # build the cache object
349
+ if not isinstance(past_key_values, LayerCache):
350
+ placeholder = inputs_embeds.new_zeros(
351
+ inputs_embeds.shape[0],
352
+ self.config.num_key_value_heads,
353
+ 1,
354
+ getattr(self.config, "head_dim", self.config.hidden_size // self.config.num_attention_heads)
355
+ )
356
+
357
+ if past_key_values is None:
358
+ past_key_values = LayerCache()
359
+ elif isinstance(past_key_values, Cache):
360
+ past_key_values = AutoLayerCache.from_cache(past_key_values)
361
+ else:
362
+ raise NotImplementedError("Only DynamicCache is supported for now.")
363
+
364
+ past_key_values.setup(placeholder)
365
+
366
+ if cache_position is None:
367
+ past_seen_tokens = past_key_values.get_seq_length() if isinstance(past_key_values, Cache) else 0
368
+ cache_position = torch.arange(
369
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
370
+ )
371
+ if position_ids is None:
372
+ position_ids = cache_position.unsqueeze(0)
373
+
374
+ causal_mask = self._update_causal_mask(
375
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
376
+ )
377
+ hidden_states = inputs_embeds
378
+
379
+ # create position embeddings to be shared across the decoder layers
380
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
381
+
382
+ # whether to forward sequentially
383
+ use_sequential = (
384
+ self.config.use_sequential
385
+ or inputs_embeds.shape[1] <= self.config.forward_passes + self.config.backward_passes
386
+ and self.parser.attends_top()
387
+ )
388
+
389
+ if use_sequential:
390
+
391
+ iteration_outputs = self._modeling_sequential(
392
+ hidden_states,
393
+ attention_mask=causal_mask,
394
+ position_ids=position_ids,
395
+ past_key_values=past_key_values,
396
+ output_attentions=output_attentions,
397
+ use_cache=use_cache,
398
+ cache_position=cache_position,
399
+ position_embeddings=position_embeddings,
400
+ output_hidden_states=output_hidden_states,
401
+ )
402
+
403
+ else:
404
+
405
+ # initialize the cache
406
+ past_key_values.initialize(self.parser, inputs_embeds.shape[1])
407
+
408
+ # we need to do forward passes based on a plan if the input is a prompt
409
+ plan = self.parser.iteration_plan(self.config.forward_passes, self.config.backward_passes)
410
+
411
+ iteration_outputs = self._modeling_with_plan(
412
+ hidden_states,
413
+ attention_mask=causal_mask,
414
+ position_ids=position_ids,
415
+ past_key_values=past_key_values,
416
+ output_attentions=output_attentions,
417
+ use_cache=use_cache,
418
+ cache_position=cache_position,
419
+ position_embeddings=position_embeddings,
420
+ output_hidden_states=output_hidden_states,
421
+ modeling_plan=plan,
422
+ )
423
+
424
+ hidden_states = iteration_outputs.last_hidden_state
425
+ all_hidden_states = iteration_outputs.hidden_states
426
+ all_self_attns = iteration_outputs.attentions
427
+ next_decoder_cache = iteration_outputs.past_key_values
428
+
429
+ hidden_states = self.norm(hidden_states)
430
+
431
+ # add hidden states from the last decoder layer
432
+ if output_hidden_states:
433
+ all_hidden_states += (hidden_states,)
434
+
435
+ next_cache = next_decoder_cache if use_cache else None
436
+
437
+ if not return_dict:
438
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
439
+ return BaseModelOutputWithPast(
440
+ last_hidden_state=hidden_states,
441
+ past_key_values=next_cache,
442
+ hidden_states=all_hidden_states,
443
+ attentions=all_self_attns,
444
+ )
445
+
446
+ def _iterate_layers(
447
+ self,
448
+ hidden_states: torch.Tensor,
449
+ attention_mask: Optional[torch.LongTensor] = None,
450
+ position_ids: Optional[torch.LongTensor] = None,
451
+ past_key_values: Optional[LayerCache] = None,
452
+ output_attentions: bool = False,
453
+ use_cache: bool = False,
454
+ cache_position: Optional[torch.LongTensor] = None,
455
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
456
+ output_hidden_states: Optional[bool] = False,
457
+ layer_slice: Optional[slice] = None,
458
+ ) -> BaseModelOutputWithPast:
459
+ """
460
+ Iterates over the layers of the model, calling each layer in turn.
461
+ """
462
+ # decoder layers
463
+ all_hidden_states = () if output_hidden_states else None
464
+ all_self_attns = () if output_attentions else None
465
+ next_decoder_cache = None
466
+
467
+ # layers to compute
468
+ if layer_slice is None:
469
+ layer_slice = slice(None)
470
+
471
+ for decoder_layer in self.layers[layer_slice]:
472
+ if output_hidden_states:
473
+ all_hidden_states += (hidden_states,)
474
+
475
+ if self.gradient_checkpointing and self.training:
476
+ layer_outputs = self._gradient_checkpointing_func(
477
+ decoder_layer.__call__,
478
+ hidden_states,
479
+ attention_mask,
480
+ position_ids,
481
+ past_key_values,
482
+ output_attentions,
483
+ use_cache,
484
+ cache_position,
485
+ position_embeddings,
486
+ )
487
+ else:
488
+ layer_outputs = decoder_layer(
489
+ hidden_states,
490
+ attention_mask=attention_mask,
491
+ position_ids=position_ids,
492
+ past_key_value=past_key_values,
493
+ output_attentions=output_attentions,
494
+ use_cache=use_cache,
495
+ cache_position=cache_position,
496
+ position_embeddings=position_embeddings,
497
+ )
498
+
499
+ hidden_states = layer_outputs[0]
500
+
501
+ if use_cache:
502
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
503
+
504
+ if output_attentions:
505
+ all_self_attns += (layer_outputs[1],)
506
+
507
+ next_cache = next_decoder_cache if use_cache else None
508
+
509
+ return BaseModelOutputWithPast(
510
+ last_hidden_state=hidden_states,
511
+ past_key_values=next_cache,
512
+ hidden_states=all_hidden_states,
513
+ attentions=all_self_attns,
514
+ )
515
+
516
+ def _modeling_with_plan(
517
+ self,
518
+ hidden_states: torch.Tensor,
519
+ attention_mask: Optional[torch.LongTensor] = None,
520
+ position_ids: Optional[torch.LongTensor] = None,
521
+ past_key_values: Optional[LayerCache] = None,
522
+ output_attentions: bool = False,
523
+ use_cache: bool = False,
524
+ cache_position: Optional[torch.LongTensor] = None,
525
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
526
+ output_hidden_states: Optional[bool] = False,
527
+ modeling_plan: List[IterStep] = None,
528
+ ) -> BaseModelOutputWithPast:
529
+ """
530
+ Given a plan, iteratively update the hidden states.
531
+ """
532
+ # decoder layers
533
+ all_hidden_states = () if output_hidden_states else None
534
+ all_self_attns = () if output_attentions else None
535
+ next_decoder_cache = None
536
+
537
+ for step in modeling_plan:
538
+ end = len(self.layers) if step.layer_slice.stop is None else step.layer_slice.stop
539
+ iteration_func = self._iterate_layers if step.requires_grad else torch.no_grad()(self._iterate_layers)
540
+
541
+ if isinstance(past_key_values, Cache):
542
+ past_key_values._update = step.update
543
+
544
+ iteration_outputs = iteration_func(
545
+ hidden_states,
546
+ attention_mask=attention_mask,
547
+ position_ids=position_ids,
548
+ past_key_values=past_key_values,
549
+ output_attentions=output_attentions,
550
+ use_cache=use_cache,
551
+ cache_position=cache_position,
552
+ position_embeddings=position_embeddings,
553
+ output_hidden_states=output_hidden_states,
554
+ layer_slice=step.layer_slice
555
+ )
556
+
557
+ # Update the hidden states cache
558
+ if step.update:
559
+ hidden_states = iteration_outputs.last_hidden_state
560
+
561
+ if output_hidden_states:
562
+ all_hidden_states = all_hidden_states[:end] + iteration_outputs.hidden_states
563
+
564
+ if output_attentions:
565
+ all_self_attns = all_self_attns[:end] + iteration_outputs.attentions
566
+
567
+ if use_cache:
568
+ next_decoder_cache = iteration_outputs.past_key_values
569
+
570
+ return BaseModelOutputWithPast(
571
+ last_hidden_state=hidden_states,
572
+ past_key_values=next_decoder_cache,
573
+ hidden_states=all_hidden_states,
574
+ attentions=all_self_attns,
575
+ )
576
+
577
+ def _modeling_sequential(
578
+ self,
579
+ hidden_states: torch.Tensor,
580
+ attention_mask: Optional[torch.LongTensor] = None,
581
+ position_ids: Optional[torch.LongTensor] = None,
582
+ past_key_values: Optional[LayerCache] = None,
583
+ output_attentions: bool = False,
584
+ use_cache: bool = False,
585
+ cache_position: Optional[torch.LongTensor] = None,
586
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
587
+ output_hidden_states: Optional[bool] = False,
588
+ ) -> BaseModelOutputWithPast:
589
+ """
590
+ Sequentially update the hidden states, token by token.
591
+ """
592
+ seq_len = hidden_states.shape[1]
593
+ last_hidden_state = []
594
+ all_hidden_states = []
595
+ all_self_attns = []
596
+
597
+ for i in range(seq_len):
598
+ m_hidden_states = hidden_states[:, i:i+1]
599
+ m_attention_mask = (
600
+ (attention_mask[:, : i + 1] if attention_mask.ndim == 2 else attention_mask[:, :, i : i + 1])
601
+ if attention_mask is not None
602
+ else None
603
+ )
604
+ m_position_ids = position_ids[:, i:i+1] if position_ids is not None else None
605
+ m_cache_position = cache_position[i:i+1] if cache_position is not None else None
606
+ m_position_embeddings = (
607
+ position_embeddings[0][:, i:i+1],
608
+ position_embeddings[1][:, i:i+1]
609
+ )
610
+
611
+ outputs = self._iterate_layers(
612
+ m_hidden_states,
613
+ attention_mask=m_attention_mask,
614
+ position_ids=m_position_ids,
615
+ past_key_values=past_key_values,
616
+ output_attentions=output_attentions,
617
+ use_cache=use_cache,
618
+ cache_position=m_cache_position,
619
+ position_embeddings=m_position_embeddings,
620
+ output_hidden_states=output_hidden_states
621
+ )
622
+
623
+ last_hidden_state.append(outputs.last_hidden_state)
624
+
625
+ if output_hidden_states:
626
+ all_hidden_states.append(outputs.hidden_states)
627
+
628
+ if output_attentions:
629
+ all_self_attns.append(outputs.attentions)
630
+
631
+ if use_cache:
632
+ past_key_values = outputs.past_key_values
633
+
634
+ last_hidden_state = torch.cat(last_hidden_state, dim=1)
635
+
636
+ if output_hidden_states:
637
+ all_hidden_states = [
638
+ torch.cat([hs[i] for hs in all_hidden_states], dim=1) for i in range(len(all_hidden_states[0]))
639
+ ]
640
+
641
+ if output_attentions:
642
+ # TODO: deal with attention outputs for non-flash-attention implmentations
643
+ all_self_attns = all_self_attns[-1]
644
+
645
+ return BaseModelOutputWithPast(
646
+ last_hidden_state=last_hidden_state,
647
+ past_key_values=past_key_values,
648
+ hidden_states=all_hidden_states,
649
+ attentions=all_self_attns,
650
+ )
651
+
652
+ def _update_causal_mask(
653
+ self,
654
+ attention_mask: torch.Tensor,
655
+ input_tensor: torch.Tensor,
656
+ cache_position: torch.Tensor,
657
+ past_key_values: Cache,
658
+ output_attentions: bool,
659
+ ):
660
+ """fix this function to handle layer cache"""
661
+ if self.config._attn_implementation == "flash_attention_2":
662
+ if attention_mask is not None and 0.0 in attention_mask:
663
+ return attention_mask
664
+ return None
665
+
666
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
667
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
668
+ # to infer the attention mask.
669
+ past_seen_tokens = past_key_values.get_seq_length() if isinstance(past_key_values, Cache) else 0
670
+ using_static_cache = isinstance(past_key_values, StaticCache)
671
+
672
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
673
+ if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
674
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
675
+ attention_mask,
676
+ inputs_embeds=input_tensor,
677
+ past_key_values_length=past_seen_tokens,
678
+ is_training=self.training,
679
+ ):
680
+ return None
681
+
682
+ dtype, device = input_tensor.dtype, input_tensor.device
683
+ min_dtype = torch.finfo(dtype).min
684
+ sequence_length = input_tensor.shape[1]
685
+ if using_static_cache:
686
+ target_length = past_key_values.get_max_length()
687
+ else:
688
+ target_length = (
689
+ attention_mask.shape[-1]
690
+ if isinstance(attention_mask, torch.Tensor)
691
+ else past_seen_tokens + sequence_length + 1
692
+ )
693
+
694
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
695
+ causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
696
+ attention_mask,
697
+ sequence_length=sequence_length,
698
+ target_length=target_length,
699
+ dtype=dtype,
700
+ device=device,
701
+ min_dtype=min_dtype,
702
+ cache_position=cache_position,
703
+ batch_size=input_tensor.shape[0],
704
+ )
705
+
706
+ if (
707
+ self.config._attn_implementation == "sdpa"
708
+ and attention_mask is not None
709
+ and attention_mask.device.type == "cuda"
710
+ and not output_attentions
711
+ ):
712
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
713
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
714
+ # Details: https://github.com/pytorch/pytorch/issues/110213
715
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
716
+
717
+ return causal_mask
718
+
719
+
720
+ class LCKVLlamaForCausalLM(LCKVLlamaPreTrainedModel, LlamaForCausalLM):
721
+ def __init__(self, config):
722
+ LCKVLlamaPreTrainedModel.__init__(self, config)
723
+ LlamaForCausalLM.__init__(self, copy.deepcopy(config)) # copy config to avoid modifying the original
724
+ self.model = LCKVLlamaModel(config)
725
+
726
+ # Initialize weights and apply final processing
727
+ self.post_init()
728
+
729
+ def prepare_inputs_for_generation(
730
+ self,
731
+ input_ids,
732
+ past_key_values=None,
733
+ attention_mask=None,
734
+ inputs_embeds=None,
735
+ cache_position=None,
736
+ position_ids=None,
737
+ use_cache=True,
738
+ num_logits_to_keep=None,
739
+ **kwargs,
740
+ ):
741
+ """fix this function to handle sink cache"""
742
+ # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
743
+ # Exception 1: when passing input_embeds, input_ids may be missing entries
744
+ # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
745
+ if isinstance(past_key_values, Cache):
746
+ if inputs_embeds is not None: # Exception 1
747
+ input_ids = input_ids[:, -cache_position.shape[0] :]
748
+ elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
749
+ input_ids = input_ids[:, cache_position]
750
+
751
+ if attention_mask is not None and position_ids is None:
752
+ # create position_ids on the fly for batch generation
753
+ position_ids = attention_mask.long().cumsum(-1) - 1
754
+ position_ids.masked_fill_(attention_mask == 0, 1)
755
+ if isinstance(past_key_values, Cache):
756
+
757
+ if getattr(past_key_values, "build_position_ids_based_on_cache", False):
758
+ cur_cache_length = past_key_values.get_seq_length()
759
+ position_ids = position_ids[:, cur_cache_length :cur_cache_length + input_ids.shape[1]]
760
+ else:
761
+ position_ids = position_ids[:, -input_ids.shape[1] :]
762
+
763
+ # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
764
+ position_ids = position_ids.clone(memory_format=torch.contiguous_format)
765
+
766
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
767
+ if inputs_embeds is not None and cache_position[0] == 0:
768
+ model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
769
+ else:
770
+ # The clone here is for the same reason as for `position_ids`.
771
+ model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
772
+
773
+ if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
774
+ if model_inputs["inputs_embeds"] is not None:
775
+ batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
776
+ device = model_inputs["inputs_embeds"].device
777
+ else:
778
+ batch_size, sequence_length = model_inputs["input_ids"].shape
779
+ device = model_inputs["input_ids"].device
780
+
781
+ dtype = self.lm_head.weight.dtype
782
+ min_dtype = torch.finfo(dtype).min
783
+
784
+ attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
785
+ attention_mask,
786
+ sequence_length=sequence_length,
787
+ target_length=past_key_values.get_max_length(),
788
+ dtype=dtype,
789
+ device=device,
790
+ min_dtype=min_dtype,
791
+ cache_position=cache_position,
792
+ batch_size=batch_size,
793
+ )
794
+
795
+ if num_logits_to_keep is not None:
796
+ model_inputs["num_logits_to_keep"] = num_logits_to_keep
797
+
798
+ model_inputs.update(
799
+ {
800
+ "position_ids": position_ids,
801
+ "cache_position": cache_position,
802
+ "past_key_values": past_key_values,
803
+ "use_cache": use_cache,
804
+ "attention_mask": attention_mask,
805
+ }
806
+ )
807
+ return model_inputs
utils.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from dataclasses import dataclass
3
+ from typing import List, Optional
4
+
5
+ import torch
6
+
7
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
8
+
9
+
10
+ @dataclass
11
+ class IterStep:
12
+ """A helper class for the iteration plan"""
13
+ layer_slice: slice = slice(None)
14
+ requires_grad: bool = True
15
+ update: bool = True
16
+
17
+ @dataclass
18
+ class LayerType:
19
+ """A helper class to collect the layer type information"""
20
+ layer_idx: int
21
+ use_sliding_window: bool
22
+ attends_to: int
23
+ attends_top: bool
24
+ computes_kv: bool
25
+
26
+ class LayerTypeParser:
27
+ """
28
+ A helper class to parse the layer type string and provide some useful methods.
29
+
30
+ Arguments:
31
+ layer_type (str): A string of integers separated by underscores. The i-th integer
32
+ means the layer will use the key-value pair in the i-th layer as the kv cache.
33
+ Special characters may be placed after the integers:
34
+ - `s` means the layer will use sliding window attention.
35
+
36
+ >>> layer_type = LayerTypeParser("0_0_0_5s_5s_5s_8_8_8")[3]
37
+ >>> layer_type.attends_to
38
+ 5
39
+ >>> layer_type.attends_top
40
+ True
41
+ >>> layer_type.use_sliding_window
42
+ True
43
+ """
44
+ def __init__(self, layer_type: str):
45
+ self._layer_type = layer_type
46
+
47
+ # parse the layer type
48
+ self.layer_indices = []
49
+ self.sliding_window = []
50
+ for s in layer_type.split("_"):
51
+ layer_idx, sliding_window = re.match(r"^(\d+)(s)?$", s).groups()
52
+ self.layer_indices.append(int(layer_idx))
53
+ self.sliding_window.append(bool(sliding_window))
54
+
55
+ def __len__(self):
56
+ return len(self.layer_indices)
57
+
58
+ def __getitem__(self, layer_idx: int) -> LayerType:
59
+ """return the layer type information for the given layer index"""
60
+ return LayerType(
61
+ layer_idx=layer_idx,
62
+ use_sliding_window=self.sliding_window[layer_idx],
63
+ attends_to=self.layer_indices[layer_idx],
64
+ attends_top=self.layer_indices[layer_idx] > layer_idx,
65
+ computes_kv=layer_idx in self.layer_indices,
66
+ )
67
+
68
+ def use_sliding_window(self) -> bool:
69
+ """whether there exists a layer that uses sliding window attention"""
70
+ return any(self.sliding_window)
71
+
72
+ def attends_top(self) -> bool:
73
+ """whether there exists a layer that attends to layers above it"""
74
+ return any(self.layer_indices[i] > i for i in range(len(self)))
75
+
76
+ def iteration_plan(self, forward_passes: int = 7, backward_passes: int = 2) -> List[IterStep]:
77
+ """
78
+ Return a iteration plan for the layer types. The plan is a list of IterStep objects.
79
+ """
80
+ # if there is no cyclic dependency, return the default plan
81
+ if not self.attends_top():
82
+ return [IterStep()]
83
+
84
+ # otherwise, return the plan for the cyclic dependency
85
+ plan = []
86
+ i = 0
87
+ while i < len(self):
88
+
89
+ # if the layer attends to top layers, resolve the cyclic dependency
90
+ if self[i].attends_top:
91
+
92
+ # find the top layer in the cyclic dependency
93
+ top = self[i].attends_to
94
+ while top < max(self.layer_indices[i: top + 1]):
95
+ top = max(self.layer_indices[i: top + 1])
96
+ top += 1
97
+
98
+ # create iteration plan for this group
99
+ layer_slice = slice(i, top)
100
+ plan.extend([
101
+ *forward_passes * [IterStep(layer_slice, requires_grad=False, update=False)],
102
+ *(backward_passes - 1) * [IterStep(layer_slice, update=False)],
103
+ IterStep(layer_slice)
104
+ ])
105
+
106
+ # otherwise, create a default plan
107
+ else:
108
+
109
+ top = i + 1
110
+ while top < len(self) and not self[top].attends_top:
111
+ top += 1
112
+ plan.append(IterStep(slice(i, top)))
113
+
114
+ # update the index
115
+ i = top
116
+
117
+ return plan
118
+
119
+ def check(self, num_hidden_layers: int):
120
+ """Check if the layer type is valid"""
121
+ if len(self.layer_indices) != num_hidden_layers:
122
+ raise ValueError("The number of layer types should be equal to the number of hidden layers.")
123
+ for i in range(num_hidden_layers):
124
+ if self.layer_indices[i] not in range(num_hidden_layers):
125
+ raise ValueError("The layer type should be in the range of the number of hidden layers.")
126
+
127
+
128
+ def flash_attention_forward(
129
+ query_states: torch.Tensor,
130
+ key_states: torch.Tensor,
131
+ value_states: torch.Tensor,
132
+ attention_mask: torch.Tensor,
133
+ query_length: int,
134
+ is_causal: bool,
135
+ dropout: float = 0.0,
136
+ position_ids: Optional[torch.Tensor] = None,
137
+ softmax_scale: Optional[float] = None,
138
+ sliding_window: Optional[int] = None,
139
+ use_top_left_mask: bool = False,
140
+ softcap: Optional[float] = None,
141
+ deterministic: bool = None,
142
+ no_diag: bool = False,
143
+ ):
144
+ """
145
+ This function is a wrapper around the _flash_attention_forward function in the
146
+ transformers library. It adds support to mask the diagonal elements of the attention
147
+ matrix. The diagonal mask is used to resolve the cyclic dependencies in the LCKV model.
148
+ """
149
+ prune_query = False
150
+ if no_diag:
151
+ if key_states.size(1) == 1:
152
+ b, l, _, d = value_states.size()
153
+ _, _, h, _ = query_states.size()
154
+ return value_states.new_zeros((b, l, h, d))
155
+
156
+ if key_states.size(1) == query_states.size(1):
157
+ prune_query = True
158
+ query_states = query_states[:, 1:, :, :]
159
+ query_length -= 1
160
+
161
+ if attention_mask is not None:
162
+ attention_mask = attention_mask[:, 1:]
163
+
164
+ key_states = key_states[:, :-1, :, :]
165
+ value_states = value_states[:, :-1, :, :]
166
+
167
+ if sliding_window is not None:
168
+ sliding_window = sliding_window - 1
169
+
170
+ result: torch.Tensor = _flash_attention_forward(
171
+ query_states=query_states,
172
+ key_states=key_states,
173
+ value_states=value_states,
174
+ attention_mask=attention_mask,
175
+ query_length=query_length,
176
+ is_causal=is_causal,
177
+ dropout=dropout,
178
+ position_ids=position_ids,
179
+ softmax_scale=softmax_scale,
180
+ sliding_window=sliding_window,
181
+ use_top_left_mask=use_top_left_mask,
182
+ softcap=softcap,
183
+ deterministic=deterministic,
184
+ )
185
+
186
+ if prune_query:
187
+ b, _, h, d = result.size()
188
+ result = torch.cat([result.new_zeros((b, 1, h, d)), result], dim=1)
189
+
190
+ return result