File size: 6,575 Bytes
61c49cc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
import re
from dataclasses import dataclass
from typing import List, Optional
import torch
from transformers.modeling_flash_attention_utils import _flash_attention_forward
@dataclass
class IterStep:
"""A helper class for the iteration plan"""
layer_slice: slice = slice(None)
requires_grad: bool = True
update: bool = True
@dataclass
class LayerType:
"""A helper class to collect the layer type information"""
layer_idx: int
use_sliding_window: bool
attends_to: int
attends_top: bool
computes_kv: bool
class LayerTypeParser:
"""
A helper class to parse the layer type string and provide some useful methods.
Arguments:
layer_type (str): A string of integers separated by underscores. The i-th integer
means the layer will use the key-value pair in the i-th layer as the kv cache.
Special characters may be placed after the integers:
- `s` means the layer will use sliding window attention.
>>> layer_type = LayerTypeParser("0_0_0_5s_5s_5s_8_8_8")[3]
>>> layer_type.attends_to
5
>>> layer_type.attends_top
True
>>> layer_type.use_sliding_window
True
"""
def __init__(self, layer_type: str):
self._layer_type = layer_type
# parse the layer type
self.layer_indices = []
self.sliding_window = []
for s in layer_type.split("_"):
layer_idx, sliding_window = re.match(r"^(\d+)(s)?$", s).groups()
self.layer_indices.append(int(layer_idx))
self.sliding_window.append(bool(sliding_window))
def __len__(self):
return len(self.layer_indices)
def __getitem__(self, layer_idx: int) -> LayerType:
"""return the layer type information for the given layer index"""
return LayerType(
layer_idx=layer_idx,
use_sliding_window=self.sliding_window[layer_idx],
attends_to=self.layer_indices[layer_idx],
attends_top=self.layer_indices[layer_idx] > layer_idx,
computes_kv=layer_idx in self.layer_indices,
)
def use_sliding_window(self) -> bool:
"""whether there exists a layer that uses sliding window attention"""
return any(self.sliding_window)
def attends_top(self) -> bool:
"""whether there exists a layer that attends to layers above it"""
return any(self.layer_indices[i] > i for i in range(len(self)))
def iteration_plan(self, forward_passes: int = 7, backward_passes: int = 2) -> List[IterStep]:
"""
Return a iteration plan for the layer types. The plan is a list of IterStep objects.
"""
# if there is no cyclic dependency, return the default plan
if not self.attends_top():
return [IterStep()]
# otherwise, return the plan for the cyclic dependency
plan = []
i = 0
while i < len(self):
# if the layer attends to top layers, resolve the cyclic dependency
if self[i].attends_top:
# find the top layer in the cyclic dependency
top = self[i].attends_to
while top < max(self.layer_indices[i: top + 1]):
top = max(self.layer_indices[i: top + 1])
top += 1
# create iteration plan for this group
layer_slice = slice(i, top)
plan.extend([
*forward_passes * [IterStep(layer_slice, requires_grad=False, update=False)],
*(backward_passes - 1) * [IterStep(layer_slice, update=False)],
IterStep(layer_slice)
])
# otherwise, create a default plan
else:
top = i + 1
while top < len(self) and not self[top].attends_top:
top += 1
plan.append(IterStep(slice(i, top)))
# update the index
i = top
return plan
def check(self, num_hidden_layers: int):
"""Check if the layer type is valid"""
if len(self.layer_indices) != num_hidden_layers:
raise ValueError("The number of layer types should be equal to the number of hidden layers.")
for i in range(num_hidden_layers):
if self.layer_indices[i] not in range(num_hidden_layers):
raise ValueError("The layer type should be in the range of the number of hidden layers.")
def flash_attention_forward(
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
attention_mask: torch.Tensor,
query_length: int,
is_causal: bool,
dropout: float = 0.0,
position_ids: Optional[torch.Tensor] = None,
softmax_scale: Optional[float] = None,
sliding_window: Optional[int] = None,
use_top_left_mask: bool = False,
softcap: Optional[float] = None,
deterministic: bool = None,
no_diag: bool = False,
):
"""
This function is a wrapper around the _flash_attention_forward function in the
transformers library. It adds support to mask the diagonal elements of the attention
matrix. The diagonal mask is used to resolve the cyclic dependencies in the LCKV model.
"""
prune_query = False
if no_diag:
if key_states.size(1) == 1:
b, l, _, d = value_states.size()
_, _, h, _ = query_states.size()
return value_states.new_zeros((b, l, h, d))
if key_states.size(1) == query_states.size(1):
prune_query = True
query_states = query_states[:, 1:, :, :]
query_length -= 1
if attention_mask is not None:
attention_mask = attention_mask[:, 1:]
key_states = key_states[:, :-1, :, :]
value_states = value_states[:, :-1, :, :]
if sliding_window is not None:
sliding_window = sliding_window - 1
result: torch.Tensor = _flash_attention_forward(
query_states=query_states,
key_states=key_states,
value_states=value_states,
attention_mask=attention_mask,
query_length=query_length,
is_causal=is_causal,
dropout=dropout,
position_ids=position_ids,
softmax_scale=softmax_scale,
sliding_window=sliding_window,
use_top_left_mask=use_top_left_mask,
softcap=softcap,
deterministic=deterministic,
)
if prune_query:
b, _, h, d = result.size()
result = torch.cat([result.new_zeros((b, 1, h, d)), result], dim=1)
return result
|