|
from typing import Any, Dict, List, Optional, Tuple |
|
|
|
import torch |
|
|
|
from transformers.cache_utils import Cache, DynamicCache, SinkCache |
|
|
|
from .utils import LayerTypeParser |
|
|
|
|
|
class IndexedCache(Cache): |
|
""" |
|
Similar to the `DynamicCache` class, but with the ability to index the cache by layer index. DynamicCache |
|
assumes that all layers compute KVs, while IndexedCache allows for a more flexible cache structure. |
|
""" |
|
build_position_ids_based_on_cache = False |
|
|
|
def __init__(self) -> None: |
|
super().__init__() |
|
self.key_cache: Dict[int, torch.Tensor] = {} |
|
self.value_cache: Dict[int, torch.Tensor] = {} |
|
self._seen_tokens = 0 |
|
self._update = True |
|
|
|
def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: |
|
""" |
|
Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the |
|
sequence length. |
|
""" |
|
if layer_idx in self.key_cache: |
|
return (self.key_cache[layer_idx], self.value_cache[layer_idx]) |
|
else: |
|
raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") |
|
|
|
def __iter__(self): |
|
""" |
|
Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over |
|
keys and values |
|
""" |
|
for layer_idx in sorted(self.key_cache.keys()): |
|
yield (self.key_cache[layer_idx], self.value_cache[layer_idx]) |
|
|
|
def __len__(self): |
|
""" |
|
Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds |
|
to the number of layers that compute KVs in the model. |
|
""" |
|
return len(self.key_cache) |
|
|
|
@property |
|
def min_layer(self) -> int: |
|
return min(self.key_cache.keys()) if len(self.key_cache) > 0 else None |
|
|
|
def is_min_layer(self, layer_idx: int) -> bool: |
|
return self.min_layer is None or self.min_layer == layer_idx |
|
|
|
def update( |
|
self, |
|
key_states: torch.Tensor, |
|
value_states: torch.Tensor, |
|
layer_idx: int, |
|
cache_kwargs: Optional[Dict[str, Any]] = None, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. |
|
|
|
Parameters: |
|
key_states (`torch.Tensor`): |
|
The new key states to cache. |
|
value_states (`torch.Tensor`): |
|
The new value states to cache. |
|
layer_idx (`int`): |
|
The index of the layer to cache the states for. |
|
cache_kwargs (`Dict[str, Any]`, `optional`): |
|
Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. |
|
|
|
Return: |
|
A tuple containing the updated key and value states. |
|
""" |
|
|
|
if self.is_min_layer(layer_idx): |
|
self._seen_tokens += key_states.shape[-2] |
|
|
|
|
|
if layer_idx not in self.key_cache: |
|
new_key_states = key_states |
|
new_value_states = value_states |
|
else: |
|
new_key_states = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) |
|
new_value_states = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) |
|
|
|
|
|
if self._update: |
|
self.key_cache[layer_idx] = new_key_states |
|
self.value_cache[layer_idx] = new_value_states |
|
|
|
return new_key_states, new_value_states |
|
|
|
def get_seq_length(self, layer_idx: Optional[int] = None) -> int: |
|
"""Returns the sequence length of the cached states. A layer index can be optionally passed.""" |
|
if layer_idx is None: |
|
layer_idx = self.min_layer |
|
|
|
|
|
is_empty_layer = ( |
|
(len(self.key_cache) == 0) |
|
or (layer_idx not in self.key_cache) |
|
) |
|
layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0 |
|
return layer_seq_length |
|
|
|
def get_max_length(self) -> Optional[int]: |
|
"""Returns the maximum sequence length of the cached states. IndexedCache does not have a maximum length.""" |
|
return None |
|
|
|
@classmethod |
|
def from_cache(cls, dynamic_cache: DynamicCache, *args, **kwargs) -> "IndexedCache": |
|
"""Converts a dynamic cache into an equivalent `IndexedCache`.""" |
|
cache = cls(*args, **kwargs) |
|
|
|
cache._seen_tokens = dynamic_cache._seen_tokens |
|
for layer_idx in range(len(dynamic_cache.key_cache)): |
|
key_states, value_states = dynamic_cache[layer_idx] |
|
cache.update(key_states, value_states, layer_idx) |
|
|
|
return cache |
|
|
|
|
|
class IndexedSinkCache(Cache): |
|
""" |
|
This is a fix to the SinkCache class in the transformers library. It also allows for the cache to be indexed by |
|
layer index, similar to the `IndexedCache` class. |
|
""" |
|
build_position_ids_based_on_cache = True |
|
|
|
def __init__(self, window_length: int = None, num_sink_tokens: int = None) -> None: |
|
super().__init__() |
|
self.key_cache: Dict[int, torch.Tensor] = {} |
|
self.value_cache: Dict[int, torch.Tensor] = {} |
|
self.window_length = window_length |
|
self.num_sink_tokens = num_sink_tokens |
|
self.cos_sin_rerotation_cache = {} |
|
self._cos_cache = None |
|
self._sin_cache = None |
|
self._seen_tokens = 0 |
|
self._update = True |
|
|
|
@staticmethod |
|
def _rotate_half(x): |
|
x1 = x[..., : x.shape[-1] // 2] |
|
x2 = x[..., x.shape[-1] // 2 :] |
|
return torch.cat((-x2, x1), dim=-1) |
|
|
|
def _apply_key_rotary_pos_emb( |
|
self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor |
|
) -> torch.Tensor: |
|
rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin) |
|
return rotated_key_states |
|
|
|
def _get_rerotation_cos_sin( |
|
self, offset: int, dtype: torch.dtype, cos: torch.Tensor, sin: torch.Tensor |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
if offset not in self.cos_sin_rerotation_cache: |
|
|
|
cos = cos.to(torch.float32) |
|
sin = sin.to(torch.float32) |
|
|
|
|
|
original_cos = cos[self.num_sink_tokens + offset :] |
|
shifted_cos = cos[self.num_sink_tokens : -offset] |
|
original_sin = sin[self.num_sink_tokens + offset :] |
|
shifted_sin = sin[self.num_sink_tokens : -offset] |
|
rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin |
|
rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin |
|
|
|
self.cos_sin_rerotation_cache[offset] = ( |
|
rerotation_cos.to(dtype).unsqueeze(0), |
|
rerotation_sin.to(dtype).unsqueeze(0), |
|
) |
|
return self.cos_sin_rerotation_cache[offset] |
|
|
|
@property |
|
def min_layer(self) -> int: |
|
return min(self.key_cache.keys()) if len(self.key_cache) > 0 else None |
|
|
|
def is_min_layer(self, layer_idx: int) -> bool: |
|
return self.min_layer is None or self.min_layer == layer_idx |
|
|
|
def get_seq_length(self, layer_idx: Optional[int] = None) -> int: |
|
"""Returns the sequence length of the cached states. A layer index can be optionally passed.""" |
|
|
|
|
|
if layer_idx is None: |
|
layer_idx = self.min_layer |
|
|
|
if layer_idx not in self.key_cache: |
|
return 0 |
|
|
|
return self.key_cache[layer_idx].shape[-2] |
|
|
|
def get_max_length(self) -> Optional[int]: |
|
"""Returns the maximum sequence length of the cached states.""" |
|
return self.window_length |
|
|
|
def update( |
|
self, |
|
key_states: torch.Tensor, |
|
value_states: torch.Tensor, |
|
layer_idx: int, |
|
cache_kwargs: Optional[Dict[str, Any]] = None, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. |
|
|
|
Parameters: |
|
key_states (`torch.Tensor`): |
|
The new key states to cache. |
|
value_states (`torch.Tensor`): |
|
The new value states to cache. |
|
layer_idx (`int`): |
|
The index of the layer to cache the states for. |
|
cache_kwargs (`Dict[str, Any]`, `optional`): |
|
Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`, |
|
`cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the |
|
rotation as the tokens are shifted. |
|
|
|
Return: |
|
A tuple containing the updated key and value states. |
|
""" |
|
|
|
|
|
sin = cache_kwargs.get("sin") |
|
cos = cache_kwargs.get("cos") |
|
partial_rotation_size = cache_kwargs.get("partial_rotation_size") |
|
using_rope = cos is not None and sin is not None |
|
|
|
|
|
if self.is_min_layer(layer_idx): |
|
self._seen_tokens += key_states.shape[-2] |
|
|
|
|
|
if using_rope and self.is_min_layer(layer_idx): |
|
|
|
|
|
if cos.dim() == 2: |
|
self._cos_cache = cos |
|
self._sin_cache = sin |
|
else: |
|
if self._cos_cache is None: |
|
self._cos_cache = cos[0, ...] |
|
self._sin_cache = sin[0, ...] |
|
elif self._cos_cache.shape[0] < self.window_length + key_states.shape[-2]: |
|
self._cos_cache = torch.cat([self._cos_cache[: self.window_length], cos[0, ...]], dim=0) |
|
self._sin_cache = torch.cat([self._sin_cache[: self.window_length], sin[0, ...]], dim=0) |
|
|
|
|
|
if layer_idx not in self.key_cache: |
|
|
|
new_key_states = key_states |
|
new_value_states = value_states |
|
|
|
else: |
|
|
|
new_key_states = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) |
|
new_value_states = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) |
|
|
|
if self._update: |
|
self.key_cache[layer_idx] = new_key_states |
|
self.value_cache[layer_idx] = new_value_states |
|
|
|
|
|
if (seq_length := self.get_seq_length(layer_idx)) > self.window_length: |
|
|
|
keys_to_keep = self.key_cache[layer_idx][:, :, -self.window_length + self.num_sink_tokens :] |
|
|
|
|
|
if using_rope: |
|
rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin( |
|
seq_length - self.window_length, |
|
key_states.dtype, |
|
self._cos_cache[:seq_length], |
|
self._sin_cache[:seq_length], |
|
) |
|
if partial_rotation_size is not None: |
|
keys_to_keep, keys_pass = ( |
|
keys_to_keep[..., :partial_rotation_size], |
|
keys_to_keep[..., partial_rotation_size:], |
|
) |
|
keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, rerotation_cos, rerotation_sin) |
|
if partial_rotation_size is not None: |
|
keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1) |
|
|
|
|
|
sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens] |
|
self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep], dim=-2) |
|
|
|
sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens] |
|
values_to_keep = self.value_cache[layer_idx][:, :, -self.window_length + self.num_sink_tokens :] |
|
self.value_cache[layer_idx] = torch.cat([sink_values, values_to_keep], dim=-2) |
|
|
|
return new_key_states, new_value_states |
|
|
|
@classmethod |
|
def from_cache(cls, sink_cache: SinkCache, *args, **kwargs) -> "IndexedSinkCache": |
|
"""Converts a dynamic cache into an equivalent `IndexedCache`.""" |
|
cache = cls(*args, **kwargs) |
|
|
|
cache.window_length = sink_cache.window_length |
|
cache.num_sink_tokens = sink_cache.num_sink_tokens |
|
cache._seen_tokens = sink_cache._seen_tokens |
|
cache._cos_cache = sink_cache._cos_cache |
|
cache._sin_cache = sink_cache._sin_cache |
|
cache.cos_sin_rerotation_cache = sink_cache.cos_sin_rerotation_cache |
|
for layer_idx in range(len(sink_cache.key_cache)): |
|
cache.key_cache[layer_idx] = sink_cache.key_cache[layer_idx] |
|
cache.value_cache[layer_idx] = sink_cache.value_cache[layer_idx] |
|
|
|
return cache |
|
|
|
|
|
class IndexedSlidingWindowCache(IndexedCache): |
|
""" |
|
Similar to the `SlidingWindowCache` class, but with the ability to index the cache by layer index. It is no longer |
|
a subclass of `StaticCache` as it is dynamic. |
|
""" |
|
build_position_ids_based_on_cache = False |
|
|
|
def __init__(self, sliding_window: int = None) -> None: |
|
super().__init__() |
|
self.sliding_window = sliding_window |
|
|
|
def update( |
|
self, |
|
key_states: torch.Tensor, |
|
value_states: torch.Tensor, |
|
layer_idx: int, |
|
cache_kwargs: Optional[Dict[str, Any]] = None, |
|
) -> Tuple[torch.Tensor]: |
|
|
|
if self.is_min_layer(layer_idx): |
|
self._seen_tokens += key_states.shape[-2] |
|
|
|
|
|
if layer_idx not in self.key_cache: |
|
|
|
new_key_states = key_states |
|
new_value_states = value_states |
|
|
|
else: |
|
|
|
new_key_states = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) |
|
new_value_states = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) |
|
|
|
if self._update: |
|
self.key_cache[layer_idx] = new_key_states |
|
self.value_cache[layer_idx] = new_value_states |
|
|
|
|
|
if self.get_seq_length(layer_idx) > self.sliding_window: |
|
self.key_cache[layer_idx] = self.key_cache[layer_idx][:, :, -self.sliding_window :] |
|
self.value_cache[layer_idx] = self.value_cache[layer_idx][:, :, -self.sliding_window :] |
|
|
|
return new_key_states, new_value_states |
|
|
|
def get_max_length(self) -> Optional[int]: |
|
return self.sliding_window |
|
|
|
@classmethod |
|
def from_cache(cls, sliding_window_cache: "IndexedSlidingWindowCache", *args, **kwargs) -> "IndexedSlidingWindowCache": |
|
"""This is to override the `from_cache` method in the `IndexedCache` class.""" |
|
cache = cls(*args, **kwargs) |
|
|
|
cache._seen_tokens = sliding_window_cache._seen_tokens |
|
cache.sliding_window = sliding_window_cache.sliding_window |
|
for layer_idx in range(len(sliding_window_cache.key_cache)): |
|
cache.key_cache[layer_idx] = sliding_window_cache.key_cache[layer_idx] |
|
cache.value_cache[layer_idx] = sliding_window_cache.value_cache[layer_idx] |
|
|
|
return cache |
|
|
|
|
|
class IndexedHybridCache(IndexedSlidingWindowCache, IndexedCache): |
|
""" |
|
Hybrid Cache class to be used for models that alternate between a local sliding window attention and global |
|
attention in every other layer. Under the hood, Hybrid Cache leverages ["IndexedSlidingWindowCache"] for |
|
sliding window attention and ["IndexedCache"] for global attention. |
|
""" |
|
build_position_ids_based_on_cache = False |
|
|
|
def __init__(self, parser: LayerTypeParser = None, sliding_window: int = None) -> None: |
|
super().__init__(sliding_window=sliding_window) |
|
self.parser = parser |
|
|
|
def update( |
|
self, |
|
key_states: torch.Tensor, |
|
value_states: torch.Tensor, |
|
layer_idx: int, |
|
cache_kwargs: Optional[Dict[str, Any]] = None, |
|
) -> Tuple[torch.Tensor]: |
|
if self.parser[layer_idx].use_sliding_window: |
|
return IndexedSlidingWindowCache.update(self, key_states, value_states, layer_idx, cache_kwargs) |
|
else: |
|
return IndexedCache.update(self, key_states, value_states, layer_idx, cache_kwargs) |
|
|
|
def get_max_length(self) -> Optional[int]: |
|
return IndexedCache.get_max_length(self) |
|
|
|
@classmethod |
|
def from_cache(cls, hybrid_cache: "IndexedHybridCache", *args, **kwargs) -> "IndexedHybridCache": |
|
"""This is to override the `from_cache` method in the `IndexedSlidingWindowCache` class.""" |
|
cache = cls(*args, **kwargs) |
|
|
|
cache._seen_tokens = hybrid_cache._seen_tokens |
|
cache.sliding_window = hybrid_cache.sliding_window |
|
cache.parser = hybrid_cache.parser |
|
for layer_idx in range(len(hybrid_cache.key_cache)): |
|
cache.key_cache[layer_idx] = hybrid_cache.key_cache[layer_idx] |
|
cache.value_cache[layer_idx] = hybrid_cache.value_cache[layer_idx] |
|
|
|
return cache |
|
|
|
|
|
class LayerCache(torch.nn.Module): |
|
""" |
|
A cache for storing the key-value pairs for layers. |
|
""" |
|
def __init__(self) -> None: |
|
""" |
|
The placeholder is used to expand the key-value pairs if the layer attends to the top layers. |
|
Size: (batch_size, num_key_value_heads, 1, head_dim) |
|
""" |
|
super().__init__() |
|
self.key_layer_cache: Dict[int, torch.Tensor] = {} |
|
self.value_layer_cache: Dict[int, torch.Tensor] = {} |
|
self.layer_type = None |
|
self.placeholder = None |
|
|
|
def setup(self, placeholder: torch.Tensor): |
|
"""setup the cache, calling this function is necessary if there is a layer that attends to the top layers""" |
|
self.placeholder = placeholder |
|
|
|
def initialize(self, parser: LayerTypeParser, sequence_length: int): |
|
"""initialize the cache""" |
|
layers_to_init = {parser[idx].attends_to for idx in range(len(parser)) if parser[idx].attends_top} |
|
|
|
if layers_to_init: |
|
b, h, _, d = self.placeholder.size() |
|
init_kvs = self.placeholder.new_zeros((b, h, sequence_length, d)) |
|
|
|
for layer_idx in layers_to_init: |
|
self.layer_append(layer_idx, init_kvs, init_kvs) |
|
|
|
def layer_get(self, layer_idx: int, zerofill: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: |
|
key_states = self.key_layer_cache.get(layer_idx, None) |
|
value_states = self.value_layer_cache.get(layer_idx, None) |
|
|
|
if zerofill: |
|
if key_states is None: |
|
key_states = self.placeholder |
|
value_states = self.placeholder |
|
else: |
|
key_states = torch.cat([self.placeholder, key_states], dim=2) |
|
value_states = torch.cat([self.placeholder, value_states], dim=2) |
|
|
|
return key_states, value_states |
|
|
|
def layer_set(self, layer_idx: int, key: torch.Tensor, value: torch.Tensor): |
|
self.key_layer_cache[layer_idx] = key |
|
self.value_layer_cache[layer_idx] = value |
|
|
|
def layer_append(self, layer_idx: int, key: torch.Tensor, value: torch.Tensor): |
|
if layer_idx not in self.key_layer_cache: |
|
self.key_layer_cache[layer_idx] = key |
|
self.value_layer_cache[layer_idx] = value |
|
else: |
|
self.key_layer_cache[layer_idx] = torch.cat([self.key_layer_cache[layer_idx], key], dim=2) |
|
self.value_layer_cache[layer_idx] = torch.cat([self.value_layer_cache[layer_idx], value], dim=2) |
|
|
|
|
|
class LayerIndexedCache(LayerCache, IndexedCache): |
|
""" |
|
A cache for storing the key-value pairs for layers, in combination with the ability of standard KV cache. |
|
""" |
|
def __init__(self) -> None: |
|
LayerCache.__init__(self) |
|
IndexedCache.__init__(self) |
|
|
|
|
|
class LayerIndexedSinkCache(LayerCache, IndexedSinkCache): |
|
""" |
|
A cache for storing the key-value pairs for layers, in combination with the ability of sink KV cache. |
|
""" |
|
def __init__(self) -> None: |
|
LayerCache.__init__(self) |
|
IndexedSinkCache.__init__(self) |
|
|
|
|
|
class LayerIndexedSlidingWindowCache(LayerCache, IndexedSlidingWindowCache): |
|
""" |
|
A cache for storing the key-value pairs for layers, in combination with the ability of sliding window KV cache. |
|
""" |
|
def __init__(self) -> None: |
|
LayerCache.__init__(self) |
|
IndexedSlidingWindowCache.__init__(self) |
|
|
|
|
|
class LayerIndexedHybridCache(LayerCache, IndexedHybridCache): |
|
""" |
|
A cache for storing the key-value pairs for layers, in combination with the ability of hybrid KV cache. |
|
""" |
|
def __init__(self) -> None: |
|
LayerCache.__init__(self) |
|
IndexedHybridCache.__init__(self) |
|
|
|
|
|
class AutoLayerCache(torch.nn.Module): |
|
""" |
|
AutoLayerCache is a module that automatically creates a cache from an existing cache. |
|
""" |
|
CACHE_MAPPING = { |
|
DynamicCache: LayerIndexedCache, |
|
SinkCache: LayerIndexedSinkCache, |
|
IndexedSlidingWindowCache: LayerIndexedSlidingWindowCache, |
|
IndexedHybridCache: LayerIndexedHybridCache, |
|
} |
|
|
|
def __init__(self, *args, **kwargs): |
|
raise RuntimeError( |
|
f"{self.__class__.__name__} is designed to be instantiated " |
|
f"using the `{self.__class__.__name__}.from_cache(cache)` method." |
|
) |
|
|
|
@classmethod |
|
def from_cache(cls, cache: Cache, *args, **kwargs): |
|
""" |
|
Create a new cache from an existing cache. The new cache will have the same type as the original cache. |
|
""" |
|
cache_type = type(cache) |
|
if cache_type not in cls.CACHE_MAPPING: |
|
raise ValueError(f"Cache type {cache_type} is not supported by {cls.__name__}.") |
|
|
|
cache_class = cls.CACHE_MAPPING[cache_type] |
|
|
|
if hasattr(cache_class, "from_cache"): |
|
return cache_class.from_cache(cache, *args, **kwargs) |
|
else: |
|
|
|
new_cache = cache_class(*args, **kwargs) |
|
new_cache.__dict__.update(cache.__dict__) |
|
return new_cache |
|
|