whynlp's picture
Upload LCKVLlamaForCausalLM
61c49cc verified
import re
from dataclasses import dataclass
from typing import List, Optional
import torch
from transformers.modeling_flash_attention_utils import _flash_attention_forward
@dataclass
class IterStep:
"""A helper class for the iteration plan"""
layer_slice: slice = slice(None)
requires_grad: bool = True
update: bool = True
@dataclass
class LayerType:
"""A helper class to collect the layer type information"""
layer_idx: int
use_sliding_window: bool
attends_to: int
attends_top: bool
computes_kv: bool
class LayerTypeParser:
"""
A helper class to parse the layer type string and provide some useful methods.
Arguments:
layer_type (str): A string of integers separated by underscores. The i-th integer
means the layer will use the key-value pair in the i-th layer as the kv cache.
Special characters may be placed after the integers:
- `s` means the layer will use sliding window attention.
>>> layer_type = LayerTypeParser("0_0_0_5s_5s_5s_8_8_8")[3]
>>> layer_type.attends_to
5
>>> layer_type.attends_top
True
>>> layer_type.use_sliding_window
True
"""
def __init__(self, layer_type: str):
self._layer_type = layer_type
# parse the layer type
self.layer_indices = []
self.sliding_window = []
for s in layer_type.split("_"):
layer_idx, sliding_window = re.match(r"^(\d+)(s)?$", s).groups()
self.layer_indices.append(int(layer_idx))
self.sliding_window.append(bool(sliding_window))
def __len__(self):
return len(self.layer_indices)
def __getitem__(self, layer_idx: int) -> LayerType:
"""return the layer type information for the given layer index"""
return LayerType(
layer_idx=layer_idx,
use_sliding_window=self.sliding_window[layer_idx],
attends_to=self.layer_indices[layer_idx],
attends_top=self.layer_indices[layer_idx] > layer_idx,
computes_kv=layer_idx in self.layer_indices,
)
def use_sliding_window(self) -> bool:
"""whether there exists a layer that uses sliding window attention"""
return any(self.sliding_window)
def attends_top(self) -> bool:
"""whether there exists a layer that attends to layers above it"""
return any(self.layer_indices[i] > i for i in range(len(self)))
def iteration_plan(self, forward_passes: int = 7, backward_passes: int = 2) -> List[IterStep]:
"""
Return a iteration plan for the layer types. The plan is a list of IterStep objects.
"""
# if there is no cyclic dependency, return the default plan
if not self.attends_top():
return [IterStep()]
# otherwise, return the plan for the cyclic dependency
plan = []
i = 0
while i < len(self):
# if the layer attends to top layers, resolve the cyclic dependency
if self[i].attends_top:
# find the top layer in the cyclic dependency
top = self[i].attends_to
while top < max(self.layer_indices[i: top + 1]):
top = max(self.layer_indices[i: top + 1])
top += 1
# create iteration plan for this group
layer_slice = slice(i, top)
plan.extend([
*forward_passes * [IterStep(layer_slice, requires_grad=False, update=False)],
*(backward_passes - 1) * [IterStep(layer_slice, update=False)],
IterStep(layer_slice)
])
# otherwise, create a default plan
else:
top = i + 1
while top < len(self) and not self[top].attends_top:
top += 1
plan.append(IterStep(slice(i, top)))
# update the index
i = top
return plan
def check(self, num_hidden_layers: int):
"""Check if the layer type is valid"""
if len(self.layer_indices) != num_hidden_layers:
raise ValueError("The number of layer types should be equal to the number of hidden layers.")
for i in range(num_hidden_layers):
if self.layer_indices[i] not in range(num_hidden_layers):
raise ValueError("The layer type should be in the range of the number of hidden layers.")
def flash_attention_forward(
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
attention_mask: torch.Tensor,
query_length: int,
is_causal: bool,
dropout: float = 0.0,
position_ids: Optional[torch.Tensor] = None,
softmax_scale: Optional[float] = None,
sliding_window: Optional[int] = None,
use_top_left_mask: bool = False,
softcap: Optional[float] = None,
deterministic: bool = None,
no_diag: bool = False,
):
"""
This function is a wrapper around the _flash_attention_forward function in the
transformers library. It adds support to mask the diagonal elements of the attention
matrix. The diagonal mask is used to resolve the cyclic dependencies in the LCKV model.
"""
prune_query = False
if no_diag:
if key_states.size(1) == 1:
b, l, _, d = value_states.size()
_, _, h, _ = query_states.size()
return value_states.new_zeros((b, l, h, d))
if key_states.size(1) == query_states.size(1):
prune_query = True
query_states = query_states[:, 1:, :, :]
query_length -= 1
if attention_mask is not None:
attention_mask = attention_mask[:, 1:]
key_states = key_states[:, :-1, :, :]
value_states = value_states[:, :-1, :, :]
if sliding_window is not None:
sliding_window = sliding_window - 1
result: torch.Tensor = _flash_attention_forward(
query_states=query_states,
key_states=key_states,
value_states=value_states,
attention_mask=attention_mask,
query_length=query_length,
is_causal=is_causal,
dropout=dropout,
position_ids=position_ids,
softmax_scale=softmax_scale,
sliding_window=sliding_window,
use_top_left_mask=use_top_left_mask,
softcap=softcap,
deterministic=deterministic,
)
if prune_query:
b, _, h, d = result.size()
result = torch.cat([result.new_zeros((b, 1, h, d)), result], dim=1)
return result