beyondrag / cache.py
giulio98's picture
Update app.py
b5ac9e4
raw
history blame
3.33 kB
from transformers import DynamicCache
import torch
import os
class FinchCache(DynamicCache):
def __init__(self) -> None:
super().__init__()
self.key_cache = []
self.value_cache = []
@staticmethod
def _rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def _apply_key_rotary_pos_emb(self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
return (key_states * cos) + (self._rotate_half(key_states) * sin)
@staticmethod
def _rerotate_cos_sin(x, inv_freq, important_pos_batch):
B, H, L = important_pos_batch.shape
device = important_pos_batch.device
device_type = x.device.type
dtype = x.dtype
idx = torch.arange(0, L, device=device)
idx = idx.unsqueeze(0)
inv_freq = inv_freq[None, None, :, None].float().expand(B, H, -1, 1) # (B, H, M, 1)
idx = idx[:, None, :].float().expand(B, H, L) # (B, H, L)
delta_pos = idx - important_pos_batch
delta_pos = delta_pos.unsqueeze(2) # (B, H, 1, L)
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = delta_pos.float() * inv_freq.float()
freqs = freqs.transpose(2, 3)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos().contiguous()
sin = emb.sin().contiguous()
return cos.to(dtype=dtype), sin.to(dtype=dtype)
@staticmethod
def gather_important_tokens(states, indices):
return torch.gather(states, 2, indices.unsqueeze(-1).expand(-1, -1, -1, states.size(3))).contiguous()
def compress_cache(self, layer_index, important_pos, inv_freq):
new_length = important_pos.size(2)
new_cos, new_sin = self._rerotate_cos_sin(self.key_cache[layer_index], inv_freq, important_pos)
gathered_keys = self.gather_important_tokens(self.key_cache[layer_index], important_pos).clone()
self.key_cache[layer_index] = self._apply_key_rotary_pos_emb(gathered_keys, new_cos, new_sin)
gathered_values = self.gather_important_tokens(self.value_cache[layer_index], important_pos).clone()
self.value_cache[layer_index] = gathered_values
self._seen_tokens = new_length
def save(self, path: str):
"""Save the cache to disk, moving tensors to CPU."""
try:
os.makedirs(os.path.dirname(path), exist_ok=True)
torch.save(
{"key_cache": [k.cpu() for k in self.key_cache], "value_cache": [v.cpu() for v in self.value_cache]},
path,
)
except Exception as e:
print(f"Error occurred while saving: {e}")
@classmethod
def load(cls, path: str, device: str = "cpu") -> "FinchCache":
"""Load the cache from disk and move tensors to the specified device."""
data = torch.load(path, map_location=device)
cache = cls()
cache.key_cache = [k.to(device) for k in data["key_cache"]]
cache.value_cache = [v.to(device) for v in data["value_cache"]]
cache._seen_tokens = cache.value_cache[0].size(2) if cache.value_cache else 0
return cache