huginn-0125 / raven_modeling_minimal.py
JonasGeiping's picture
Update raven_modeling_minimal.py
972cea6 verified
"""Modeling file for HF compatibility and zero-shot experiments."""
import torch
import math
from torch import Tensor
from torch.nn.attention.flex_attention import create_block_mask, BlockMask, flex_attention
from torch.nn.attention import bias as attn_bias
from dataclasses import dataclass
from typing import Union, Optional, Any
from .raven_config_minimal import RavenConfig
from transformers.cache_utils import Cache, DynamicCache, StaticCache
###################### Huggingface Glue code I ##################################################################
from transformers import PreTrainedModel, GenerationMixin
from transformers.utils import ModelOutput
from transformers.generation.utils import GenerateDecoderOnlyOutput
import torch.nn.functional as F
from transformers import GenerationConfig
torch.backends.cuda.enable_math_sdp(False)
class RavenPreTrainedModel(PreTrainedModel):
config_class = RavenConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["SandwichBlock"]
_skip_keys_device_placement = ["past_key_values"]
_tied_weights_keys = ["lm_head.weight"]
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_quantized_cache = False
_supports_static_cache = True
_tp_plan = {}
def _init_weights(self, module):
if not torch.rand((1,)).is_meta:
print("Random Initialization not implemented.")
@dataclass
class CausalLMOutputRecurrentLatents(ModelOutput):
loss: Optional[torch.Tensor] = None
log_ppl: Optional[torch.Tensor] = None
logits: Optional[torch.Tensor] = None
past_key_values: Optional[Cache] = None
latent_states: Optional[torch.Tensor] = None
hidden_states: Optional[torch.Tensor] = None
attention_maps: Optional[dict[int, torch.Tensor]] = None
stats: Optional[dict] = None
###################### Minimal implementation from here ############################################################
class RMSNorm(torch.nn.Module):
"""Saner dtype handling and slightly better for fusion"""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = torch.nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
with torch.autocast(enabled=False, device_type=x.device.type if x.device.type != "meta" else "cuda"):
return self._norm(x.float()).type_as(x) * self.weight
def reset_parameters(self) -> None:
torch.nn.init.ones_(self.weight)
class HuginnDynamicCache(DynamicCache):
def __init__(self, lookup_strategy: str = "full") -> None:
super().__init__()
self._seen_tokens = 0
self.key_cache: dict[int, dict[int, torch.Tensor]] = {}
self.value_cache: dict[int, dict[int, torch.Tensor]] = {}
# structure: cache[index_of_layer_or_recurrent_step][index_in_sequence]
# the cache is held uncoalesced because certain recurrent steps may be missing for some sequence ids if using
# per-token adaptive compute. In those cases, the "lookup_strategy" determines how to proceed
# Also, It is critical that the head indices do not overlap with the recurrent iteration indices
self.lookup_strategy = lookup_strategy
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
step_idx_tensor: torch.Tensor,
lookup_strategy: Optional[str] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
step_idx: int = int(step_idx_tensor) # todo: fix dicts with tensor step_idx, currently the memberships fail
lookup_strategy = self.lookup_strategy if lookup_strategy is None else lookup_strategy
if "compress-" in self.lookup_strategy and step_idx > 1: # hardcode for current model!
if "compress-s" in self.lookup_strategy:
compression_stage = int(self.lookup_strategy.split("compress-")[1][1:])
new_step_idx = (step_idx - 2) % compression_stage + 2
elif "compress-anchor" in self.lookup_strategy:
if step_idx - 2 < 4 * 8: # anchor onto first 8 recurrence steps # noqa: SIM108
new_step_idx = step_idx
else: # then re-use the next 4 KV states = one recurrence for all future recurrence
new_step_idx = 34 + (step_idx - 34) % 4
# print(step_idx, new_step_idx)
else: # compress-r
compression_stage = int(self.lookup_strategy.split("compress-")[1][1:])
new_step_idx = (step_idx - 2) // compression_stage + 2
step_idx = new_step_idx
# Init
if step_idx not in self.key_cache:
self.key_cache[step_idx] = {}
self.value_cache[step_idx] = {}
# Update the number of seen tokens, we assume that step_idx=0 (first prelude) is always hit
if step_idx == 0:
self._seen_tokens += key_states.shape[-2]
# Add entries to cache
for idx, entry in enumerate(key_states.unbind(dim=-2)):
if "compress-" not in self.lookup_strategy:
assert step_idx < 0 or self._seen_tokens - key_states.shape[-2] + idx not in self.key_cache[step_idx]
self.key_cache[step_idx][self._seen_tokens - key_states.shape[-2] + idx] = entry
for idx, entry in enumerate(value_states.unbind(dim=-2)):
self.value_cache[step_idx][self._seen_tokens - value_states.shape[-2] + idx] = entry
# Materialize past state based on lookup strategy:
if len(self.key_cache[step_idx]) == self._seen_tokens or self.lookup_strategy == "full":
# All entries are present, materialize cache as normal
return (
torch.stack(list(self.key_cache[step_idx].values()), dim=-2),
torch.stack(list(self.value_cache[step_idx].values()), dim=-2),
)
else: # some entries were not previously computed
if lookup_strategy.startswith("latest-m4"):
latest_keys = []
latest_values = []
for token_pos in range(self._seen_tokens):
# For steps >= 2, use modulo 4, this hard-codes the huginn block structure for now
if step_idx >= 2:
# Find valid steps for this token position
valid_steps = [s for s in range(step_idx + 1) if token_pos in self.key_cache[s]]
max_step = max([s for s in valid_steps if s >= 2 and s % 4 == step_idx % 4])
else:
max_step = step_idx if token_pos in self.key_cache[step_idx] else 0
latest_keys.append(self.key_cache[max_step][token_pos])
latest_values.append(self.value_cache[max_step][token_pos])
return torch.stack(latest_keys, dim=-2), torch.stack(latest_values, dim=-2)
elif lookup_strategy.startswith("available-m4"):
latest_keys = []
latest_values = []
for token_pos in range(self._seen_tokens):
if token_pos in self.key_cache[step_idx]:
step = step_idx
else:
# Find valid steps for this token position
valid_steps = [s for s in range(step_idx + 1) if token_pos in self.key_cache[s]]
step = max([s for s in valid_steps if s >= 2 and s % 4 == step_idx % 4])
latest_keys.append(self.key_cache[step][token_pos])
latest_values.append(self.value_cache[step][token_pos])
return torch.stack(latest_keys, dim=-2), torch.stack(latest_values, dim=-2)
elif lookup_strategy.startswith("always-last-m4"):
latest_keys = []
latest_values = []
for token_pos in range(self._seen_tokens):
# For steps >= 2, use modulo 4, this hard-codes the huginn block structure for now
if step_idx >= 2:
# Find valid steps for this token position
valid_steps = [key_step for key_step in self.key_cache if token_pos in self.key_cache[key_step]]
max_step = max([s for s in valid_steps if s >= 2 and s % 4 == step_idx % 4])
else:
max_step = step_idx if token_pos in self.key_cache[step_idx] else 0
latest_keys.append(self.key_cache[max_step][token_pos])
latest_values.append(self.value_cache[max_step][token_pos])
return torch.stack(latest_keys, dim=-2), torch.stack(latest_values, dim=-2)
elif lookup_strategy.startswith("skip"):
existing_keys = []
existing_values = []
for token_pos in range(self._seen_tokens):
if token_pos in self.key_cache[step_idx]:
existing_keys.append(self.key_cache[step_idx][token_pos])
existing_values.append(self.value_cache[step_idx][token_pos])
return torch.stack(existing_keys, dim=-2), torch.stack(existing_values, dim=-2)
elif lookup_strategy.startswith("randomized"): # sanity check
rand_keys = []
rand_values = []
for token_pos in range(self._seen_tokens):
if step_idx < 2: # For prelude steps
max_step = step_idx if token_pos in self.key_cache[step_idx] else 0
else: # Get all steps from same block position
curr_modulo = (step_idx - 2) % 4 + 2
valid_steps = [
s
for s in range(2, step_idx + 1)
if (s - 2) % 4 + 2 == curr_modulo and token_pos in self.key_cache[s]
]
max_step = valid_steps[torch.randint(len(valid_steps), (1,))]
rand_keys.append(self.key_cache[max_step][token_pos])
rand_values.append(self.value_cache[max_step][token_pos])
return torch.stack(rand_keys, dim=-2), torch.stack(rand_values, dim=-2)
else:
raise ValueError(f"Unknown lookup strategy: {lookup_strategy}")
def reset(self) -> None:
"""Reset the cache state."""
self._seen_tokens = 0
self.key_cache.clear()
self.value_cache.clear()
def clear_last_k_entries(self, k: int = 0):
"""Partially clear cache."""
assert self._seen_tokens >= k
self._seen_tokens = self._seen_tokens - k
# self.key_cache[step_idx][self._seen_tokens - key_states.shape[-2] + idx] = entry
self.key_cache = {
step: {seq: seq_cache for seq, seq_cache in cache.items() if seq < self._seen_tokens}
for step, cache in self.key_cache.items()
}
self.value_cache = {
step: {seq: seq_cache for seq, seq_cache in cache.items() if seq < self._seen_tokens}
for step, cache in self.value_cache.items()
}
def get_seq_length(self, step_idx: int = 0) -> int:
return self._seen_tokens
def get_memory_usage(self) -> float:
total_bytes = 0
# For each recurrent step/layer index
for step_idx in self.key_cache:
# Get the sequence cache for this step
key_seq_cache = self.key_cache[step_idx]
for seq_idx in key_seq_cache:
key_tensor = key_seq_cache[seq_idx]
# Add memory for of key tensors, assuming value is the same
total_bytes += key_tensor.nelement() * key_tensor.element_size()
return total_bytes * 2 / (1024 * 1024)
class HuginnStaticCache(Cache):
"""Static Cache for the recurrent model"""
is_compileable = False # this is todo
def __init__(
self,
max_length: int,
max_num_steps: int,
num_heads: int,
hidden_dim: int,
batch_size: int = 1,
lookup_strategy: str = "full",
device: Optional[Union[torch.device, str]] = None,
dtype: torch.dtype = torch.float32,
) -> None:
super().__init__()
self._seen_tokens = 0
self.max_length = max_length
self.lookup_strategy = lookup_strategy
# Adjust max_num_steps based on compression strategy
if "compress-" in lookup_strategy:
compression_stage = int(lookup_strategy.split("compress-")[1][1:])
if "compress-s" in lookup_strategy:
# For modulo compression (s), we need steps for 0,1 + compressed steps
self.max_num_steps = 4 + compression_stage
else:
# For relative compression, we need steps for 0,1 + compressed steps
self.max_num_steps = 4 + (max_num_steps - 4 + compression_stage - 1) // compression_stage
else:
self.max_num_steps = max_num_steps
# Pre-allocate cache tensors [steps, batch, heads, seq_len, head_dim]
device = torch.device(device) if device is not None else None
cache_shape = (self.max_num_steps, batch_size, num_heads, max_length, hidden_dim)
self.key_cache = torch.zeros(cache_shape, dtype=dtype, device=device)
self.value_cache = torch.zeros(cache_shape, dtype=dtype, device=device)
self.valid_mask = torch.zeros((self.max_num_steps, max_length), dtype=torch.bool, device=device)
# Mark tensors as static for compile
torch._dynamo.mark_static_address(self.key_cache)
torch._dynamo.mark_static_address(self.value_cache)
torch._dynamo.mark_static_address(self.valid_mask)
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
step_idx: torch.Tensor,
lookup_strategy: Optional[str] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if step_idx == 0:
self._seen_tokens += key_states.shape[-2]
# Adjust step_idx for compression
lookup_strategy = self.lookup_strategy if lookup_strategy is None else lookup_strategy
if "compress-" in lookup_strategy and step_idx > 1:
compression_stage = int(lookup_strategy.split("compress-")[1][1:])
if "compress-s" in lookup_strategy:
step_idx = (step_idx - 2) % compression_stage + 2
else:
step_idx = (step_idx - 2) // compression_stage + 2
start_idx = self._seen_tokens - key_states.shape[-2]
indices = torch.arange(start_idx, start_idx + key_states.shape[-2], device=key_states.device)
self.key_cache[step_idx].index_copy_(2, indices, key_states)
self.value_cache[step_idx].index_copy_(2, indices, value_states)
self.valid_mask[step_idx, start_idx : start_idx + key_states.shape[-2]] = True
# Return based on lookup strategy
if lookup_strategy == "full":
return (
self.key_cache[step_idx, :, :, : self._seen_tokens],
self.value_cache[step_idx, :, :, : self._seen_tokens],
)
elif lookup_strategy.startswith("latest-m4"):
if step_idx >= 2:
pattern_steps = torch.arange(2, step_idx.item() + 1, 4, device=self.valid_mask.device)
pattern_valid = self.valid_mask[pattern_steps]
max_valid_step = pattern_steps[pattern_valid.to(torch.long).argmax(dim=0)]
return (
self.key_cache[max_valid_step, torch.arange(self._seen_tokens)],
self.value_cache[max_valid_step, torch.arange(self._seen_tokens)],
)
return self.key_cache[step_idx, :, :, : self._seen_tokens], self.value_cache[
step_idx, :, :, : self._seen_tokens
]
elif lookup_strategy == "skip":
valid_mask = self.valid_mask[step_idx, : self._seen_tokens]
return (
self.key_cache[step_idx, :, :, : self._seen_tokens][valid_mask],
self.value_cache[step_idx, :, :, : self._seen_tokens][valid_mask],
)
elif lookup_strategy.startswith("randomized"):
if step_idx < 2:
max_step = step_idx
else:
curr_modulo = (step_idx - 2) % 4 + 2
valid_steps = (
torch.where(
(torch.arange(2, step_idx.item() + 1, device=self.valid_mask.device) - 2) % 4 + 2 == curr_modulo
)[0]
+ 2
)
rand_idx = torch.randint(len(valid_steps), (1,), device=valid_steps.device)
max_step = valid_steps[rand_idx]
return self.key_cache[max_step, : self._seen_tokens], self.value_cache[max_step, : self._seen_tokens]
else:
raise ValueError(f"Unknown lookup strategy: {lookup_strategy}")
def reset(self) -> None:
self._seen_tokens = 0
self.key_cache.zero_()
self.value_cache.zero_()
self.valid_mask.zero_()
def get_seq_length(self, step_idx: int = 0) -> int:
return self._seen_tokens
def get_memory_usage(self) -> float:
return (self.key_cache.nelement() + self.value_cache.nelement()) * self.key_cache.element_size() / (1024 * 1024)
ValidCache = HuginnDynamicCache | HuginnStaticCache
class CausalSelfAttention(torch.nn.Module):
def __init__(self, config: RavenConfig) -> None:
super().__init__()
self.config = config
self.n_head = config.num_attention_heads
self.n_kv_heads = config.num_key_value_heads
self.head_dim = config.n_embd // self.n_head
shape = (self.n_head + 2 * self.n_kv_heads) * self.head_dim
self.chunks = [config.n_embd, self.n_kv_heads * self.head_dim, self.n_kv_heads * self.head_dim]
self.Wqkv = torch.nn.Linear(config.n_embd, shape, bias=False)
if config.qk_bias:
self.qk_bias = torch.nn.Parameter(torch.zeros(2, 1, self.n_head, self.head_dim))
self.proj = torch.nn.Linear(config.n_embd, config.n_embd, bias=False)
def forward(
self,
x: Tensor,
freqs_cis: Tensor,
block_idx: torch.Tensor,
mask: Optional[BlockMask] = None,
past_key_values: Optional[ValidCache] = None,
) -> Tensor:
B, S, E = x.shape # batch size, sequence length, embedding dimensionality (n_embd)
q, k, v = self.Wqkv(x).split(self.chunks, dim=2)
q = q.view(B, S, self.n_head, self.head_dim)
k = k.view(B, S, self.n_kv_heads, self.head_dim)
v = v.view(B, S, self.n_kv_heads, self.head_dim)
# bias?
if self.config.qk_bias:
q_bias, k_bias = self.qk_bias.split(1, dim=0)
q, k = (q + q_bias).to(q.dtype), (k + k_bias).to(q.dtype)
# apply rotary
q, k = apply_rotary_emb_complex_like(q, k, freqs_cis=freqs_cis)
q = q.transpose(1, 2) # (B, nh, S, hs)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
if past_key_values is not None:
k, v = past_key_values.update(k, v, block_idx)
if mask is not None:
y: torch.Tensor = flex_attention(q, k, v, block_mask=mask) # type: ignore
else:
if q.shape[2] < k.shape[2]:
if q.shape[2] > 1:
bias = attn_bias.causal_lower_right(q.shape[2], k.shape[2])
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, bias, dropout_p=0.0)
else:
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)
else:
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=True)
y = y.transpose(1, 2).reshape(B, S, E).contiguous() # reshape is a view if possible (it mostly is)
return self.proj(y)
class GatedMLP(torch.nn.Module):
def __init__(self, config: RavenConfig, in_features: int = 0) -> None:
super().__init__()
in_features = config.n_embd if in_features == 0 else in_features
self.fc = torch.nn.Linear(in_features, config.intermediate_size * 2, bias=False)
self.proj = torch.nn.Linear(config.intermediate_size, config.n_embd, bias=False)
self.nonlin = torch.nn.SiLU()
def forward(self, x: Tensor) -> Tensor:
# modified to single FC layer to improve parallelism
x_fc_1, x_fc_2 = self.fc(x).chunk(2, dim=-1)
x = self.nonlin(x_fc_1) * x_fc_2
return self.proj(x)
class SandwichBlock(torch.nn.Module):
expanded = False
def __init__(self, config: RavenConfig, layer_id: int) -> None:
super().__init__()
self.norm_1 = RMSNorm(config.n_embd, eps=config.norm_eps)
self.attn = CausalSelfAttention(config)
self.norm_2 = RMSNorm(config.n_embd, eps=config.norm_eps)
self.mlp = GatedMLP(config)
self.norm_3 = RMSNorm(config.n_embd, eps=config.norm_eps)
self.norm_4 = RMSNorm(config.n_embd, eps=config.norm_eps)
self.layer_id = layer_id
def forward(
self,
x: Tensor,
freqs_cis: Tensor,
step_idx: int,
mask: Optional[BlockMask] = None,
past_key_values: Optional[ValidCache] = None,
) -> Tensor:
attn_out = self.attn(self.norm_1(x), freqs_cis, step_idx, mask, past_key_values)
x = self.norm_2(attn_out + x)
x = self.norm_4(self.mlp(self.norm_3(x)) + x)
return x
class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
freqs_cis: torch.Tensor
def __init__(
self,
config: RavenConfig,
) -> None:
super().__init__(config)
self.config = config
# Transformer layers
prelude = torch.nn.ModuleList(SandwichBlock(config, layer_id=i) for i in range(config.n_layers_in_prelude))
adapter = torch.nn.Linear(config.n_embd * 2, config.n_embd, bias=config.bias)
core_block = torch.nn.ModuleList(
SandwichBlock(config, layer_id=i + config.n_layers_in_prelude)
for i in range(config.n_layers_in_recurrent_block)
)
o = config.n_layers_in_prelude + config.n_layers_in_recurrent_block * config.mean_recurrence
coda = torch.nn.ModuleList(SandwichBlock(config, layer_id=i + o) for i in range(config.n_layers_in_coda))
self.transformer = torch.nn.ModuleDict(
dict(
wte=torch.nn.Embedding(config.padded_vocab_size, config.n_embd),
prelude=prelude,
adapter=adapter,
core_block=core_block,
coda=coda,
ln_f=RMSNorm(config.n_embd, eps=config.norm_eps), # used twice :>
)
)
self.emb_scale = config.init_values["embed_scale"]
# Head
self.lm_head = torch.nn.Linear(config.n_embd, config.padded_vocab_size, bias=False)
if self.config.tie_embeddings:
self.tie_weights()
# rope
self.register_buffer("freqs_cis", self._precompute_freqs_cis(), persistent=True)
def get_input_embeddings(self):
return self.transformer.wte
def get_output_embeddings(self):
return self.lm_head
def _precompute_freqs_cis(self):
# can actually be a buffer now, and remains in fp32! (at least in the settings I tested)
freqs_cis = precompute_freqs_cis(
self.config.n_embd // self.config.num_attention_heads, self.config.block_size, self.config.rope_base, 1
)
return freqs_cis
def compile_mask(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[ValidCache] = None,
pad_token_id=65509,
) -> Optional[BlockMask]:
batch_size, seq_len = input_ids.shape[0], input_ids.shape[1]
# If no padding and no attention mask, no need for a mask
if attention_mask is None and (input_ids == pad_token_id).sum() == 0:
return None
if past_key_values is not None and seq_len == 1:
return None
# Get total sequence length including cache
cache_len = past_key_values.get_seq_length() if past_key_values is not None else 0
kv_length = cache_len + seq_len
if attention_mask is None:
def mask_mod(b, h, q_idx, kv_idx):
return q_idx >= kv_idx & (input_ids[b, kv_idx] != pad_token_id)
else:
def mask_mod(b, h, q_idx, kv_idx):
return (q_idx >= kv_idx) & (input_ids[b, kv_idx] != pad_token_id) & attention_mask[b, q_idx, kv_idx]
kv_length = past_key_values.get_seq_length() if past_key_values is not None else seq_len
if kv_length == 0:
kv_length = seq_len # prefill
block_mask = create_block_mask(
mask_mod,
B=batch_size,
H=None,
Q_LEN=seq_len,
KV_LEN=kv_length,
device=input_ids.device,
)
# # Define mask_mod function
# def mask_mod(b, h, q_idx, kv_idx):
# # Always apply causal constraint
# is_causal = q_idx >= kv_idx
# # Handle cache vs current tokens
# is_cache = kv_idx < cache_len
# current_idx = kv_idx - cache_len
# # For cache: always valid; For current: check padding
# not_pad = input_ids[b, current_idx] != pad_token_id
# valid = is_cache | not_pad
# # Apply attention mask if provided
# if attention_mask is not None:
# q_idx_curr = q_idx - cache_len
# attn_valid = attention_mask[b, q_idx_curr, current_idx]
# valid = valid & (is_cache | attn_valid)
# return is_causal & valid
# def mask_mod(b, h, q_idx, kv_idx):
# is_causal = q_idx >= kv_idx
# is_current = (kv_idx >= cache_len) & (kv_idx < kv_length)
# current_idx = kv_idx - cache_len
# is_valid = (~is_current) | (
# (current_idx >= 0) & (current_idx < seq_len) & (input_ids != pad_token_id)[b, current_idx % seq_len]
# )
# return is_causal & is_valid
# # Define mask_mod function
# def mask_mod(b, h, q_idx, kv_idx):
# # Always apply causal constraint
# is_causal = q_idx >= kv_idx
# # Handle cache vs current tokens
# is_cache = kv_idx < cache_len
# current_idx = kv_idx - cache_len
# in_bounds = (current_idx >= 0) & (current_idx < seq_len)
# # For cache: always valid; For current: check padding
# not_pad = (input_ids[b, current_idx % seq_len] != pad_token_id) | ~in_bounds
# valid = is_cache | (not_pad & in_bounds)
# # Apply attention mask if provided
# if attention_mask is not None:
# q_idx_curr = q_idx - cache_len
# q_in_bounds = (q_idx_curr >= 0) & (q_idx_curr < seq_len)
# attn_valid = attention_mask[b, q_idx_curr % seq_len, current_idx % seq_len] | ~(in_bounds & q_in_bounds)
# valid = valid & (is_cache | attn_valid)
# return is_causal & valid
# Create block mask
block_mask = create_block_mask(
mask_mod,
B=batch_size,
H=None,
Q_LEN=seq_len,
KV_LEN=kv_length,
device=input_ids.device,
)
return block_mask
def forward(
self,
input_ids: torch.Tensor,
input_embeds: Optional[torch.Tensor] = None,
input_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None, # binary mask of shape q x kv, True=valid position
position_ids: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
num_steps: Optional[torch.Tensor] = None,
past_key_values: Optional[ValidCache] = None,
output_details: dict = {
"return_logits": True,
"return_latents": True,
"return_head": False,
"return_stats": False,
},
use_cache: bool = False,
cache_position: Optional[torch.Tensor] = None,
init_scale: float = 1.0,
**kwargs,
) -> CausalLMOutputRecurrentLatents:
# Support multiple position formats:
if position_ids is None and cache_position is None:
freqs_cis = self.freqs_cis[:, : input_ids.shape[1]]
elif position_ids is not None:
freqs_cis = self.freqs_cis.index_select(1, position_ids.squeeze())
elif cache_position is not None:
freqs_cis = self.freqs_cis[:, cache_position]
if input_embeds is None:
input_embeds = self.transformer.wte(input_ids) # type: ignore # types broken in 2.6+
if self.emb_scale != 1:
input_embeds = input_embeds * self.emb_scale # type: ignore
if use_cache and past_key_values is None:
past_key_values = HuginnDynamicCache()
prepared_attn_mask = None # self.compile_mask(input_ids, attention_mask, past_key_values)
block_idx = torch.tensor(-1, device=torch.device("cpu"), dtype=torch.long) # count in tensors for compile
# Non-recurrent prelude
for block in self.transformer.prelude: # type: ignore # types broken in 2.6+
block_idx += 1
input_embeds = block(input_embeds, freqs_cis, block_idx, prepared_attn_mask, past_key_values)
# Main recurrence
x, num_steps_no_grad, num_steps_with_grad, xk, block_idx = self.iterate_forward(
input_embeds, # type: ignore # mystery typing error
input_states,
freqs_cis,
block_idx,
prepared_attn_mask,
past_key_values,
num_steps,
init_scale,
)
latent_states = x.clone().detach()
# Coda layers
block_idx = torch.tensor(0, device=torch.device("cpu"), dtype=torch.long) # use negative indices for head
for block in self.transformer.coda: # type: ignore # types broken in 2.6+
block_idx -= 1
x = block(x, freqs_cis, block_idx, prepared_attn_mask, past_key_values)
x = self.transformer.ln_f(x) # type: ignore # types broken in 2.6+
# Prediction head, assuming labels really are labels and not equal to input_ids
if labels is not None:
logits = self.lm_head(x).float()
loss = torch.nn.functional.cross_entropy(
logits.view(-1, logits.shape[-1]), labels.view(-1), ignore_index=-100
)
log_ppl = loss.clone().detach().exp()
else:
logits = self.lm_head(x).float()
loss, log_ppl = torch.as_tensor(0.0), torch.as_tensor(0.0)
return CausalLMOutputRecurrentLatents(
loss=loss,
log_ppl=log_ppl,
logits=logits if output_details["return_logits"] else None,
past_key_values=past_key_values,
hidden_states=x if output_details["return_head"] else None,
latent_states=latent_states if output_details["return_latents"] else None,
stats=self.get_stats(logits, x, latent_states, xk, input_embeds, num_steps_no_grad, num_steps_with_grad)
if output_details["return_stats"]
else None,
)
@torch._dynamo.disable(recursive=False) # type: ignore
def iterate_forward(
self,
input_embeds: torch.Tensor,
input_states: torch.Tensor,
freqs_cis,
block_idx: torch.Tensor,
mask: Optional[BlockMask],
past_key_values: Optional[ValidCache] = None,
num_steps: Optional[torch.Tensor] = None,
init_scale: float = 1.0,
):
x = xk = self.initialize_state(input_embeds, scale=init_scale) if input_states is None else input_states.clone()
if num_steps is None:
num_steps_no_grad, num_steps_with_grad = self.randomized_iteration_sampler() # type: ignore
elif hasattr(num_steps, "__len__") and len(num_steps) > 1:
num_steps_no_grad, num_steps_with_grad = num_steps
else:
num_steps_no_grad, num_steps_with_grad = num_steps, torch.tensor(0) if not x.is_meta else 0
with torch.no_grad():
# ultra annoying in ddp due to
# https://discuss.pytorch.org/t/does-distributeddataparallel-work-with-torch-no-grad-and-find-unused-parameters-false/122594
# for now running with find_unused_params=True enabled even though the graph structure is (technically) clear
# and all parameters are always used
for no_grad_step in range(num_steps_no_grad):
xk = x
x, block_idx = self.core_block_forward(
xk, input_embeds, freqs_cis, mask, past_key_values, block_idx, no_grad_step
)
for grad_step in range(num_steps_with_grad):
xk = x
x, block_idx = self.core_block_forward(
xk, input_embeds, freqs_cis, mask, past_key_values, block_idx, num_steps_no_grad + grad_step
)
return self.transformer.ln_f(x), num_steps_no_grad, num_steps_with_grad, xk.detach(), block_idx # type: ignore # types broken in 2.6+
def core_block_forward(
self,
x,
input_embeds,
freqs_cis,
mask: Optional[BlockMask],
past_key_values,
block_idx: torch.Tensor,
current_step: int | Tensor,
):
x = self._maybe_inject_noise(x, current_step)
x = self.transformer.adapter(torch.cat([x, input_embeds.to(x.device)], dim=-1)) # type: ignore # types broken in 2.6+
for block in self.transformer.core_block: # type: ignore # types broken in 2.6+
block_idx += 1
x = block(x, freqs_cis, block_idx, mask, past_key_values)
return x, block_idx
@torch.no_grad()
def iterate_one_step(
self,
input_embeds,
input_states,
position_ids: Optional[torch.Tensor] = None,
cache_position: Optional[torch.Tensor] = None,
block_idx: torch.Tensor = torch.tensor(0, dtype=torch.long),
attention_mask: Optional[BlockMask] = None,
past_key_values: Optional[ValidCache] = None,
current_step: int = 0,
):
if position_ids is None and cache_position is None:
freqs_cis = self.freqs_cis[:, : input_embeds.shape[1]]
elif position_ids is not None:
freqs_cis = self.freqs_cis.index_select(1, position_ids.squeeze())
elif cache_position is not None:
freqs_cis = self.freqs_cis[:, cache_position]
x, block_idx = self.core_block_forward(
input_states,
input_embeds,
freqs_cis,
attention_mask,
past_key_values,
block_idx,
current_step=current_step,
)
return x, block_idx, current_step + 1
def predict_from_latents(
self,
latents,
attention_mask: Optional[BlockMask] = None,
position_ids: Optional[torch.Tensor] = None,
cache_position: Optional[torch.Tensor] = None,
past_key_values: Optional[ValidCache] = None,
):
if position_ids is None and cache_position is None:
freqs_cis = self.freqs_cis[:, : latents.shape[1]]
elif position_ids is not None:
freqs_cis = self.freqs_cis.index_select(1, position_ids.squeeze())
elif cache_position is not None:
freqs_cis = self.freqs_cis[:, cache_position]
x = self.transformer.ln_f(latents) # type: ignore # types broken in 2.6+
# Coda layers
block_idx = torch.tensor(0, device=torch.device("cpu"), dtype=torch.long) # use negative indices for head
for block in self.transformer.coda: # type: ignore # types broken in 2.6+
block_idx -= 1
x = block(x, freqs_cis, block_idx, attention_mask, past_key_values)
x = self.transformer.ln_f(x) # type: ignore # types broken in 2.6+
logits = self.lm_head(x).float()
return CausalLMOutputRecurrentLatents(
loss=torch.as_tensor(0.0),
log_ppl=torch.as_tensor(0.0),
logits=logits,
past_key_values=past_key_values,
latent_states=x,
)
def embed_inputs(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_values: Optional[ValidCache] = None,
use_cache: bool = False,
cache_position: Optional[torch.Tensor] = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
# Support multiple position formats:
if position_ids is None and cache_position is None:
freqs_cis = self.freqs_cis[:, : input_ids.shape[1]]
elif position_ids is not None:
freqs_cis = self.freqs_cis.index_select(1, position_ids.squeeze())
elif cache_position is not None:
freqs_cis = self.freqs_cis[:, cache_position]
input_embeds = self.transformer.wte(input_ids) # type: ignore # types broken in 2.6+
prepared_attn_mask = self.compile_mask(input_ids, attention_mask)
if self.emb_scale != 1:
input_embeds = input_embeds * self.emb_scale # type: ignore
if use_cache and past_key_values is None:
past_key_values = HuginnDynamicCache()
block_idx = torch.tensor(-1, device=torch.device("cpu"), dtype=torch.long) # count in tensors for compile
# Non-recurrent prelude
for block in self.transformer.prelude: # type: ignore # types broken in 2.6+
block_idx += 1
input_embeds = block(input_embeds, freqs_cis, block_idx, prepared_attn_mask, past_key_values)
return input_embeds, block_idx
@torch._dynamo.disable(recursive=False) # type: ignore
def randomized_iteration_sampler(self) -> tuple[torch.Tensor, torch.Tensor]:
"""Outputs are long tensors so that they can be passed through compiled functions"""
t = max(self.config.mean_recurrence - self.config.mean_backprop_depth, 0)
s = self.config.mean_backprop_depth
if torch.rand((1,)).is_meta: # annoying clause to make meta-tensor-based flop counting work
# these values are only the mean TFLOPs of the randomized sampler
# Note that this clause also breaks the contract, and returns ints in meta tensor mode
return t, s # type: ignore
if self.training:
sigma = 0.5
mu = math.log(t + s) - (sigma**2 / 2)
rate = torch.zeros((1,)).log_normal_(mean=mu, std=sigma)
p = torch.poisson(torch.tensor([rate], dtype=torch.float)) + 1
n = torch.clamp(p - s, min=0)
k = torch.as_tensor(torch.minimum(torch.as_tensor(s), p))
else:
n, k = torch.as_tensor(self.config.mean_recurrence), torch.as_tensor(0)
return n.to(dtype=torch.long), k.to(dtype=torch.long)
def initialize_state(self, input_embeds, scale: float = 1.0):
x = torch.randn_like(input_embeds)
std = self.config.init_values["std"] * scale
if std > 0:
torch.nn.init.trunc_normal_(x, mean=0.0, std=std, a=-3 * std, b=3 * std)
if self.emb_scale != 1:
x = x * self.emb_scale
else:
x.zero_()
return x
def _maybe_inject_noise(self, x, current_step, renorm=False):
if self.config.test_time_noise > 0:
n = self.config.test_time_noise * self.config.init_values["std"] * self.emb_scale
if self.config.test_time_noise_type == "geom":
step1 = torch.as_tensor(current_step + 1, device=x.device) # need to cast for compile
x = x * (1 - n / step1) + torch.randn_like(x) * n / step1
elif self.config.test_time_noise_type == "sqrt":
step1sqrt = torch.as_tensor(current_step + 1, device=x.device).sqrt() # need to cast for compile
x = x * (1 - n / step1sqrt) + torch.randn_like(x) * n / step1sqrt
elif self.config.test_time_noise_type == "line":
noise = max(n, (self.config.mean_recurrence - current_step) / self.config.mean_recurrence) # type: ignore
x = x * (1 - noise) + torch.randn_like(x) * noise
elif self.config.test_time_noise_type == "chi":
noise = 2 * torch.rand(1, device=x.device, dtype=x.dtype) * n
x = x * (1 - noise) + torch.randn_like(x) * noise
elif self.config.test_time_noise_type == "fixed":
x = x * (1 - n) + torch.randn_like(x) * n
else:
raise ValueError()
if renorm:
x = self.transformer.core_block[-1].norm_4(x) # type: ignore moduledict types still broken in pytorch
return x
def prepare_inputs_for_generation(
self,
input_ids: torch.Tensor,
past_key_values: Optional[Cache] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
cache_position: Optional[torch.Tensor] = None,
cache_lookup_strategy: str = "full",
**kwargs,
):
model_inputs = {}
model_inputs["cache_position"] = cache_position
current_input_length = input_ids.shape[1]
if past_key_values is not None:
if not isinstance(past_key_values, (HuginnDynamicCache, HuginnStaticCache)):
assert past_key_values.get_seq_length() == 0 # only replace empty caches
# Need to use custom cache, detect and replace HF cache if generate injects it
if isinstance(past_key_values, StaticCache):
past_key_values = HuginnStaticCache(
max_length=getattr(self.generation_config, "max_length", self.config.block_size),
max_num_steps=4 + kwargs.get("num_steps", self.config.mean_recurrence) * 4,
num_heads=self.config.num_key_value_heads,
hidden_dim=self.config.n_embd // self.config.num_attention_heads,
dtype=torch.bfloat16,
device=input_ids.device,
lookup_strategy=cache_lookup_strategy,
)
else:
past_key_values = HuginnDynamicCache(lookup_strategy=cache_lookup_strategy)
model_inputs["past_key_values"] = past_key_values if kwargs["use_cache"] else None
input_ids = input_ids[:, cache_position] # type: ignore
model_inputs["input_ids"] = input_ids.clone(memory_format=torch.contiguous_format)
if cache_position is None:
position_ids = torch.arange(current_input_length)[None, :].to(input_ids.device)
model_inputs["position_ids"] = position_ids[:, -current_input_length:].clone(
memory_format=torch.contiguous_format
) # some form of position_ids is a critical argument for the model to correctly apply rope!
# forward all other entries
for key, value in kwargs.items():
if key not in model_inputs:
model_inputs[key] = value
return model_inputs
@torch.no_grad()
def generate(self, *args, **kwargs):
"""Dispatcher - use HF generate in all normal cases."""
self.generation_config = args[1] if len(args) > 1 else self.generation_config
if any(k in kwargs for k in ("criterion", "exit_threshold")):
# print("Dispatching to custom generate_adaptive function call")
return self.generate_with_adaptive_compute(*args, **kwargs)
elif "continuous_compute" in kwargs:
# print("Dispatching to custom generate_minimal function call")
return self.generate_minimal(*args, **kwargs)
else:
return super().generate(*args, **kwargs)
@torch.no_grad()
def _prep_generate_args(
self,
input_ids: torch.Tensor,
generation_config: Optional[GenerationConfig] = None, # type: ignore
cache_lookup_strategy: str = "full",
model_kwargs: dict = {},
):
# Setup
if generation_config is None:
generation_config: GenerationConfig = self.generation_config # type: ignore
if "max_new_tokens" in model_kwargs:
max_new_tokens = model_kwargs["max_new_tokens"]
if "max_length" in model_kwargs:
max_new_tokens = min(max_new_tokens, model_kwargs["max_length"] - input_ids.shape[1])
else:
max_length = model_kwargs.get("max_length", generation_config.max_length)
max_new_tokens = max_length - input_ids.shape[1]
if "cache_implementation" not in model_kwargs or model_kwargs["cache_implementation"] == "dynamic":
model_kwargs["past_key_values"] = HuginnDynamicCache(lookup_strategy=cache_lookup_strategy)
else:
model_kwargs["past_key_values"] = HuginnStaticCache(
max_length=max_length,
max_num_steps=4 + model_kwargs.get("num_steps", self.config.mean_recurrence) * 4,
num_heads=self.config.num_key_value_heads,
hidden_dim=self.config.n_embd // self.config.num_attention_heads,
batch_size=input_ids.shape[0],
dtype=torch.bfloat16,
device=input_ids.device,
lookup_strategy=cache_lookup_strategy,
)
model_kwargs["use_cache"] = True
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
return model_kwargs, generation_config, max_new_tokens
@torch.no_grad()
def generate_minimal(
self,
input_ids: torch.Tensor,
generation_config: Optional[GenerationConfig] = None, # type: ignore
tokenizer=None,
streamer=None,
continuous_compute=False, # warm-start state / continuous CoT
init_scale: float = 1.0,
cache_lookup_strategy: str = "full",
**model_kwargs,
) -> Union[torch.Tensor, dict[str, Any]]:
"""Minimal single-sequence generation. Template for more complicated generate tasks"""
model_kwargs, generation_config, max_new_tokens = self._prep_generate_args(
input_ids, generation_config, cache_lookup_strategy
)
stop_tokens = self._get_stops(generation_config, tokenizer, model_kwargs).to(input_ids.device)
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
# Set up continuous compute if enabled
if continuous_compute:
embedded_inputs, _ = self.embed_inputs(input_ids)
model_kwargs["input_states"] = self.initialize_state(embedded_inputs, scale=init_scale)
# Generate tokens
batch_size = input_ids.shape[0]
for _ in range(max_new_tokens):
# Forward pass
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
outputs = self(**model_inputs, init_scale=init_scale)
# Get next token
next_token_logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device)
next_token = self._sample_next_token(next_token_logits, generation_config)
# Append token to sequence
input_ids = torch.cat([input_ids, next_token], dim=-1)
if streamer:
streamer.put(next_token.cpu())
# Update model kwargs
model_kwargs = self._update_model_kwargs_for_generation(outputs, model_kwargs)
if continuous_compute:
model_kwargs["input_states"] = outputs.latent_states[:, -1:, :]
if stop_tokens is not None:
for i in range(batch_size):
if unfinished_sequences[i] and next_token[i, 0].item() in stop_tokens:
unfinished_sequences[i] = 0
if "stopping_criteria" in model_kwargs:
unfinished_sequences = unfinished_sequences & ~model_kwargs["stopping_criteria"](input_ids, None)
if unfinished_sequences.max() == 0:
break
if streamer:
streamer.end()
if generation_config.return_dict_in_generate:
return GenerateDecoderOnlyOutput(
sequences=input_ids, # type: ignore
scores=None,
logits=None,
attentions=None,
hidden_states=None,
past_key_values=model_kwargs.get("past_key_values"),
)
return input_ids
@torch.no_grad()
def generate_with_adaptive_compute(
self,
input_ids: torch.Tensor,
generation_config: Optional[GenerationConfig] = None, # type: ignore
tokenizer=None,
streamer=None,
continuous_compute=False, # warm-start state / continuous CoT
criterion="none", # off by default, turn on by choosing an exit criterion
exit_threshold: Union[str, float, int] = "auto",
init_scale: float = 1.0,
cache_lookup_strategy: str = "full",
**model_kwargs,
) -> Union[torch.Tensor, GenerateDecoderOnlyOutput]:
"""
Generate tokens with adaptive compute. This is NOT the most efficient implementation.
For batches, on each token, we iterate until the entire batch finishes.
"""
model_kwargs, generation_config, max_new_tokens = self._prep_generate_args(
input_ids, generation_config, cache_lookup_strategy, model_kwargs
)
max_steps = model_kwargs.get("num_steps", self.config.mean_recurrence)
stop_tokens = self._get_stops(generation_config, tokenizer, model_kwargs).to(input_ids.device)
logit_type = dict(copy=True, dtype=torch.float32, device=input_ids.device)
batch_size = input_ids.shape[0]
compute_steps = []
# Set up continuous compute if enabled
if continuous_compute:
embedded_inputs, _ = self.embed_inputs(input_ids)
model_kwargs["input_states"] = self.initialize_state(embedded_inputs, scale=init_scale)
# Track which sequences have finished (using unfinished_sequences to match generate_minimal)
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
# Generate tokens
for _ in range(max_new_tokens):
# Adaptive compute forward
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
aux_inputs = {
k: model_inputs[k] for k in ["cache_position", "past_key_values", "attention_mask"] if k in model_inputs
}
embedded_inputs, block_idx = self.embed_inputs(model_inputs["input_ids"], **aux_inputs)
current_latents = (
self.initialize_state(embedded_inputs, scale=init_scale)
if not continuous_compute
else model_kwargs["input_states"]
)
# Initialize criterion tracking for each sequence in batch
exit_values_per_seq = [[] for _ in range(batch_size)]
compute_steps_per_seq = [0] * batch_size
exit_reached = torch.zeros(batch_size, dtype=torch.bool, device=input_ids.device)
# Set up criterions based on selected strategy
if criterion == "entropy-diff":
entropy = torch.ones(batch_size, device=input_ids.device) * 100.0
exit_threshold = 1e-3 if exit_threshold == "auto" else float(exit_threshold)
elif criterion == "latent-diff":
exit_threshold = 0.03 if exit_threshold == "auto" else float(exit_threshold)
elif "kl" in criterion:
V = self.config.padded_vocab_size
log_probs = ((1 / V) * torch.ones(batch_size, V, dtype=torch.float, device=input_ids.device)).log()
if criterion == "minp-kl":
exit_threshold = 1e-6 if exit_threshold == "auto" else float(exit_threshold)
else:
exit_threshold = 5e-4 if exit_threshold == "auto" else float(exit_threshold)
elif criterion == "argmax-stability":
stable_for_n_steps = torch.zeros(batch_size, dtype=torch.long, device=input_ids.device)
current_argmax = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) * -1
exit_threshold = 5 if exit_threshold == "auto" else int(exit_threshold)
elif criterion == "none":
exit_threshold = 1.0 if exit_threshold == "auto" else float(exit_threshold)
else:
raise ValueError("Invalid adaptive compute strategy.")
next_token_logits = None
# Iterate through compute steps
for compute_step in range(max_steps):
prev_latents = current_latents.clone()
current_latents, block_idx, _ = self.iterate_one_step(
embedded_inputs,
current_latents,
block_idx=block_idx,
**aux_inputs,
current_step=compute_step,
)
if _ > 0: # do not exit in prefill
# Check exit condition for each sequence in batch
if criterion == "entropy-diff":
prev_entropy = entropy
outputs = self.predict_from_latents(current_latents, **aux_inputs)
logits: torch.Tensor = outputs.logits # type: ignore
probs = F.softmax(logits[:, -1, :], dim=-1)
entropy = -torch.sum(probs * torch.log(probs + 1e-10), dim=-1)
exit_values = (entropy - prev_entropy).abs()
elif criterion == "latent-diff":
norm_diff = (prev_latents - current_latents).norm(dim=-1) / current_latents.norm(dim=-1)
exit_values = norm_diff.mean(dim=-1)
elif "kl" in criterion:
outputs = self.predict_from_latents(current_latents, **aux_inputs)
logits: torch.Tensor = outputs.logits # type: ignore
prev_log_probs = log_probs
if criterion == "minp-kl":
probs = F.softmax(logits[:, -1, :].float(), dim=-1)
max_probs = probs.max(dim=-1, keepdim=True)[0]
probs_mask = probs < (0.1 * max_probs)
masked_probs = probs.clone()
masked_probs[probs_mask] = 1 / V
probs = masked_probs / masked_probs.sum(dim=-1, keepdim=True)
log_probs = probs.log()
else:
log_probs = F.log_softmax(logits[:, -1, :].float(), dim=-1)
exit_values = F.kl_div(log_probs, prev_log_probs, reduction="none", log_target=True).sum(dim=-1)
elif criterion == "argmax-stability":
prev_argmax = current_argmax
outputs = self.predict_from_latents(current_latents, **aux_inputs)
logits: torch.Tensor = outputs.logits # type: ignore
current_argmax = logits[:, -1, :].argmax(dim=-1)
stable_for_n_steps = torch.where(
current_argmax == prev_argmax, stable_for_n_steps + 1, torch.zeros_like(stable_for_n_steps)
)
exit_values = stable_for_n_steps
elif criterion == "none":
exit_values = torch.ones(batch_size, device=input_ids.device) * 2.0 * exit_threshold
# Record values and check exits for each sequence
for i in range(batch_size):
if not exit_reached[i] and unfinished_sequences[i].bool():
exit_values_per_seq[i].append(exit_values[i].item())
# Check for new exits, respecting unfinished_sequences
new_exits = (
exit_values < exit_threshold
if criterion != "argmax-stability"
else exit_values >= exit_threshold
)
new_exits = new_exits & ~exit_reached & unfinished_sequences.bool()
if new_exits.any():
exit_reached = exit_reached | new_exits
if criterion == "latent-diff":
# Normally we don't compute the output for latent-diff, but when there is an exit,
# we need to compute and save the output
outputs = self.predict_from_latents(current_latents, **aux_inputs)
logits: torch.Tensor = outputs.logits # type: ignore
if next_token_logits is None:
next_token_logits = logits[:, -1, :].to(**logit_type) # type: ignore
else:
for i in range(batch_size):
if new_exits[i]:
next_token_logits[i] = logits[i, -1, :].to(**logit_type) # type: ignore
for i in range(batch_size):
if new_exits[i]:
compute_steps_per_seq[i] = compute_step + 1
# If all sequences have exited or finished, break early
if (exit_reached | ~unfinished_sequences.bool()).all():
break
# This else is if the for loop finished without breaking
else:
outputs = self.predict_from_latents(current_latents, **aux_inputs)
# For sequences that didn't exit early, use the final logits
if next_token_logits is None:
next_token_logits = outputs.logits[:, -1, :].to(**logit_type) # type: ignore
else:
for i in range(batch_size):
if not exit_reached[i] and unfinished_sequences[i].bool():
next_token_logits[i] = outputs.logits[i, -1, :].to(**logit_type) # type: ignore
compute_steps_per_seq[i] = max_steps
# Save latent states for continuous compute if enabled
if continuous_compute:
model_kwargs["input_states"] = current_latents[:, -1:, :]
# Record compute steps for this token generation
compute_steps.append([compute_steps_per_seq, exit_values_per_seq])
# Sample or select next token based on generation config
next_token = self._sample_next_token(next_token_logits, generation_config)
# Append token to sequence
input_ids = torch.cat([input_ids, next_token], dim=-1)
if streamer:
streamer.put(next_token.cpu())
# Update model kwargs for next iteration
model_kwargs = self._update_model_kwargs_for_generation(outputs, model_kwargs)
# Check for stop tokens and update unfinished sequences
for i in range(batch_size):
if (
unfinished_sequences[i].bool()
and stop_tokens is not None
and next_token[i, 0].item() in stop_tokens
):
unfinished_sequences[i] = 0
# Apply any custom stopping criteria
if "stopping_criteria" in model_kwargs:
unfinished_sequences = unfinished_sequences & ~model_kwargs["stopping_criteria"](input_ids, None)
# Break if all sequences are finished
if unfinished_sequences.max() == 0:
break
if streamer:
streamer.end()
if generation_config.return_dict_in_generate:
return GenerateDecoderOnlyOutput(
sequences=input_ids, # type: ignore
scores=compute_steps, # type: ignore
logits=None,
attentions=None,
hidden_states=None,
past_key_values=model_kwargs.get("past_key_values"),
)
return input_ids
def _get_stops(self, generation_config, tokenizer, model_kwargs):
stop_tokens = {65504, 65505, 65508} # begin_text, end_text, end_turn
if generation_config.eos_token_id is not None:
stop_tokens.add(generation_config.eos_token_id)
if "stopping_criteria" in model_kwargs and tokenizer is None:
tokenizer = model_kwargs["stopping_criteria"][0].tokenizer
if hasattr(generation_config, "stop_strings") and tokenizer and generation_config.stop_strings:
for s in generation_config.stop_strings:
token_id = tokenizer(s, add_special_tokens=False)["input_ids"][0]
stop_tokens.add(token_id)
return torch.tensor(list(stop_tokens))
def _sample_next_token(self, next_token_logits, generation_config):
"""Helper function to sample the next token."""
if generation_config.do_sample:
if generation_config.temperature:
next_token_logits = next_token_logits.float() / generation_config.temperature
probs = F.softmax(next_token_logits, dim=-1)
# Apply top_k
if generation_config.top_k:
top_k_values, _ = torch.topk(probs, generation_config.top_k, dim=-1)
min_values = top_k_values[:, -1].unsqueeze(-1).expand_as(probs)
probs = torch.where(probs < min_values, torch.zeros_like(probs), probs)
# Apply top_p (nucleus sampling)
if generation_config.top_p:
sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
# Create mask for probs to keep
remove_indices = cumulative_probs > generation_config.top_p
remove_indices[:, 0] = False # Keep at least the top probability
# Convert sorted indices mask back to original indices mask
mask = torch.zeros_like(probs, dtype=torch.bool)
for i in range(probs.shape[0]):
mask[i, sorted_indices[i, remove_indices[i]]] = True
probs = torch.where(mask, torch.zeros_like(probs), probs)
# Apply min_p
if generation_config.min_p:
max_probs = probs.max(dim=-1, keepdim=True)[0]
min_p_threshold = generation_config.min_p * max_probs
probs = torch.where(probs < min_p_threshold, torch.zeros_like(probs), probs)
# Renormalize probabilities
probs = probs / probs.sum(dim=-1, keepdim=True).clamp(min=1e-10)
# Sample from the distribution
return torch.multinomial(probs, num_samples=1)
else:
return torch.argmax(next_token_logits, dim=-1, keepdim=True)
@torch.no_grad()
def generate_speculative(
self,
input_ids: torch.Tensor,
generation_config: Optional[GenerationConfig] = None, # type: ignore
tokenizer=None,
streamer=None,
continuous_compute=False, # warm-start state / continuous CoT
init_scale: float = 1.0,
cache_lookup_strategy: str = "full",
draft_steps=32,
lookahead_for_draft=8,
verification_threshold=1,
num_steps: int = 32, # intercept deliberately
**model_kwargs,
) -> Union[torch.Tensor, dict[str, Any]]:
"""Batched speculative decoding with per-sequence acceptance."""
assert lookahead_for_draft > 0
pad_id = 65509
model_kwargs, generation_config, max_new_tokens = self._prep_generate_args(
input_ids, generation_config, cache_lookup_strategy, model_kwargs
)
stop_tokens = self._get_stops(generation_config, tokenizer, model_kwargs).to(input_ids.device)
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
# Set up continuous compute if enabled
if continuous_compute:
embedded_inputs, _ = self.embed_inputs(input_ids)
model_kwargs["input_states"] = self.initialize_state(embedded_inputs, scale=init_scale)
tokens_generated = 0
# Prefill cache with full num_steps
if model_kwargs["past_key_values"].get_seq_length() == 0:
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
outputs = self(**model_inputs, num_steps=num_steps, init_scale=init_scale)
next_token = self._sample_next_token(
outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32), generation_config
)
input_ids = torch.cat([input_ids, next_token], dim=-1)
tokens_generated += 1
if streamer:
streamer.put(next_token.cpu())
model_kwargs["cache_position"] = torch.as_tensor(
[model_inputs["past_key_values"].get_seq_length()], device=input_ids.device
)
if continuous_compute:
model_kwargs["input_states"] = outputs.latent_states[:, -1:, :]
# Generate tokens
batch_size, prefix_seq_len = input_ids.shape[0], input_ids.shape[1]
accepted_tokens = []
while tokens_generated < max_new_tokens:
### Run the next draft ####
drafted_inputs = input_ids.clone()
current_len = input_ids.shape[1]
for _ in range(lookahead_for_draft):
model_inputs = self.prepare_inputs_for_generation(drafted_inputs, **model_kwargs)
outputs = self(**model_inputs, num_steps=draft_steps, init_scale=init_scale)
next_token_logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32)
next_token = self._sample_next_token(next_token_logits, generation_config)
drafted_inputs = torch.cat([drafted_inputs, next_token], dim=-1)
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1
if continuous_compute:
model_kwargs["input_states"] = outputs.latent_states[:, -1:, :]
model_kwargs["past_key_values"].clear_last_k_entries(lookahead_for_draft)
## Verify drafted tokens ###
model_kwargs["cache_position"] = torch.arange(
current_len - 1, current_len + lookahead_for_draft - 1, device=input_ids.device
)
model_inputs = self.prepare_inputs_for_generation(drafted_inputs, **model_kwargs)
outputs = self(**model_inputs, num_steps=num_steps, init_scale=init_scale)
verified_next_token_preds = outputs.logits.argmax(dim=-1)
if verification_threshold >= 1:
mismatched_tokens = (
verified_next_token_preds[:, -lookahead_for_draft:] != drafted_inputs[:, current_len:]
)
not_all_matched, first_mismatch = torch.max(mismatched_tokens, dim=1)
else:
verified_logits = outputs.logits[:, -lookahead_for_draft:, :]
verified_probs = F.softmax(verified_logits, dim=-1)
drafted_token_probs = torch.gather(
verified_probs, -1, drafted_inputs[:, current_len:].unsqueeze(-1)
).squeeze(-1)
max_probs = verified_probs.max(dim=-1)[0]
verification_passed = drafted_token_probs >= verification_threshold * max_probs
not_all_matched, first_mismatch = torch.max(~verification_passed, dim=1)
# Per-sequence acceptance handling
acceptance_lengths = torch.where(not_all_matched, first_mismatch, lookahead_for_draft)
# Build next_tokens for each sequence
next_tokens_batch = []
for i in range(batch_size):
seq_acceptance = acceptance_lengths[i].item()
if not_all_matched[i] and seq_acceptance < lookahead_for_draft:
# Accept up to mismatch + sample final token
accepted_part = drafted_inputs[i : i + 1, current_len : current_len + seq_acceptance]
final_token_logits = outputs.logits[i : i + 1, seq_acceptance, :].to(copy=True, dtype=torch.float32)
final_token = self._sample_next_token(final_token_logits, generation_config)
seq_tokens = torch.cat([accepted_part, final_token], dim=-1) if seq_acceptance > 0 else final_token
else:
# Accept all drafted tokens
seq_tokens = drafted_inputs[i : i + 1, current_len : current_len + seq_acceptance]
next_tokens_batch.append(seq_tokens)
# Clean up KV cache - only if any sequence had mismatches
if not_all_matched.any():
min_first_mismatch = first_mismatch.min().item()
model_inputs["past_key_values"].clear_last_k_entries(lookahead_for_draft - min_first_mismatch - 1)
# Concatenate accepted tokens to input_ids
batch_accepted_counts = [tokens.shape[1] for tokens in next_tokens_batch]
max_len = max(batch_accepted_counts)
padded_tokens = [
torch.cat(
[
tokens,
pad_id * torch.ones((1, max_len - tokens.shape[1]), dtype=tokens.dtype, device=tokens.device),
],
dim=-1,
)
if tokens.shape[1] < max_len
else tokens
for tokens in next_tokens_batch
]
next_tokens = torch.cat(padded_tokens, dim=0)
input_ids = torch.cat([input_ids, next_tokens], dim=-1)
accepted_tokens.append(batch_accepted_counts)
tokens_generated += max(batch_accepted_counts)
if streamer:
streamer.put(next_tokens_batch[0].cpu())
model_kwargs["cache_position"] = torch.as_tensor(
[model_inputs["past_key_values"].get_seq_length()], device=input_ids.device
)
if continuous_compute:
model_kwargs["input_states"] = outputs.latent_states[:, -1:, :]
# Check stopping conditions
if stop_tokens is not None:
for i in range(batch_size):
if unfinished_sequences[i] and torch.isin(next_tokens_batch[i], stop_tokens).any():
unfinished_sequences[i] = 0
if "stopping_criteria" in model_kwargs:
unfinished_sequences = unfinished_sequences & ~model_kwargs["stopping_criteria"](input_ids, None)
if unfinished_sequences.max() == 0:
break
if streamer:
streamer.end()
# Cut off extraneous parts of the sequence per batch element
if stop_tokens is not None:
for i in range(batch_size):
stop_positions = torch.isin(input_ids[i, prefix_seq_len:], stop_tokens).nonzero()
if len(stop_positions) > 0:
input_ids[i, prefix_seq_len + stop_positions[0].item() + 1 :] = pad_id
# Trim tensor to remove columns that are pad_id across all sequences
non_pad_mask = input_ids != pad_id
last_real_token = non_pad_mask.any(dim=0).nonzero()
if len(last_real_token) > 0:
input_ids = input_ids[:, : last_real_token[-1].item() + 1]
if generation_config.return_dict_in_generate:
return GenerateDecoderOnlyOutput(
sequences=input_ids, # type: ignore
scores=accepted_tokens, # type: ignore
logits=None,
attentions=None,
hidden_states=None,
past_key_values=model_kwargs.get("past_key_values"),
)
return input_ids
def get_stats(self, logits, x, latent_states, xk, input_embeds, num_steps_no_grad, num_steps_with_grad):
probs = torch.softmax(logits.float(), dim=-1)
prob_entropy = torch.where(probs > 0, -probs * probs.log(), 0).sum(dim=-1)
residual_diff = (x - latent_states).norm(dim=-1)
rel_residual = residual_diff / latent_states.norm(dim=-1)
stats = {
"entropy": prob_entropy,
"residual_diff": residual_diff,
"rel_residual": rel_residual,
"num_steps_no_grad": num_steps_no_grad,
"num_steps_with_grad": num_steps_with_grad,
}
return stats
#################################### Utils #######################################################################
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, condense_ratio: int = 1):
with torch.autocast("cuda", enabled=False):
inv_freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
t = torch.arange(end, dtype=torch.float32, device=inv_freqs.device) / condense_ratio
freqs = torch.outer(t, inv_freqs).float()
return torch.stack([torch.cos(freqs)[None, :, None, :], torch.sin(freqs)[None, :, None, :]], dim=4)
# equivalent to
# freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
# cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
def apply_rotary_emb_complex_like(q: Tensor, k: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
with torch.autocast("cuda", enabled=False):
qk_r2 = torch.cat([q, k], dim=2).unflatten(dim=-1, sizes=(-1, 2)).float() # cast to float32 for smooth skin
rotated_qk_r2 = torch.stack(
[
qk_r2[..., 0] * freqs_cis[..., 0] - qk_r2[..., 1] * freqs_cis[..., 1],
qk_r2[..., 1] * freqs_cis[..., 0] + qk_r2[..., 0] * freqs_cis[..., 1],
],
-1,
).flatten(3)
rotated_qk = rotated_qk_r2
return torch.split(rotated_qk.type_as(q), q.shape[2], dim=2) # type: ignore
#################################### HF registration ############################################################
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
# New
RavenConfig.register_for_auto_class()
RavenForCausalLM.register_for_auto_class("AutoModel")
RavenForCausalLM.register_for_auto_class("AutoModelForCausalLM")
# Old?
AutoConfig.register("huginn_raven", RavenConfig)
AutoModel.register(RavenConfig, RavenForCausalLM)
AutoModelForCausalLM.register(RavenConfig, RavenForCausalLM)