File size: 8,868 Bytes
72268ee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 |
import os
import math
import torch
import torch.nn as nn
from torch.nn import functional as F
from awq.modules.fused.cache import WindowedCache
from awq.utils.fused_utils import get_attention_shapes
try:
import ft_inference_engine
FT_INSTALLED = True
except:
FT_INSTALLED = False
class RoPE(nn.Module):
def __init__(self, hidden_size, n_heads, max_seq_len, device):
super(RoPE, self).__init__()
self.freqs_cis = nn.Parameter(
self.precompute_freqs_cis(hidden_size // n_heads, max_seq_len * 2).to(device),
requires_grad=False
)
@staticmethod
def precompute_freqs_cis(dim: int, end: int, theta=10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end)
freqs = torch.outer(t, freqs).float()
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis
@staticmethod
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def forward(self, xq: torch.Tensor, xk: torch.Tensor, start_pos: int, seqlen: int):
xq_ = torch.view_as_complex(
xq.float().reshape(*xq.shape[:-1], 2, -1).transpose(-2, -1).contiguous()
)
xk_ = torch.view_as_complex(
xk.float().reshape(*xk.shape[:-1], 2, -1).transpose(-2, -1).contiguous()
)
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
freqs_cis = self.reshape_for_broadcast(freqs_cis, xq_).to(xq_.device)
xq_out = torch.view_as_real(xq_ * freqs_cis).transpose(-2, -1).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).transpose(-2, -1).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class ALiBi(nn.Module):
def __init__(self, n_heads, max_seq_len, device, alibi_bias_max=8):
super(ALiBi, self).__init__()
# Initialize ALiBi slopes and bias
slopes, bias = self.build_alibi_bias(n_heads, max_seq_len, alibi_bias_max=alibi_bias_max)
self.slopes = nn.Parameter(slopes.float().to(device), requires_grad=False)
self.bias = nn.Parameter(bias.float().to(device), requires_grad=False)
@staticmethod
def gen_slopes(n_heads, alibi_bias_max=8):
_n_heads = 2 ** math.ceil(math.log2(n_heads))
m = torch.arange(1, _n_heads + 1, dtype=torch.float32)
m = m.mul(alibi_bias_max / _n_heads)
slopes = 1.0 / torch.pow(2, m)
if _n_heads != n_heads:
slopes = torch.cat([slopes[1::2], slopes[::2]])[:n_heads]
return slopes.view(1, n_heads, 1, 1)
@staticmethod
def build_alibi_bias(n_heads, seq_len, alibi_bias_max=8, dtype=torch.float32):
alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32).view(1, 1, 1, seq_len)
slopes = ALiBi.gen_slopes(n_heads, alibi_bias_max)
alibi_bias = alibi_bias * slopes
slopes = slopes.squeeze(0).squeeze(-1).squeeze(-1)
return slopes.to(dtype=dtype), alibi_bias.to(dtype=dtype)
def forward(self, scores, seqlen):
scores += self.bias[..., :seqlen]
return scores
class QuantAttentionFused(nn.Module):
def __init__(self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj, dev, max_seq_len,
use_alibi=False, attention_shapes=None):
super().__init__()
self.hidden_size = hidden_size
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.n_kv_groups = n_heads // n_kv_heads if n_kv_heads != 0 else 0
self.head_dim = self.hidden_size // n_heads
self.qkv_proj = qkv_layer
self.o_proj = o_proj
self.start_pos = 0
self.use_alibi = use_alibi
self.cache_batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1"))
self.max_seq_len = max_seq_len
# attention shapes for self attention
self.attention_shapes = get_attention_shapes(
attention_shapes, max_seq_len, self.cache_batch_size, n_heads, n_kv_heads, self.head_dim
)
# cache store that rolls cache
self.cache = WindowedCache(
self.attention_shapes["cache_v"], self.attention_shapes["cache_k"], dev
)
if use_alibi:
self.alibi = ALiBi(n_heads, max_seq_len, dev)
self.rotary_dim = 0
self.is_neox = False
else:
self.alibi = None
self.rope = RoPE(hidden_size, n_heads, max_seq_len, dev)
self.rotary_dim = self.head_dim
self.is_neox = True
def forward(self, hidden_states:torch.Tensor, attention_mask=None, *args, **kwargs):
bsz, seqlen, _ = hidden_states.shape
if bsz != self.cache_batch_size:
raise RuntimeError(
f"Batch size is incorrectly set - input batch size {bsz}, kv-cache batch size {self.cache_batch_size}. "
f"Use: AutoAWQForCausalLM.from_quantized(batch_size={bsz})"
)
if self.start_pos > self.max_seq_len or self.start_pos + seqlen > self.max_seq_len:
excess_length = self.start_pos + seqlen - self.max_seq_len
self.start_pos = self.cache.roll_kv(excess_length, self.start_pos)
xqkv = self.qkv_proj(hidden_states)
xqkv = xqkv.view((bsz, seqlen) + self.attention_shapes["xqkv_view"])
xq = self.attention_shapes["xq_slice"](xqkv)
xk = self.attention_shapes["xk_slice"](xqkv)
xv = self.attention_shapes["xv_slice"](xqkv)
if seqlen > 1 or not FT_INSTALLED:
xq = xq.view((bsz, seqlen) + self.attention_shapes["xq_view"])
xk = xk.view((bsz, seqlen) + self.attention_shapes["xk_view"])
xv = xv.view((bsz, seqlen) + self.attention_shapes["xv_view"])
if not self.use_alibi:
xq, xk = self.rope.forward(xq, xk, self.start_pos, seqlen)
self.cache.to(xq)
values_store = xv.transpose(2, 1)
keys_store = (
xk.reshape((bsz, seqlen) + self.attention_shapes["xk_reshape"])
.permute(0, 2, 3, 1, 4)
.contiguous()
)
self.cache.update_kv(values_store, keys_store, bsz, self.start_pos, seqlen)
if seqlen == 1:
xv, xk = self.cache.get_kv(bsz, self.start_pos, seqlen, self.head_dim)
keys = xk
values = xv
if self.n_kv_groups != 0:
keys = torch.repeat_interleave(keys, dim=2, repeats=self.n_kv_groups)
values = torch.repeat_interleave(values, dim=2, repeats=self.n_kv_groups)
xq = xq.transpose(1, 2)
keys = keys.transpose(1, 2)
values = values.transpose(1, 2)
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
if self.use_alibi:
scores = self.alibi.forward(scores, seqlen)
if attention_mask is not None:
scores = scores + attention_mask # (bs, n_local_heads, slen, cache_len + slen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim)
attention_weight = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
else:
xq = xq.view((bsz,) + self.attention_shapes["single_xq_view"])
xk = xk.view((bsz,) + self.attention_shapes["single_xk_view"])
xv = xv.view((bsz,) + self.attention_shapes["single_xv_view"])
alibi_slopes = self.alibi.slopes if self.alibi is not None else None
attention_weight = ft_inference_engine.single_query_attention(
xq, # query
xk, # key
xv, # value
self.cache.k, # key cache
self.cache.v, # value cache
None, # length per sample
alibi_slopes, # alibi slopes
self.start_pos, # timestep
self.rotary_dim, # rotary embedding dimension
10000, # rotary embedding base
self.is_neox, # is neox
)
attention_weight = attention_weight.reshape(bsz, 1, -1)
attn_output = self.o_proj(attention_weight)
self.start_pos += seqlen
# past_key_value is replaced with cache_v, cache_k, returning empty data
past_key_value = [torch.Tensor([ [ [[0]], [[0]], [[0]] ] ])]
return attn_output, attention_weight, past_key_value |