Spaces:
Running
on
Zero
Running
on
Zero
File size: 9,934 Bytes
b5ac9e4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 |
import math
import torch
from cache import FinchCache
from utils import repeat_kv
from transformers.models.llama.modeling_llama import rotate_half
import spaces
@spaces.GPU
def get_compressed_kv_cache(model, sink_tokens, step_size, target_token_size, context_ids, context_attention_mask, question_ids, question_attention_mask):
device = model.device
dtype = model.dtype
sink_tokens = sink_tokens
num_chunks = step_size
context_ids = context_ids.to(device)
context_attention_mask = context_attention_mask.to(device)
question_ids = question_ids.to(device)
question_attention_mask = question_attention_mask.to(device)
question_len = question_ids.size(1)
total_len = context_ids.size(1)
max_context_tokens_allowed = model.config.max_position_embeddings - question_len
if total_len > max_context_tokens_allowed:
num_chunks = max(step_size, math.ceil(total_len / max_context_tokens_allowed))
if total_len <= sink_tokens or num_chunks == 1:
# If the context is too short or only one chunk is desired, use the entire context.
context_ids_list = [context_ids]
context_attention_mask_list = [context_attention_mask]
else:
# Calculate how many tokens remain after the sink tokens.
remainder_len = total_len - sink_tokens
# Compute the base tokens per chunk and any leftover.
base = remainder_len // num_chunks
leftover = remainder_len % num_chunks
# Build a list of chunk sizes.
# First chunk gets the sink tokens plus base tokens.
chunk_sizes = [sink_tokens + base]
# Chunks 2 to num_chunks-1 get base tokens each.
for _ in range(num_chunks - 2):
chunk_sizes.append(base)
# The last chunk gets the remaining tokens (base + leftover).
if num_chunks > 1:
chunk_sizes.append(base + leftover)
# Now slice the context using the calculated sizes.
context_ids_list = []
context_attention_mask_list = []
offset = 0
for size in chunk_sizes:
end = offset + size
context_ids_list.append(context_ids[:, offset:end])
context_attention_mask_list.append(context_attention_mask[:, offset:end])
offset = end
# (Optional) Continue with the rest of your processing…
len_rest = max(total_len - sink_tokens, 1)
compression_factor = len_rest // target_token_size
if compression_factor < 1:
compression_factor = 1
tokenized_doc_chunks = []
for ids_chunk, mask_chunk in zip(context_ids_list, context_attention_mask_list):
tokenized_doc_chunks.append({"input_ids": ids_chunk, "attention_mask": mask_chunk})
print("Number of chunks: ", len(tokenized_doc_chunks))
rotary_emb = model.model.rotary_emb.to(device)
inv_freq = rotary_emb.inv_freq
batch_size = question_ids.size(0)
ones_mask = torch.ones(batch_size, 1, dtype=question_attention_mask.dtype, device=device)
cache = FinchCache()
past_cache_len = 0
past_attention_mask = torch.zeros(batch_size, 0, dtype=question_attention_mask.dtype, device=device)
num_chunks = len(tokenized_doc_chunks)
# Prepare a shared dictionary for hook outputs.
query_context_matrices = {}
# Define a hook function that uses a per-chunk offset stored on self.
def query_hook_fn(module, input, output):
layer_idx = getattr(module, "layer_idx", None)
if layer_idx is not None:
query_states = output.detach()
bsz, seq_len, hidden_dim = query_states.size()
num_query_heads = module.num_query_heads
head_dim = hidden_dim // num_query_heads
query_states = (
query_states.view(bsz, seq_len, num_query_heads, head_dim)
.transpose(1, 2)
.contiguous()
)
# Use self._current_chunk_offset to select only the new tokens.
query_context_matrices[layer_idx] = query_states[:, :, _current_chunk_offset:, :].clone()
# Pre-register hooks for all layers only once.
hooks = []
for i, layer in enumerate(model.model.layers):
layer.self_attn.q_proj.layer_idx = i # For tracking.
layer.self_attn.q_proj.num_query_heads = layer.self_attn.config.num_attention_heads
hook = layer.self_attn.q_proj.register_forward_hook(query_hook_fn)
hooks.append(hook)
# Process each document chunk sequentially.
for j, tokenized_doc_chunk in enumerate(tokenized_doc_chunks):
current_seq_length = tokenized_doc_chunk["input_ids"].size(1)
# Save the offset in an attribute the hook can access.
_current_chunk_offset = current_seq_length
# Clear the dictionary from any previous chunk.
query_context_matrices.clear()
# These chunks are already on the device.
chunk_input_ids = tokenized_doc_chunk["input_ids"].contiguous()
chunk_attention_mask = tokenized_doc_chunk["attention_mask"].contiguous()
segment_attention_mask = torch.cat(
[past_attention_mask, chunk_attention_mask, ones_mask], dim=-1
).contiguous()
current_input_ids = torch.cat([chunk_input_ids, question_ids], dim=-1).contiguous()
current_attention_mask = torch.cat([segment_attention_mask, question_attention_mask], dim=-1).contiguous()
past_seen_tokens = cache.get_seq_length() if cache is not None else 0
cache_position = torch.arange(
past_seen_tokens + chunk_input_ids.shape[1],
past_seen_tokens + current_input_ids.shape[1],
device=device
)
causal_mask = model.model._prepare_4d_causal_attention_mask_with_cache_position(
current_attention_mask,
sequence_length=question_ids.size(1),
target_length=current_attention_mask.size(-1),
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=current_input_ids.size(0),
).contiguous()
with torch.no_grad():
outputs = model.model(
input_ids=current_input_ids,
use_cache=True,
past_key_values=cache,
)
cache = outputs.past_key_values
len_question = question_ids.size(1)
# Now, for each transformer layer, update the cache using the query/key attention.
for layer_idx in range(len(model.model.layers)):
key_matrix = cache.key_cache[layer_idx]
query_matrix = query_context_matrices[layer_idx]
layer_cache_pos = torch.arange(
past_cache_len + current_seq_length,
past_cache_len + current_seq_length + len_question,
device=device
)
position_ids = layer_cache_pos.unsqueeze(0)
cos, sin = rotary_emb(query_matrix, position_ids)
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
query_matrix = (query_matrix * cos) + (rotate_half(query_matrix) * sin)
num_repeats = model.config.num_attention_heads // model.config.num_key_value_heads
key_matrix = repeat_kv(key_matrix, num_repeats)
scaling = math.sqrt(model.config.head_dim)
attention_matrix = torch.matmul(query_matrix, key_matrix.transpose(2, 3)) / scaling
causal_mask_sliced = causal_mask[:, :, :, : key_matrix.shape[-2]]
attention_matrix = attention_matrix + causal_mask_sliced
attention_matrix = torch.nn.functional.softmax(attention_matrix, dim=-1, dtype=torch.float32).to(query_matrix.dtype)
# Normalization
tol = 1e-8
binary_mask = (torch.abs(causal_mask_sliced.to(torch.float32)) < tol).to(torch.float32)
non_zero_counts = binary_mask.sum(dim=3, keepdim=True)
non_zero_counts = torch.clamp_min(non_zero_counts, 1.0).to(attention_matrix.dtype)
attention_matrix = attention_matrix / non_zero_counts
if j != num_chunks - 1:
attention_matrix = attention_matrix[:, :, :, : past_cache_len + current_seq_length].clone().contiguous()
else:
attention_matrix = attention_matrix[:, :, :, : past_cache_len + current_seq_length + len_question].clone().contiguous()
attention_matrix = torch.sum(attention_matrix, dim=-2)
attention_matrix = attention_matrix.view(
attention_matrix.size(0), model.config.num_key_value_heads, num_repeats, -1
).sum(dim=2)
full_context_size = attention_matrix.size(-1)
attention_matrix[..., :sink_tokens] = float("inf")
if j == num_chunks - 1:
attention_matrix[..., -len_question:] = float("inf")
if j == 0:
k = int(sink_tokens + (max(0, current_seq_length - sink_tokens) // compression_factor))
k = min(k + past_cache_len, full_context_size)
elif j < num_chunks - 1:
to_keep_new = int(current_seq_length // compression_factor)
k = min(past_cache_len + to_keep_new, full_context_size)
else:
desired_final = sink_tokens + target_token_size + len_question# TODO remember to include the question tokens
k = desired_final if full_context_size >= desired_final else full_context_size
k = max(k, sink_tokens)
selected_indices = torch.topk(attention_matrix, k, dim=-1).indices
selected_indices, _ = torch.sort(selected_indices, dim=-1)
cache.compress_cache(layer_idx, selected_indices, inv_freq)
past_cache_len = cache._seen_tokens
past_attention_mask = torch.ones(1, past_cache_len, device=device)
# Remove the hooks once after all chunks are processed.
for hook in hooks:
hook.remove()
return cache
|