from typing import Any, Dict, List, Optional, Tuple import torch from transformers.cache_utils import Cache, DynamicCache, SinkCache from .utils import LayerTypeParser class IndexedCache(Cache): """ Similar to the `DynamicCache` class, but with the ability to index the cache by layer index. DynamicCache assumes that all layers compute KVs, while IndexedCache allows for a more flexible cache structure. """ build_position_ids_based_on_cache = False def __init__(self) -> None: super().__init__() self.key_cache: Dict[int, torch.Tensor] = {} self.value_cache: Dict[int, torch.Tensor] = {} self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen self._update = True # to prevent the cache from updating when inference with iterations def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: """ Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the sequence length. """ if layer_idx in self.key_cache: return (self.key_cache[layer_idx], self.value_cache[layer_idx]) else: raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") def __iter__(self): """ Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over keys and values """ for layer_idx in sorted(self.key_cache.keys()): yield (self.key_cache[layer_idx], self.value_cache[layer_idx]) def __len__(self): """ Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds to the number of layers that compute KVs in the model. """ return len(self.key_cache) @property def min_layer(self) -> int: return min(self.key_cache.keys()) if len(self.key_cache) > 0 else None def is_min_layer(self, layer_idx: int) -> bool: return self.min_layer is None or self.min_layer == layer_idx def update( self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int, cache_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. Parameters: key_states (`torch.Tensor`): The new key states to cache. value_states (`torch.Tensor`): The new value states to cache. layer_idx (`int`): The index of the layer to cache the states for. cache_kwargs (`Dict[str, Any]`, `optional`): Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. Return: A tuple containing the updated key and value states. """ # Update the number of seen tokens if self.is_min_layer(layer_idx): self._seen_tokens += key_states.shape[-2] # Retrieve the cache if layer_idx not in self.key_cache: new_key_states = key_states new_value_states = value_states else: new_key_states = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) new_value_states = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) # Update the cache if self._update: self.key_cache[layer_idx] = new_key_states self.value_cache[layer_idx] = new_value_states return new_key_states, new_value_states def get_seq_length(self, layer_idx: Optional[int] = None) -> int: """Returns the sequence length of the cached states. A layer index can be optionally passed.""" if layer_idx is None: layer_idx = self.min_layer # TODO: deprecate this function in favor of `cache_position` is_empty_layer = ( (len(self.key_cache) == 0) # no cache in any layer or (layer_idx not in self.key_cache) # skipped `layer_idx` and hasn't run a layer with cache after it ) layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0 return layer_seq_length def get_max_length(self) -> Optional[int]: """Returns the maximum sequence length of the cached states. IndexedCache does not have a maximum length.""" return None @classmethod def from_cache(cls, dynamic_cache: DynamicCache, *args, **kwargs) -> "IndexedCache": """Converts a dynamic cache into an equivalent `IndexedCache`.""" cache = cls(*args, **kwargs) cache._seen_tokens = dynamic_cache._seen_tokens for layer_idx in range(len(dynamic_cache.key_cache)): key_states, value_states = dynamic_cache[layer_idx] cache.update(key_states, value_states, layer_idx) return cache class IndexedSinkCache(Cache): """ This is a fix to the SinkCache class in the transformers library. It also allows for the cache to be indexed by layer index, similar to the `IndexedCache` class. """ build_position_ids_based_on_cache = True def __init__(self, window_length: int = None, num_sink_tokens: int = None) -> None: super().__init__() self.key_cache: Dict[int, torch.Tensor] = {} self.value_cache: Dict[int, torch.Tensor] = {} self.window_length = window_length self.num_sink_tokens = num_sink_tokens self.cos_sin_rerotation_cache = {} self._cos_cache = None self._sin_cache = None self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen self._update = True # to prevent the cache from updating when inference with iterations @staticmethod def _rotate_half(x): x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def _apply_key_rotary_pos_emb( self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor ) -> torch.Tensor: rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin) return rotated_key_states def _get_rerotation_cos_sin( self, offset: int, dtype: torch.dtype, cos: torch.Tensor, sin: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: if offset not in self.cos_sin_rerotation_cache: # Upcast to float32 temporarily for better accuracy cos = cos.to(torch.float32) sin = sin.to(torch.float32) # Compute the cos and sin required for back- and forward-rotating to one position earlier in the sequence original_cos = cos[self.num_sink_tokens + offset :] shifted_cos = cos[self.num_sink_tokens : -offset] original_sin = sin[self.num_sink_tokens + offset :] shifted_sin = sin[self.num_sink_tokens : -offset] rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin self.cos_sin_rerotation_cache[offset] = ( rerotation_cos.to(dtype).unsqueeze(0), rerotation_sin.to(dtype).unsqueeze(0), ) return self.cos_sin_rerotation_cache[offset] @property def min_layer(self) -> int: return min(self.key_cache.keys()) if len(self.key_cache) > 0 else None def is_min_layer(self, layer_idx: int) -> bool: return self.min_layer is None or self.min_layer == layer_idx def get_seq_length(self, layer_idx: Optional[int] = None) -> int: """Returns the sequence length of the cached states. A layer index can be optionally passed.""" # TODO: deprecate this function in favor of `cache_position` # Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length if layer_idx is None: layer_idx = self.min_layer if layer_idx not in self.key_cache: return 0 return self.key_cache[layer_idx].shape[-2] def get_max_length(self) -> Optional[int]: """Returns the maximum sequence length of the cached states.""" return self.window_length def update( self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int, cache_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. Parameters: key_states (`torch.Tensor`): The new key states to cache. value_states (`torch.Tensor`): The new value states to cache. layer_idx (`int`): The index of the layer to cache the states for. cache_kwargs (`Dict[str, Any]`, `optional`): Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`, `cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the rotation as the tokens are shifted. Return: A tuple containing the updated key and value states. """ # Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models # with partially rotated position embeddings, like Phi or Persimmon. sin = cache_kwargs.get("sin") cos = cache_kwargs.get("cos") partial_rotation_size = cache_kwargs.get("partial_rotation_size") using_rope = cos is not None and sin is not None # Update the number of seen tokens if self.is_min_layer(layer_idx): self._seen_tokens += key_states.shape[-2] # Update the sin/cos cache, which holds sin/cos values for all possible positions if using_rope and self.is_min_layer(layer_idx): # BC: some models still pass `sin`/`cos` with 2 dims. In those models, they are the full sin/cos. Remove # after all RoPE models have a llama-like cache utilization. if cos.dim() == 2: self._cos_cache = cos self._sin_cache = sin else: if self._cos_cache is None: self._cos_cache = cos[0, ...] self._sin_cache = sin[0, ...] elif self._cos_cache.shape[0] < self.window_length + key_states.shape[-2]: self._cos_cache = torch.cat([self._cos_cache[: self.window_length], cos[0, ...]], dim=0) self._sin_cache = torch.cat([self._sin_cache[: self.window_length], sin[0, ...]], dim=0) # [bsz, num_heads, seq_len, head_dim] if layer_idx not in self.key_cache: # Empty cache new_key_states = key_states new_value_states = value_states else: # Growing cache new_key_states = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) new_value_states = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) if self._update: self.key_cache[layer_idx] = new_key_states self.value_cache[layer_idx] = new_value_states # If the cache is full, we need to shift the cache if (seq_length := self.get_seq_length(layer_idx)) > self.window_length: # Shifting cache keys_to_keep = self.key_cache[layer_idx][:, :, -self.window_length + self.num_sink_tokens :] # On RoPE models, we need to recompute the Key rotation as the tokens are shifted if using_rope: rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin( seq_length - self.window_length, key_states.dtype, self._cos_cache[:seq_length], self._sin_cache[:seq_length], ) if partial_rotation_size is not None: keys_to_keep, keys_pass = ( keys_to_keep[..., :partial_rotation_size], keys_to_keep[..., partial_rotation_size:], ) keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, rerotation_cos, rerotation_sin) if partial_rotation_size is not None: keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1) # Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens] self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep], dim=-2) sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens] values_to_keep = self.value_cache[layer_idx][:, :, -self.window_length + self.num_sink_tokens :] self.value_cache[layer_idx] = torch.cat([sink_values, values_to_keep], dim=-2) return new_key_states, new_value_states @classmethod def from_cache(cls, sink_cache: SinkCache, *args, **kwargs) -> "IndexedSinkCache": """Converts a dynamic cache into an equivalent `IndexedCache`.""" cache = cls(*args, **kwargs) cache.window_length = sink_cache.window_length cache.num_sink_tokens = sink_cache.num_sink_tokens cache._seen_tokens = sink_cache._seen_tokens cache._cos_cache = sink_cache._cos_cache cache._sin_cache = sink_cache._sin_cache cache.cos_sin_rerotation_cache = sink_cache.cos_sin_rerotation_cache for layer_idx in range(len(sink_cache.key_cache)): cache.key_cache[layer_idx] = sink_cache.key_cache[layer_idx] cache.value_cache[layer_idx] = sink_cache.value_cache[layer_idx] return cache class IndexedSlidingWindowCache(IndexedCache): """ Similar to the `SlidingWindowCache` class, but with the ability to index the cache by layer index. It is no longer a subclass of `StaticCache` as it is dynamic. """ build_position_ids_based_on_cache = False def __init__(self, sliding_window: int = None) -> None: super().__init__() self.sliding_window = sliding_window def update( self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int, cache_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor]: # Update the number of seen tokens if self.is_min_layer(layer_idx): self._seen_tokens += key_states.shape[-2] # [bsz, num_heads, seq_len, head_dim] if layer_idx not in self.key_cache: # Empty cache new_key_states = key_states new_value_states = value_states else: # Growing cache new_key_states = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) new_value_states = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) if self._update: self.key_cache[layer_idx] = new_key_states self.value_cache[layer_idx] = new_value_states # If the cache is full, we need to shift the cache if self.get_seq_length(layer_idx) > self.sliding_window: self.key_cache[layer_idx] = self.key_cache[layer_idx][:, :, -self.sliding_window :] self.value_cache[layer_idx] = self.value_cache[layer_idx][:, :, -self.sliding_window :] return new_key_states, new_value_states def get_max_length(self) -> Optional[int]: return self.sliding_window @classmethod def from_cache(cls, sliding_window_cache: "IndexedSlidingWindowCache", *args, **kwargs) -> "IndexedSlidingWindowCache": """This is to override the `from_cache` method in the `IndexedCache` class.""" cache = cls(*args, **kwargs) cache._seen_tokens = sliding_window_cache._seen_tokens cache.sliding_window = sliding_window_cache.sliding_window for layer_idx in range(len(sliding_window_cache.key_cache)): cache.key_cache[layer_idx] = sliding_window_cache.key_cache[layer_idx] cache.value_cache[layer_idx] = sliding_window_cache.value_cache[layer_idx] return cache class IndexedHybridCache(IndexedSlidingWindowCache, IndexedCache): """ Hybrid Cache class to be used for models that alternate between a local sliding window attention and global attention in every other layer. Under the hood, Hybrid Cache leverages ["IndexedSlidingWindowCache"] for sliding window attention and ["IndexedCache"] for global attention. """ build_position_ids_based_on_cache = False def __init__(self, parser: LayerTypeParser = None, sliding_window: int = None) -> None: super().__init__(sliding_window=sliding_window) self.parser = parser def update( self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int, cache_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor]: if self.parser[layer_idx].use_sliding_window: return IndexedSlidingWindowCache.update(self, key_states, value_states, layer_idx, cache_kwargs) else: return IndexedCache.update(self, key_states, value_states, layer_idx, cache_kwargs) def get_max_length(self) -> Optional[int]: return IndexedCache.get_max_length(self) @classmethod def from_cache(cls, hybrid_cache: "IndexedHybridCache", *args, **kwargs) -> "IndexedHybridCache": """This is to override the `from_cache` method in the `IndexedSlidingWindowCache` class.""" cache = cls(*args, **kwargs) cache._seen_tokens = hybrid_cache._seen_tokens cache.sliding_window = hybrid_cache.sliding_window cache.parser = hybrid_cache.parser for layer_idx in range(len(hybrid_cache.key_cache)): cache.key_cache[layer_idx] = hybrid_cache.key_cache[layer_idx] cache.value_cache[layer_idx] = hybrid_cache.value_cache[layer_idx] return cache class LayerCache(torch.nn.Module): """ A cache for storing the key-value pairs for layers. """ def __init__(self) -> None: """ The placeholder is used to expand the key-value pairs if the layer attends to the top layers. Size: (batch_size, num_key_value_heads, 1, head_dim) """ super().__init__() self.key_layer_cache: Dict[int, torch.Tensor] = {} self.value_layer_cache: Dict[int, torch.Tensor] = {} self.layer_type = None self.placeholder = None def setup(self, placeholder: torch.Tensor): """setup the cache, calling this function is necessary if there is a layer that attends to the top layers""" self.placeholder = placeholder def initialize(self, parser: LayerTypeParser, sequence_length: int): """initialize the cache""" layers_to_init = {parser[idx].attends_to for idx in range(len(parser)) if parser[idx].attends_top} if layers_to_init: b, h, _, d = self.placeholder.size() init_kvs = self.placeholder.new_zeros((b, h, sequence_length, d)) for layer_idx in layers_to_init: self.layer_append(layer_idx, init_kvs, init_kvs) def layer_get(self, layer_idx: int, zerofill: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: key_states = self.key_layer_cache.get(layer_idx, None) value_states = self.value_layer_cache.get(layer_idx, None) if zerofill: if key_states is None: key_states = self.placeholder value_states = self.placeholder else: key_states = torch.cat([self.placeholder, key_states], dim=2) value_states = torch.cat([self.placeholder, value_states], dim=2) return key_states, value_states def layer_set(self, layer_idx: int, key: torch.Tensor, value: torch.Tensor): self.key_layer_cache[layer_idx] = key self.value_layer_cache[layer_idx] = value def layer_append(self, layer_idx: int, key: torch.Tensor, value: torch.Tensor): if layer_idx not in self.key_layer_cache: self.key_layer_cache[layer_idx] = key self.value_layer_cache[layer_idx] = value else: self.key_layer_cache[layer_idx] = torch.cat([self.key_layer_cache[layer_idx], key], dim=2) self.value_layer_cache[layer_idx] = torch.cat([self.value_layer_cache[layer_idx], value], dim=2) class LayerIndexedCache(LayerCache, IndexedCache): """ A cache for storing the key-value pairs for layers, in combination with the ability of standard KV cache. """ def __init__(self) -> None: LayerCache.__init__(self) IndexedCache.__init__(self) class LayerIndexedSinkCache(LayerCache, IndexedSinkCache): """ A cache for storing the key-value pairs for layers, in combination with the ability of sink KV cache. """ def __init__(self) -> None: LayerCache.__init__(self) IndexedSinkCache.__init__(self) class LayerIndexedSlidingWindowCache(LayerCache, IndexedSlidingWindowCache): """ A cache for storing the key-value pairs for layers, in combination with the ability of sliding window KV cache. """ def __init__(self) -> None: LayerCache.__init__(self) IndexedSlidingWindowCache.__init__(self) class LayerIndexedHybridCache(LayerCache, IndexedHybridCache): """ A cache for storing the key-value pairs for layers, in combination with the ability of hybrid KV cache. """ def __init__(self) -> None: LayerCache.__init__(self) IndexedHybridCache.__init__(self) class AutoLayerCache(torch.nn.Module): """ AutoLayerCache is a module that automatically creates a cache from an existing cache. """ CACHE_MAPPING = { DynamicCache: LayerIndexedCache, SinkCache: LayerIndexedSinkCache, IndexedSlidingWindowCache: LayerIndexedSlidingWindowCache, IndexedHybridCache: LayerIndexedHybridCache, } def __init__(self, *args, **kwargs): raise RuntimeError( f"{self.__class__.__name__} is designed to be instantiated " f"using the `{self.__class__.__name__}.from_cache(cache)` method." ) @classmethod def from_cache(cls, cache: Cache, *args, **kwargs): """ Create a new cache from an existing cache. The new cache will have the same type as the original cache. """ cache_type = type(cache) if cache_type not in cls.CACHE_MAPPING: raise ValueError(f"Cache type {cache_type} is not supported by {cls.__name__}.") cache_class = cls.CACHE_MAPPING[cache_type] if hasattr(cache_class, "from_cache"): return cache_class.from_cache(cache, *args, **kwargs) else: # we init an empty cache and copy the attributes new_cache = cache_class(*args, **kwargs) new_cache.__dict__.update(cache.__dict__) return new_cache