File size: 3,332 Bytes
b5ac9e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from transformers import DynamicCache
import torch
import os

class FinchCache(DynamicCache):
    def __init__(self) -> None:
        super().__init__()
        self.key_cache = []
        self.value_cache = []

    @staticmethod
    def _rotate_half(x):
        x1 = x[..., : x.shape[-1] // 2]
        x2 = x[..., x.shape[-1] // 2 :]
        return torch.cat((-x2, x1), dim=-1)

    def _apply_key_rotary_pos_emb(self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
        return (key_states * cos) + (self._rotate_half(key_states) * sin)

    @staticmethod
    def _rerotate_cos_sin(x, inv_freq, important_pos_batch):
        B, H, L = important_pos_batch.shape
        device = important_pos_batch.device
        device_type = x.device.type
        dtype = x.dtype
        idx = torch.arange(0, L, device=device)
        idx = idx.unsqueeze(0)
        inv_freq = inv_freq[None, None, :, None].float().expand(B, H, -1, 1) # (B, H, M, 1)
        idx = idx[:, None, :].float().expand(B, H, L) # (B, H, L)
        delta_pos =  idx - important_pos_batch
        delta_pos = delta_pos.unsqueeze(2) # (B, H, 1, L)

        device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"

        with torch.autocast(device_type=device_type, enabled=False):
            freqs = delta_pos.float() * inv_freq.float()
            freqs = freqs.transpose(2, 3)
            emb = torch.cat((freqs, freqs), dim=-1)
            cos = emb.cos().contiguous()
            sin = emb.sin().contiguous()
        return cos.to(dtype=dtype), sin.to(dtype=dtype)

    @staticmethod
    def gather_important_tokens(states, indices):
        return torch.gather(states, 2, indices.unsqueeze(-1).expand(-1, -1, -1, states.size(3))).contiguous()

    def compress_cache(self, layer_index, important_pos, inv_freq):
        new_length = important_pos.size(2)
        new_cos, new_sin = self._rerotate_cos_sin(self.key_cache[layer_index], inv_freq, important_pos)
        gathered_keys = self.gather_important_tokens(self.key_cache[layer_index], important_pos).clone()
        self.key_cache[layer_index] = self._apply_key_rotary_pos_emb(gathered_keys, new_cos, new_sin)
        gathered_values = self.gather_important_tokens(self.value_cache[layer_index], important_pos).clone()
        self.value_cache[layer_index] = gathered_values
        self._seen_tokens = new_length

    def save(self, path: str):
        """Save the cache to disk, moving tensors to CPU."""
        try:
            os.makedirs(os.path.dirname(path), exist_ok=True)
            torch.save(
                {"key_cache": [k.cpu() for k in self.key_cache], "value_cache": [v.cpu() for v in self.value_cache]},
                path,
            )
        except Exception as e:
            print(f"Error occurred while saving: {e}")

    @classmethod
    def load(cls, path: str, device: str = "cpu") -> "FinchCache":
        """Load the cache from disk and move tensors to the specified device."""
        data = torch.load(path, map_location=device)
        cache = cls()
        cache.key_cache = [k.to(device) for k in data["key_cache"]]
        cache.value_cache = [v.to(device) for v in data["value_cache"]]
        cache._seen_tokens = cache.value_cache[0].size(2) if cache.value_cache else 0
        return cache