|
from transformers import DynamicCache |
|
import torch |
|
import os |
|
|
|
class FinchCache(DynamicCache): |
|
def __init__(self) -> None: |
|
super().__init__() |
|
self.key_cache = [] |
|
self.value_cache = [] |
|
|
|
@staticmethod |
|
def _rotate_half(x): |
|
x1 = x[..., : x.shape[-1] // 2] |
|
x2 = x[..., x.shape[-1] // 2 :] |
|
return torch.cat((-x2, x1), dim=-1) |
|
|
|
def _apply_key_rotary_pos_emb(self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: |
|
return (key_states * cos) + (self._rotate_half(key_states) * sin) |
|
|
|
@staticmethod |
|
def _rerotate_cos_sin(x, inv_freq, important_pos_batch): |
|
B, H, L = important_pos_batch.shape |
|
device = important_pos_batch.device |
|
device_type = x.device.type |
|
dtype = x.dtype |
|
idx = torch.arange(0, L, device=device) |
|
idx = idx.unsqueeze(0) |
|
inv_freq = inv_freq[None, None, :, None].float().expand(B, H, -1, 1) |
|
idx = idx[:, None, :].float().expand(B, H, L) |
|
delta_pos = idx - important_pos_batch |
|
delta_pos = delta_pos.unsqueeze(2) |
|
|
|
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" |
|
|
|
with torch.autocast(device_type=device_type, enabled=False): |
|
freqs = delta_pos.float() * inv_freq.float() |
|
freqs = freqs.transpose(2, 3) |
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
cos = emb.cos().contiguous() |
|
sin = emb.sin().contiguous() |
|
return cos.to(dtype=dtype), sin.to(dtype=dtype) |
|
|
|
@staticmethod |
|
def gather_important_tokens(states, indices): |
|
return torch.gather(states, 2, indices.unsqueeze(-1).expand(-1, -1, -1, states.size(3))).contiguous() |
|
|
|
def compress_cache(self, layer_index, important_pos, inv_freq): |
|
new_length = important_pos.size(2) |
|
new_cos, new_sin = self._rerotate_cos_sin(self.key_cache[layer_index], inv_freq, important_pos) |
|
gathered_keys = self.gather_important_tokens(self.key_cache[layer_index], important_pos).clone() |
|
self.key_cache[layer_index] = self._apply_key_rotary_pos_emb(gathered_keys, new_cos, new_sin) |
|
gathered_values = self.gather_important_tokens(self.value_cache[layer_index], important_pos).clone() |
|
self.value_cache[layer_index] = gathered_values |
|
self._seen_tokens = new_length |
|
|
|
def save(self, path: str): |
|
"""Save the cache to disk, moving tensors to CPU.""" |
|
try: |
|
os.makedirs(os.path.dirname(path), exist_ok=True) |
|
torch.save( |
|
{"key_cache": [k.cpu() for k in self.key_cache], "value_cache": [v.cpu() for v in self.value_cache]}, |
|
path, |
|
) |
|
except Exception as e: |
|
print(f"Error occurred while saving: {e}") |
|
|
|
@classmethod |
|
def load(cls, path: str, device: str = "cpu") -> "FinchCache": |
|
"""Load the cache from disk and move tensors to the specified device.""" |
|
data = torch.load(path, map_location=device) |
|
cache = cls() |
|
cache.key_cache = [k.to(device) for k in data["key_cache"]] |
|
cache.value_cache = [v.to(device) for v in data["value_cache"]] |
|
cache._seen_tokens = cache.value_cache[0].size(2) if cache.value_cache else 0 |
|
return cache |