beyondrag / global_compression.py
giulio98's picture
Update app.py
b5ac9e4
raw
history blame
9.93 kB
import math
import torch
from cache import FinchCache
from utils import repeat_kv
from transformers.models.llama.modeling_llama import rotate_half
import spaces
@spaces.GPU
def get_compressed_kv_cache(model, sink_tokens, step_size, target_token_size, context_ids, context_attention_mask, question_ids, question_attention_mask):
device = model.device
dtype = model.dtype
sink_tokens = sink_tokens
num_chunks = step_size
context_ids = context_ids.to(device)
context_attention_mask = context_attention_mask.to(device)
question_ids = question_ids.to(device)
question_attention_mask = question_attention_mask.to(device)
question_len = question_ids.size(1)
total_len = context_ids.size(1)
max_context_tokens_allowed = model.config.max_position_embeddings - question_len
if total_len > max_context_tokens_allowed:
num_chunks = max(step_size, math.ceil(total_len / max_context_tokens_allowed))
if total_len <= sink_tokens or num_chunks == 1:
# If the context is too short or only one chunk is desired, use the entire context.
context_ids_list = [context_ids]
context_attention_mask_list = [context_attention_mask]
else:
# Calculate how many tokens remain after the sink tokens.
remainder_len = total_len - sink_tokens
# Compute the base tokens per chunk and any leftover.
base = remainder_len // num_chunks
leftover = remainder_len % num_chunks
# Build a list of chunk sizes.
# First chunk gets the sink tokens plus base tokens.
chunk_sizes = [sink_tokens + base]
# Chunks 2 to num_chunks-1 get base tokens each.
for _ in range(num_chunks - 2):
chunk_sizes.append(base)
# The last chunk gets the remaining tokens (base + leftover).
if num_chunks > 1:
chunk_sizes.append(base + leftover)
# Now slice the context using the calculated sizes.
context_ids_list = []
context_attention_mask_list = []
offset = 0
for size in chunk_sizes:
end = offset + size
context_ids_list.append(context_ids[:, offset:end])
context_attention_mask_list.append(context_attention_mask[:, offset:end])
offset = end
# (Optional) Continue with the rest of your processing…
len_rest = max(total_len - sink_tokens, 1)
compression_factor = len_rest // target_token_size
if compression_factor < 1:
compression_factor = 1
tokenized_doc_chunks = []
for ids_chunk, mask_chunk in zip(context_ids_list, context_attention_mask_list):
tokenized_doc_chunks.append({"input_ids": ids_chunk, "attention_mask": mask_chunk})
print("Number of chunks: ", len(tokenized_doc_chunks))
rotary_emb = model.model.rotary_emb.to(device)
inv_freq = rotary_emb.inv_freq
batch_size = question_ids.size(0)
ones_mask = torch.ones(batch_size, 1, dtype=question_attention_mask.dtype, device=device)
cache = FinchCache()
past_cache_len = 0
past_attention_mask = torch.zeros(batch_size, 0, dtype=question_attention_mask.dtype, device=device)
num_chunks = len(tokenized_doc_chunks)
# Prepare a shared dictionary for hook outputs.
query_context_matrices = {}
# Define a hook function that uses a per-chunk offset stored on self.
def query_hook_fn(module, input, output):
layer_idx = getattr(module, "layer_idx", None)
if layer_idx is not None:
query_states = output.detach()
bsz, seq_len, hidden_dim = query_states.size()
num_query_heads = module.num_query_heads
head_dim = hidden_dim // num_query_heads
query_states = (
query_states.view(bsz, seq_len, num_query_heads, head_dim)
.transpose(1, 2)
.contiguous()
)
# Use self._current_chunk_offset to select only the new tokens.
query_context_matrices[layer_idx] = query_states[:, :, _current_chunk_offset:, :].clone()
# Pre-register hooks for all layers only once.
hooks = []
for i, layer in enumerate(model.model.layers):
layer.self_attn.q_proj.layer_idx = i # For tracking.
layer.self_attn.q_proj.num_query_heads = layer.self_attn.config.num_attention_heads
hook = layer.self_attn.q_proj.register_forward_hook(query_hook_fn)
hooks.append(hook)
# Process each document chunk sequentially.
for j, tokenized_doc_chunk in enumerate(tokenized_doc_chunks):
current_seq_length = tokenized_doc_chunk["input_ids"].size(1)
# Save the offset in an attribute the hook can access.
_current_chunk_offset = current_seq_length
# Clear the dictionary from any previous chunk.
query_context_matrices.clear()
# These chunks are already on the device.
chunk_input_ids = tokenized_doc_chunk["input_ids"].contiguous()
chunk_attention_mask = tokenized_doc_chunk["attention_mask"].contiguous()
segment_attention_mask = torch.cat(
[past_attention_mask, chunk_attention_mask, ones_mask], dim=-1
).contiguous()
current_input_ids = torch.cat([chunk_input_ids, question_ids], dim=-1).contiguous()
current_attention_mask = torch.cat([segment_attention_mask, question_attention_mask], dim=-1).contiguous()
past_seen_tokens = cache.get_seq_length() if cache is not None else 0
cache_position = torch.arange(
past_seen_tokens + chunk_input_ids.shape[1],
past_seen_tokens + current_input_ids.shape[1],
device=device
)
causal_mask = model.model._prepare_4d_causal_attention_mask_with_cache_position(
current_attention_mask,
sequence_length=question_ids.size(1),
target_length=current_attention_mask.size(-1),
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=current_input_ids.size(0),
).contiguous()
with torch.no_grad():
outputs = model.model(
input_ids=current_input_ids,
use_cache=True,
past_key_values=cache,
)
cache = outputs.past_key_values
len_question = question_ids.size(1)
# Now, for each transformer layer, update the cache using the query/key attention.
for layer_idx in range(len(model.model.layers)):
key_matrix = cache.key_cache[layer_idx]
query_matrix = query_context_matrices[layer_idx]
layer_cache_pos = torch.arange(
past_cache_len + current_seq_length,
past_cache_len + current_seq_length + len_question,
device=device
)
position_ids = layer_cache_pos.unsqueeze(0)
cos, sin = rotary_emb(query_matrix, position_ids)
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
query_matrix = (query_matrix * cos) + (rotate_half(query_matrix) * sin)
num_repeats = model.config.num_attention_heads // model.config.num_key_value_heads
key_matrix = repeat_kv(key_matrix, num_repeats)
scaling = math.sqrt(model.config.head_dim)
attention_matrix = torch.matmul(query_matrix, key_matrix.transpose(2, 3)) / scaling
causal_mask_sliced = causal_mask[:, :, :, : key_matrix.shape[-2]]
attention_matrix = attention_matrix + causal_mask_sliced
attention_matrix = torch.nn.functional.softmax(attention_matrix, dim=-1, dtype=torch.float32).to(query_matrix.dtype)
# Normalization
tol = 1e-8
binary_mask = (torch.abs(causal_mask_sliced.to(torch.float32)) < tol).to(torch.float32)
non_zero_counts = binary_mask.sum(dim=3, keepdim=True)
non_zero_counts = torch.clamp_min(non_zero_counts, 1.0).to(attention_matrix.dtype)
attention_matrix = attention_matrix / non_zero_counts
if j != num_chunks - 1:
attention_matrix = attention_matrix[:, :, :, : past_cache_len + current_seq_length].clone().contiguous()
else:
attention_matrix = attention_matrix[:, :, :, : past_cache_len + current_seq_length + len_question].clone().contiguous()
attention_matrix = torch.sum(attention_matrix, dim=-2)
attention_matrix = attention_matrix.view(
attention_matrix.size(0), model.config.num_key_value_heads, num_repeats, -1
).sum(dim=2)
full_context_size = attention_matrix.size(-1)
attention_matrix[..., :sink_tokens] = float("inf")
if j == num_chunks - 1:
attention_matrix[..., -len_question:] = float("inf")
if j == 0:
k = int(sink_tokens + (max(0, current_seq_length - sink_tokens) // compression_factor))
k = min(k + past_cache_len, full_context_size)
elif j < num_chunks - 1:
to_keep_new = int(current_seq_length // compression_factor)
k = min(past_cache_len + to_keep_new, full_context_size)
else:
desired_final = sink_tokens + target_token_size + len_question# TODO remember to include the question tokens
k = desired_final if full_context_size >= desired_final else full_context_size
k = max(k, sink_tokens)
selected_indices = torch.topk(attention_matrix, k, dim=-1).indices
selected_indices, _ = torch.sort(selected_indices, dim=-1)
cache.compress_cache(layer_idx, selected_indices, inv_freq)
past_cache_len = cache._seen_tokens
past_attention_mask = torch.ones(1, past_cache_len, device=device)
# Remove the hooks once after all chunks are processed.
for hook in hooks:
hook.remove()
return cache