beyondrag / utils.py
giulio98's picture
Update app.py
b5ac9e4
import torch
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
Repeats key-value hidden states along the key-value head dimension.
Args:
hidden_states (torch.Tensor): Input tensor with shape either
(batch, num_key_value_heads, seqlen, head_dim) or
(num_layers, batch, num_key_value_heads, seqlen, head_dim).
n_rep (int): Number of repetitions for key-value heads.
Returns:
torch.Tensor: The repeated tensor with shape either
(batch, num_attention_heads, seqlen, head_dim) or
(num_layers, batch, num_attention_heads, seqlen, head_dim).
"""
if hidden_states.dim() == 4: # (batch, num_key_value_heads, seqlen, head_dim)
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states.unsqueeze(2).expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
elif hidden_states.dim() == 5: # (num_layers, batch, num_key_value_heads, seqlen, head_dim)
num_layers, batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states.unsqueeze(3).expand(num_layers, batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(num_layers, batch, num_key_value_heads * n_rep, slen, head_dim)
else:
raise ValueError("Input tensor must have 4 or 5 dimensions.")
import math
def calculate_tokens_suggest_compression_ratio(text, tokenizer, model):
"""
Tokenizes the text and returns:
- token_count: the number of tokens in the input text.
- suggestions: a list of 6 candidate compression ratios.
- tokenized: a dictionary containing 'input_ids' and 'attention_mask'.
The suggestions are chosen so that compressing the token count by these ratios
would (in the worst case) bring the count within the maximum allowed tokens (128k).
If the text already fits within the context (<= 128k tokens),
the default suggestions [1, 2, 4, 8, 16, 32] are returned.
If the text is too long, we generate six values in logarithmic space
between max(required_ratio, 1) and 32 (or a higher upper bound if needed).
"""
tokenized = tokenizer(text, return_tensors="pt", truncation=False)
token_ids = tokenized["input_ids"][0]
token_count = token_ids.size(0)
max_context = model.config.max_position_embeddings
if token_count <= max_context:
required_ratio = 1.0
else:
required_ratio = token_count / max_context
if required_ratio <= 1.0:
suggestions = [1, 2, 4, 8, 16, 32]
else:
lower_bound = max(required_ratio, 1)
if required_ratio < 32:
upper_bound = 32
else:
upper_bound = required_ratio * (32 / 1)
suggestions = [
round(math.exp(math.log(lower_bound) + i * (math.log(upper_bound) - math.log(lower_bound)) / (6 - 1)), 2)
for i in range(6)
]
return token_count, suggestions, tokenized
def update_retrieval_context(token_count, compression_ratio):
retrieval_tokens = int(token_count / compression_ratio)
return f"Retrieval context tokens (after compression): {retrieval_tokens}"