File size: 22,117 Bytes
ae81e0f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 |
"""
Linear attention classes
"""
from typing import List, Tuple, Optional
import copy
import torch
import torch.nn as nn
from omegaconf import OmegaConf, DictConfig
from transformers.cache_utils import Cache # starting at Transformers v4.36
# Causal linear attention dot product CUDA kernel from fast-transformers
try:
from csrc import causal_dot_product as fast_causal_dot_product
except ImportError:
fast_causal_dot_product = None
from src.model.feature_map import init_feature_map, init_learned_kernel
from src.model.rotary import get_rotary_embeddings, apply_rotary_pos_emb
from .utils import repeat_kv
# -------------------
# Attention functions
# -------------------
def causal_dot_product(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
"""
Causal linear attention dot product
- If available, use CUDA kernel from fast-transformers
"""
if fast_causal_dot_product is None:
kv = torch.einsum('bhlf,bhld->bhlfd', k, v)
return torch.einsum('bhlf,bhlfd->bhld', q, kv.cumsum(dim=2))
return fast_causal_dot_product(q, k, v)
def linear_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
fp32_attention: bool = False, eps: float = 1e-12,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Compute linear attention with CUDA kernel implementation from fast-transformers
- https://github.com/idiap/fast-transformers
- Assume q, k are shape (batch_size, num_heads, seq_len, feature_dim);
v is shape (b, h, l, head_dim)
"""
dtype = q.dtype
# Causal mask already applied
y = causal_dot_product(q.contiguous().to(dtype=torch.float32),
k.contiguous().to(dtype=torch.float32),
v.contiguous().to(dtype=torch.float32))
if fp32_attention:
y = (y / (torch.einsum(
"bhld,bhld->bhl", q.float(), k.float().cumsum(dim=2)
) + eps)[..., None]).to(dtype=dtype)
else:
y = y.to(dtype=dtype)
k = k.float().cumsum(dim=2).to(dtype=dtype)
y = y / (torch.einsum("bhld,bhld->bhl", q, k) + eps)[..., None]
return y, None, None
def softmax_attention(q: torch.Tensor, k: torch.Tensor, v: Optional[torch.Tensor] = None,
causal: bool = True, fp32_attention: bool = True,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Standard softmax attention; only compute outputs if v is not None
-> Assume q, k, v are shape (batch_size, num_heads, seq_len, head_dim)
"""
y = None
a = torch.einsum('bhmd,bhnd->bhmn', q, k) * (k.shape[-1] ** -0.5)
if causal: # Apply causal mask
m, n = a.shape[-2:]
causal_mask = torch.ones((m, n), device = a.device, dtype = torch.bool).triu(n - m + 1)
a = a.masked_fill(causal_mask, -torch.finfo(a.dtype).max)
if fp32_attention:
a = torch.softmax(a, dim=-1, dtype=torch.float32).to(q.dtype)
else:
a = torch.softmax(a, dim=-1)
if v is not None:
y = torch.einsum('bhmn,bhnd->bhmd', a, v)
return y, a, None
def quadratic_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor = None,
causal: bool = True, fp32_attention: bool = False, eps: float = 1e-12,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Compute attention with feature maps by instantiating L x L matrix of attention weights
-> Use for attention distillation
-> Assume q, k are shape (batch_size, num_heads, seq_len, feature_dim); v is shape (b, h, l, head_dim)
"""
y = None
dtype = q.dtype
if fp32_attention:
q, k = q.float(), k.float()
a = torch.einsum('bhmd,bhnd->bhmn', q, k) # note we don't scale, tho we could
if causal: # Apply causal mask
m, n = a.shape[-2:]
causal_mask = torch.ones((m, n), device = a.device, dtype = torch.bool).triu(n - m + 1)
a = a.masked_fill(causal_mask, 0)
# Normalize to compute attention
a = a / (a.sum(dim=-1, keepdim=True) + eps)
a = a.to(dtype=dtype) if fp32_attention else a
if torch.isnan(a).sum() > 0:
breakpoint()
if v is not None:
y = torch.einsum('bhmn,bhnd->bhmd', a, v)
return y, a, None
# ---------------------
# Attention layer class
# ---------------------
class LolcatsLinearAttention(nn.Module):
"""
LoLCATs attention implementation initialized from a
`LlamaAttention` or `MistralAttention` object (base_attn)
Most of the arguments are directly tied to argparse args
- For now we don't support padding.
"""
def __init__(self,
base_attn: nn.Module, # like LlamaAttention
feature_map: str,
feature_map_kwargs: dict,
layer_idx: Optional[int] = None,
max_layer_idx: Optional[int] = None,
learned_kernel: Optional[str] = None,
learned_kernel_kwargs: Optional[dict] = None,
tie_qk_kernels: Optional[bool] = False,
rotary_config: Optional[dict] = None,
train_attention: Optional[bool] = False,
remove_base_attn: Optional[bool] = True,
attention_type: Optional[str] = 'lolcats_llama',
mask_value: int = 0,
eps: float = 1e-12,
fp32_attention: bool = False,
track_state_grads: bool = False,
rank: Optional[int] = 0,
**kwargs: any) -> None:
super().__init__()
self.base_config = getattr(base_attn, 'config', None)
if self.base_config is not None:
self.base_config = self.base_config.to_dict()
self.attention_type = attention_type
self.mask_value = mask_value
self.eps = eps
self.layer_idx = (layer_idx if layer_idx is not None else base_attn.layer_idx)
self.max_layer_idx = max_layer_idx
self.tie_qk_kernels = tie_qk_kernels
self.train_attention = train_attention
self.base_inference = False
self.fp32_attention = fp32_attention
self.track_state_grads = track_state_grads
if rank == 0: # multi-gpu
if fp32_attention and layer_idx == 0:
print(f'-> fp32_attention is {fp32_attention}')
if layer_idx == 0 and feature_map_kwargs is not None:
for k, v in feature_map_kwargs.items():
print(f'-> {k}: {v}')
if layer_idx == 0 and learned_kernel_kwargs is not None:
for k, v in learned_kernel_kwargs.items():
print(f'-> {k}: {v}')
self.remove_base_attn = remove_base_attn
# Rotary embeddings (patch for Llama 3.1, Transformer v4.43.0)
self.rotary_config = rotary_config
if isinstance(self.rotary_config, DictConfig): # ensure dict
self.rotary_config = OmegaConf.to_container(self.rotary_config)
self.rotary_emb = None
if self.base_config is not None and self.rotary_config is None:
self.rotary_emb = base_attn.rotary_emb
self.init_weights_(base_attn, remove_base_attn)
self.init_feature_map_(feature_map, feature_map_kwargs,
learned_kernel, learned_kernel_kwargs)
def init_feature_map_(self,
feature_map: str,
feature_map_kwargs: dict,
learned_kernel: str = None,
learned_kernel_kwargs: dict = None):
"""
Initialize MLP-based feature map
"""
self.fmap_gqa = False # Turn True if specified below
if learned_kernel is not None:
# Ensure dict
learned_kernel_kwargs = {k: v for k, v in learned_kernel_kwargs.items()}
learned_kernel_kwargs['num_heads'] = self.num_heads
learned_kernel_kwargs['head_dim'] = self.head_dim
learned_kernel_kwargs['dtype'] = self.q_proj.weight.dtype
learned_kernel_kwargs['device'] = self.q_proj.weight.device
# Create MLP
mlp_learned_kernel = init_learned_kernel(learned_kernel, **learned_kernel_kwargs)
# Add "activation"; see src.models.feature_map.py
self.feature_map_q = init_feature_map(name=feature_map,
mlp=mlp_learned_kernel,
**feature_map_kwargs)
if self.tie_qk_kernels: # tie mlp weights for query and key feature maps
self.feature_map_k = self.feature_map_q
else:
self.feature_map_k = copy.deepcopy(self.feature_map_q)
def init_weights_(self, base_attn: nn.Module, remove_base_attn: bool = True):
"""
Initialize module layers, weights, positional dependencies, etc.
from original softmax attention layer (base_attn)
"""
# Make other attributes accessible
self.attention_dropout = 0 # We don't use dropout
self.hidden_size = base_attn.hidden_size
self.num_heads = base_attn.num_heads
self.head_dim = base_attn.head_dim
self.num_key_value_heads = base_attn.num_key_value_heads
self.num_key_value_groups = base_attn.num_key_value_groups
self.q_shape = [self.num_heads, self.head_dim]
self.k_shape = [self.num_key_value_heads, self.head_dim]
self.v_shape = [self.num_key_value_heads, self.head_dim]
device = base_attn.q_proj.weight.device
# Rotary embeddings
if self.rotary_emb is None:
self.max_position_embeddings = base_attn.max_position_embeddings
scaling_factor = getattr(base_attn.rotary_emb, 'scaling_factor', 1.)
if self.rotary_config is None:
self.rotary_emb = get_rotary_embeddings(
rope_scaling_type=None,
head_dim=self.head_dim,
max_position_embeddings=self.max_position_embeddings, # base_attn.rotary_emb.max_position_embeddings,
rope_theta=base_attn.rotary_emb.base,
rope_scaling_factor=scaling_factor, # base_attn.rotary_emb.scaling_factor,
device=device,
)
else:
if 'device' not in self.rotary_config:
self.rotary_config['device'] = device
self.rotary_emb = get_rotary_embeddings(**self.rotary_config)
# Copy original model projection layers
self.q_proj = base_attn.q_proj
self.k_proj = base_attn.k_proj
self.v_proj = base_attn.v_proj
self.o_proj = base_attn.o_proj
try: # If wanting to use FA2 for ground-truth inference
self._flash_attn_uses_top_left_mask = base_attn._flash_attn_uses_top_left_mask
except AttributeError:
pass
if self.remove_base_attn or remove_base_attn:
del base_attn # We don't need to keep these around
else:
self.base_attn = base_attn # For some training runs helpful to just call
def process_qkv(self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[int, torch.Tensor, torch.Tensor]] = None,): # "legacy" cache approach
"""
Compute queries, keys, and values
"""
b, l, _ = hidden_states.size()
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
kv_seq_len = k.shape[-2]
# Shape is (batch_size, seq_len, num_heads, head_dim)
q = q.view(b, l, *self.q_shape).transpose(1, 2)
k = k.view(b, l, *self.k_shape).transpose(1, 2)
v = v.view(b, l, *self.v_shape).transpose(1, 2)
if past_key_value is not None: # and k.shape[2] > q.shape[2]: # e.g., when generating
past_key_value.window_size = getattr(self, 'decode_window_size', None) # self.decode_window_size
if isinstance(past_key_value, Cache): # In Transformers v4.36+ this is a DynamicCache object
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
else:
kv_seq_len += past_key_value[0].shape[-2]
# Apply rotary embeddings and repeat for GQA
if position_ids is not None and kv_seq_len <= position_ids[0, -1]:
kv_seq_len = position_ids[0, -1] + 1 # hack for adjusting position ids
try: # As in Transformers v4.36
cos, sin = self.rotary_emb(k, seq_len=kv_seq_len)
q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids)
except TypeError: # As in Transformers v4.39+
cos, sin = self.rotary_emb(v, position_ids)
q, k = apply_rotary_pos_emb(q, k, cos, sin)
k = repeat_kv(k, self.num_key_value_groups)
v = repeat_kv(v, self.num_key_value_groups)
return q, k, v, kv_seq_len
def forward(self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[int, torch.Tensor, torch.Tensor]] = None, # "legacy" cache approach
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Forward pass modified from transformers.models.mistral.modeling_mistral (v4.36)
- Consistent with HuggingFace Transformers for easy use with their pretrained models
"""
b, l, _ = hidden_states.size()
q, k, v, kv_seq_len = self.process_qkv(hidden_states, attention_mask,
position_ids, past_key_value)
if self.base_inference:
with torch.no_grad():
# 1. Compute "ground-truth" attention output and weights
y_true, _, _ = softmax_attention(q, k, v, causal=True)
y_true = y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
y_true = self.o_proj(y_true)
attn_weights = (None, None)
elif self.train_attention: # Distilling / learning attentions
# Note for now we assume no padding when distilling; attention masks only enforce causality
assert output_attentions is True, f'When training feature maps, output_attentions should be True but is {output_attentions}'
with torch.no_grad():
# 1. Compute "ground-truth" attention output and weights
_y_true, attn_true, _ = softmax_attention(q, k, v, causal=True)
y_true = _y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
y_true = self.o_proj(y_true)
# 2. Compute "predicted" attention (just weights)
q, k = self.feature_map_q.q_map(q), self.feature_map_k.k_map(k)
y_pred, attn_pred, _ = quadratic_attention(q, k, v, causal=True)
attn_weights = ((attn_pred, attn_true), (y_pred, _y_true)) # Save both attention weights so we can supervise.
else: # Finetuning
q, k = self.feature_map_q(q), self.feature_map_k(k)
# Apply prefill mask
if attention_mask is not None and q.shape[2] > 1:
if len(attention_mask.shape) == 4:
lin_attn_mask = (attention_mask == 0)[:, :1, -1, :l][..., None] # b, 1, k_len, 1
else:
lin_attn_mask = attention_mask[:, None, :, None] # b, 1, k_len, 1
k = k.masked_fill(~lin_attn_mask, 0)
if past_key_value is not None: # Initialize states
if len(past_key_value.kv_states) == self.layer_idx:
b, h, _, f = k.shape
past_key_value.kv_states.append(
torch.zeros(b, h, f, self.head_dim, dtype=q.dtype, device=q.device)
)
past_key_value.k_states.append(
torch.zeros(b, h, 1, f, dtype=q.dtype, device=q.device)
)
# Generating
if q.shape[2] == 1 and kv_seq_len > 1 and past_key_value is not None:
assert use_cache is True
kv_state, k_state = past_key_value.update(k, v, self.layer_idx,
accumulate_in_fp32=self.fp32_attention)
if self.fp32_attention:
q = q.float()
y_true = (torch.einsum('bhlf,bhfd->bhld', q, kv_state.float()) /
torch.einsum('bhlf,bhlf->bhl', q, k_state.float())[..., None]).to(dtype=k.dtype)
else:
y_true = (torch.einsum('bhlf,bhfd->bhld', q, kv_state) /
torch.einsum('bhlf,bhlf->bhl', q, k_state)[..., None])
else:
kv_state = past_key_value.kv_states[self.layer_idx]
k_state = past_key_value.k_states[self.layer_idx]
y_true, _, _ = linear_attention(q, k, v, self.fp32_attention, self.eps) # Ordinarily the states are ignored
past_key_value.update(k.detach(), v.detach(), self.layer_idx,
accumulate_in_fp32=self.fp32_attention)
# doing some unnecessary recomputation here
else:
y_true, _, _ = linear_attention(q, k, v, self.fp32_attention, self.eps)
# Concatenate heads and apply output projection
y_true = y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
y_true = self.o_proj(y_true)
attn_weights = None
return y_true, attn_weights, past_key_value
class LinearAttentionState(Cache):
"""
Handle the KV and K states for linear attention
- Adopts HF Transformers `past_key_values` convention
- Inherits from `Cache` class
- Modified from transformers.cache_utils.DynamicCache (v4.36)
"""
def __init__(self) -> None:
self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36
self._seen_tokens_by_layer: List[int] = []
self.kv_states: List[torch.Tensor] = []
self.k_states: List[torch.Tensor] = []
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""
Returns the sequence length of the cached states. A layer index can be optionally passed.
"""
if len(self._seen_tokens_by_layer) <= layer_idx: # Initializing kv and k states
self._seen_tokens_by_layer.append(0)
return self._seen_tokens_by_layer[layer_idx]
def get_max_length(self) -> Optional[int]:
"""
Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.
"""
return None
def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
"""Given the sequence length of the new inputs, returns the usable length of the cache."""
# Cache without size limit -> all cache is usable
# Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
# length, we will need to evict part of the cache (and thus not all cache is usable)
max_length = self.get_max_length()
previous_seq_length = self.get_seq_length(layer_idx)
if max_length is not None and previous_seq_length + new_seq_length > max_length:
return max_length - new_seq_length
return previous_seq_length
def update(self, key_states: torch.Tensor, value_states: torch.Tensor,
layer_idx: Optional[int] = None, cache_kwargs: Optional[any] = None,
accumulate_in_fp32: bool = True, **kwargs: any,
) -> Tuple[torch.Tensor, torch.Tensor]:
with torch.no_grad ():
if layer_idx == 0:
self._seen_tokens += key_states.shape[-2]
dtype = key_states.dtype
if accumulate_in_fp32:
key_states, value_states = key_states.float(), value_states.float()
kv_state = torch.einsum('bhlf,bhld->bhfd', key_states, value_states).detach()
k_state = key_states.sum(dim=-2, keepdim=True).detach() # b, h, 1, f; note the 1
# Update the cache
if len(self.k_states) <= layer_idx: # Initializing kv and k states
print('if len(self.k_states) <= layer_idx: # Initializing kv and k states')
self.kv_states.append(kv_state.to(dtype))
self.k_states.append(k_state.to(dtype))
else:
kv_state = (self.kv_states[layer_idx].to(kv_state.dtype) + kv_state).to(dtype)
k_state = (self.k_states[layer_idx].to(kv_state.dtype) + k_state).to(dtype)
self.kv_states[layer_idx] = kv_state
self.k_states[layer_idx] = k_state
self._seen_tokens_by_layer[layer_idx] += key_states.shape[-2]
return self.kv_states[layer_idx], self.k_states[layer_idx]
def to_legacy_cache(self):
"""Hack, but just return self"""
return self
def reorder_cache(self, beam_idx: torch.LongTensor):
"""
Reorders the cache for beam search, given the selected beam indices.
-> Copied from transformers/src/transformers/cache_utils.py
"""
raise NotImplementedError('Reordering cache not implemented for LinearAttentionState')
|