|
import torch |
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
|
""" |
|
Repeats key-value hidden states along the key-value head dimension. |
|
Args: |
|
hidden_states (torch.Tensor): Input tensor with shape either |
|
(batch, num_key_value_heads, seqlen, head_dim) or |
|
(num_layers, batch, num_key_value_heads, seqlen, head_dim). |
|
n_rep (int): Number of repetitions for key-value heads. |
|
Returns: |
|
torch.Tensor: The repeated tensor with shape either |
|
(batch, num_attention_heads, seqlen, head_dim) or |
|
(num_layers, batch, num_attention_heads, seqlen, head_dim). |
|
""" |
|
if hidden_states.dim() == 4: |
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape |
|
if n_rep == 1: |
|
return hidden_states |
|
hidden_states = hidden_states.unsqueeze(2).expand(batch, num_key_value_heads, n_rep, slen, head_dim) |
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) |
|
|
|
elif hidden_states.dim() == 5: |
|
num_layers, batch, num_key_value_heads, slen, head_dim = hidden_states.shape |
|
if n_rep == 1: |
|
return hidden_states |
|
hidden_states = hidden_states.unsqueeze(3).expand(num_layers, batch, num_key_value_heads, n_rep, slen, head_dim) |
|
return hidden_states.reshape(num_layers, batch, num_key_value_heads * n_rep, slen, head_dim) |
|
|
|
else: |
|
raise ValueError("Input tensor must have 4 or 5 dimensions.") |
|
|
|
import math |
|
|
|
def calculate_tokens_suggest_compression_ratio(text, tokenizer, model): |
|
""" |
|
Tokenizes the text and returns: |
|
- token_count: the number of tokens in the input text. |
|
- suggestions: a list of 6 candidate compression ratios. |
|
- tokenized: a dictionary containing 'input_ids' and 'attention_mask'. |
|
|
|
The suggestions are chosen so that compressing the token count by these ratios |
|
would (in the worst case) bring the count within the maximum allowed tokens (128k). |
|
|
|
If the text already fits within the context (<= 128k tokens), |
|
the default suggestions [1, 2, 4, 8, 16, 32] are returned. |
|
If the text is too long, we generate six values in logarithmic space |
|
between max(required_ratio, 1) and 32 (or a higher upper bound if needed). |
|
""" |
|
tokenized = tokenizer(text, return_tensors="pt", truncation=False) |
|
token_ids = tokenized["input_ids"][0] |
|
token_count = token_ids.size(0) |
|
max_context = model.config.max_position_embeddings |
|
if token_count <= max_context: |
|
required_ratio = 1.0 |
|
else: |
|
required_ratio = token_count / max_context |
|
if required_ratio <= 1.0: |
|
suggestions = [1, 2, 4, 8, 16, 32] |
|
else: |
|
lower_bound = max(required_ratio, 1) |
|
if required_ratio < 32: |
|
upper_bound = 32 |
|
else: |
|
upper_bound = required_ratio * (32 / 1) |
|
suggestions = [ |
|
round(math.exp(math.log(lower_bound) + i * (math.log(upper_bound) - math.log(lower_bound)) / (6 - 1)), 2) |
|
for i in range(6) |
|
] |
|
|
|
return token_count, suggestions, tokenized |
|
|
|
|
|
def update_retrieval_context(token_count, compression_ratio): |
|
retrieval_tokens = int(token_count / compression_ratio) |
|
return f"Retrieval context tokens (after compression): {retrieval_tokens}" |
|
|
|
|
|
|