|
import math |
|
import torch |
|
from cache import FinchCache |
|
from utils import repeat_kv |
|
from transformers.models.llama.modeling_llama import rotate_half |
|
import spaces |
|
|
|
@spaces.GPU |
|
def get_compressed_kv_cache(model, sink_tokens, step_size, target_token_size, context_ids, context_attention_mask, question_ids, question_attention_mask): |
|
device = model.device |
|
dtype = model.dtype |
|
sink_tokens = sink_tokens |
|
num_chunks = step_size |
|
context_ids = context_ids.to(device) |
|
context_attention_mask = context_attention_mask.to(device) |
|
question_ids = question_ids.to(device) |
|
question_attention_mask = question_attention_mask.to(device) |
|
question_len = question_ids.size(1) |
|
total_len = context_ids.size(1) |
|
max_context_tokens_allowed = model.config.max_position_embeddings - question_len |
|
if total_len > max_context_tokens_allowed: |
|
num_chunks = max(step_size, math.ceil(total_len / max_context_tokens_allowed)) |
|
|
|
if total_len <= sink_tokens or num_chunks == 1: |
|
|
|
context_ids_list = [context_ids] |
|
context_attention_mask_list = [context_attention_mask] |
|
else: |
|
|
|
remainder_len = total_len - sink_tokens |
|
|
|
|
|
base = remainder_len // num_chunks |
|
leftover = remainder_len % num_chunks |
|
|
|
|
|
|
|
chunk_sizes = [sink_tokens + base] |
|
|
|
|
|
for _ in range(num_chunks - 2): |
|
chunk_sizes.append(base) |
|
|
|
|
|
if num_chunks > 1: |
|
chunk_sizes.append(base + leftover) |
|
|
|
|
|
context_ids_list = [] |
|
context_attention_mask_list = [] |
|
offset = 0 |
|
for size in chunk_sizes: |
|
end = offset + size |
|
context_ids_list.append(context_ids[:, offset:end]) |
|
context_attention_mask_list.append(context_attention_mask[:, offset:end]) |
|
offset = end |
|
|
|
|
|
len_rest = max(total_len - sink_tokens, 1) |
|
compression_factor = len_rest // target_token_size |
|
if compression_factor < 1: |
|
compression_factor = 1 |
|
|
|
tokenized_doc_chunks = [] |
|
for ids_chunk, mask_chunk in zip(context_ids_list, context_attention_mask_list): |
|
tokenized_doc_chunks.append({"input_ids": ids_chunk, "attention_mask": mask_chunk}) |
|
|
|
print("Number of chunks: ", len(tokenized_doc_chunks)) |
|
|
|
rotary_emb = model.model.rotary_emb.to(device) |
|
inv_freq = rotary_emb.inv_freq |
|
batch_size = question_ids.size(0) |
|
ones_mask = torch.ones(batch_size, 1, dtype=question_attention_mask.dtype, device=device) |
|
|
|
cache = FinchCache() |
|
past_cache_len = 0 |
|
past_attention_mask = torch.zeros(batch_size, 0, dtype=question_attention_mask.dtype, device=device) |
|
num_chunks = len(tokenized_doc_chunks) |
|
|
|
|
|
query_context_matrices = {} |
|
|
|
|
|
def query_hook_fn(module, input, output): |
|
layer_idx = getattr(module, "layer_idx", None) |
|
if layer_idx is not None: |
|
query_states = output.detach() |
|
bsz, seq_len, hidden_dim = query_states.size() |
|
num_query_heads = module.num_query_heads |
|
head_dim = hidden_dim // num_query_heads |
|
query_states = ( |
|
query_states.view(bsz, seq_len, num_query_heads, head_dim) |
|
.transpose(1, 2) |
|
.contiguous() |
|
) |
|
|
|
query_context_matrices[layer_idx] = query_states[:, :, _current_chunk_offset:, :].clone() |
|
|
|
|
|
hooks = [] |
|
for i, layer in enumerate(model.model.layers): |
|
layer.self_attn.q_proj.layer_idx = i |
|
layer.self_attn.q_proj.num_query_heads = layer.self_attn.config.num_attention_heads |
|
hook = layer.self_attn.q_proj.register_forward_hook(query_hook_fn) |
|
hooks.append(hook) |
|
|
|
|
|
for j, tokenized_doc_chunk in enumerate(tokenized_doc_chunks): |
|
current_seq_length = tokenized_doc_chunk["input_ids"].size(1) |
|
|
|
_current_chunk_offset = current_seq_length |
|
|
|
query_context_matrices.clear() |
|
|
|
|
|
chunk_input_ids = tokenized_doc_chunk["input_ids"].contiguous() |
|
chunk_attention_mask = tokenized_doc_chunk["attention_mask"].contiguous() |
|
segment_attention_mask = torch.cat( |
|
[past_attention_mask, chunk_attention_mask, ones_mask], dim=-1 |
|
).contiguous() |
|
current_input_ids = torch.cat([chunk_input_ids, question_ids], dim=-1).contiguous() |
|
current_attention_mask = torch.cat([segment_attention_mask, question_attention_mask], dim=-1).contiguous() |
|
|
|
past_seen_tokens = cache.get_seq_length() if cache is not None else 0 |
|
cache_position = torch.arange( |
|
past_seen_tokens + chunk_input_ids.shape[1], |
|
past_seen_tokens + current_input_ids.shape[1], |
|
device=device |
|
) |
|
causal_mask = model.model._prepare_4d_causal_attention_mask_with_cache_position( |
|
current_attention_mask, |
|
sequence_length=question_ids.size(1), |
|
target_length=current_attention_mask.size(-1), |
|
dtype=dtype, |
|
device=device, |
|
cache_position=cache_position, |
|
batch_size=current_input_ids.size(0), |
|
).contiguous() |
|
|
|
with torch.no_grad(): |
|
outputs = model.model( |
|
input_ids=current_input_ids, |
|
use_cache=True, |
|
past_key_values=cache, |
|
) |
|
cache = outputs.past_key_values |
|
|
|
len_question = question_ids.size(1) |
|
|
|
for layer_idx in range(len(model.model.layers)): |
|
key_matrix = cache.key_cache[layer_idx] |
|
query_matrix = query_context_matrices[layer_idx] |
|
layer_cache_pos = torch.arange( |
|
past_cache_len + current_seq_length, |
|
past_cache_len + current_seq_length + len_question, |
|
device=device |
|
) |
|
position_ids = layer_cache_pos.unsqueeze(0) |
|
cos, sin = rotary_emb(query_matrix, position_ids) |
|
cos = cos.unsqueeze(1) |
|
sin = sin.unsqueeze(1) |
|
query_matrix = (query_matrix * cos) + (rotate_half(query_matrix) * sin) |
|
num_repeats = model.config.num_attention_heads // model.config.num_key_value_heads |
|
key_matrix = repeat_kv(key_matrix, num_repeats) |
|
|
|
scaling = math.sqrt(model.config.head_dim) |
|
attention_matrix = torch.matmul(query_matrix, key_matrix.transpose(2, 3)) / scaling |
|
causal_mask_sliced = causal_mask[:, :, :, : key_matrix.shape[-2]] |
|
attention_matrix = attention_matrix + causal_mask_sliced |
|
attention_matrix = torch.nn.functional.softmax(attention_matrix, dim=-1, dtype=torch.float32).to(query_matrix.dtype) |
|
|
|
tol = 1e-8 |
|
binary_mask = (torch.abs(causal_mask_sliced.to(torch.float32)) < tol).to(torch.float32) |
|
non_zero_counts = binary_mask.sum(dim=3, keepdim=True) |
|
non_zero_counts = torch.clamp_min(non_zero_counts, 1.0).to(attention_matrix.dtype) |
|
attention_matrix = attention_matrix / non_zero_counts |
|
if j != num_chunks - 1: |
|
attention_matrix = attention_matrix[:, :, :, : past_cache_len + current_seq_length].clone().contiguous() |
|
else: |
|
attention_matrix = attention_matrix[:, :, :, : past_cache_len + current_seq_length + len_question].clone().contiguous() |
|
attention_matrix = torch.sum(attention_matrix, dim=-2) |
|
attention_matrix = attention_matrix.view( |
|
attention_matrix.size(0), model.config.num_key_value_heads, num_repeats, -1 |
|
).sum(dim=2) |
|
full_context_size = attention_matrix.size(-1) |
|
attention_matrix[..., :sink_tokens] = float("inf") |
|
if j == num_chunks - 1: |
|
attention_matrix[..., -len_question:] = float("inf") |
|
if j == 0: |
|
k = int(sink_tokens + (max(0, current_seq_length - sink_tokens) // compression_factor)) |
|
k = min(k + past_cache_len, full_context_size) |
|
elif j < num_chunks - 1: |
|
to_keep_new = int(current_seq_length // compression_factor) |
|
k = min(past_cache_len + to_keep_new, full_context_size) |
|
else: |
|
desired_final = sink_tokens + target_token_size + len_question |
|
k = desired_final if full_context_size >= desired_final else full_context_size |
|
k = max(k, sink_tokens) |
|
selected_indices = torch.topk(attention_matrix, k, dim=-1).indices |
|
selected_indices, _ = torch.sort(selected_indices, dim=-1) |
|
cache.compress_cache(layer_idx, selected_indices, inv_freq) |
|
|
|
past_cache_len = cache._seen_tokens |
|
past_attention_mask = torch.ones(1, past_cache_len, device=device) |
|
|
|
|
|
for hook in hooks: |
|
hook.remove() |
|
|
|
return cache |
|
|
|
|